|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utilities for models.""" |
|
|
|
|
|
import torch |
|
|
from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm |
|
|
from megatron.model.fused_softmax import SoftmaxFusionTypes |
|
|
from types import GeneratorType |
|
|
|
|
|
|
|
|
def get_params_for_weight_decay_optimization(module, neox_args): |
|
|
"""Divide params into with-weight-decay and without-weight-decay groups. |
|
|
Layernorms and biases will have no weight decay but the rest will. |
|
|
""" |
|
|
weight_decay_params = {"params": []} |
|
|
no_weight_decay_params = {"params": [], "weight_decay": 0.0} |
|
|
for module_ in module.modules(): |
|
|
if any( |
|
|
[ |
|
|
isinstance(module_, LayerNorm), |
|
|
isinstance(module_, RMSNorm), |
|
|
isinstance(module_, ScaleNorm), |
|
|
] |
|
|
) or ( |
|
|
neox_args.weight_decay == 0.0 |
|
|
): |
|
|
no_weight_decay_params["params"].extend( |
|
|
[p for p in list(module_._parameters.values()) if p is not None] |
|
|
) |
|
|
else: |
|
|
weight_decay_params["params"].extend( |
|
|
[ |
|
|
p |
|
|
for n, p in list(module_._parameters.items()) |
|
|
if p is not None and n != "bias" |
|
|
] |
|
|
) |
|
|
no_weight_decay_params["params"].extend( |
|
|
[ |
|
|
p |
|
|
for n, p in list(module_._parameters.items()) |
|
|
if p is not None and n == "bias" |
|
|
] |
|
|
) |
|
|
if neox_args.weight_decay == 0.0: |
|
|
|
|
|
|
|
|
|
|
|
return [no_weight_decay_params] |
|
|
return weight_decay_params, no_weight_decay_params |
|
|
|
|
|
|
|
|
def exists(x): |
|
|
return x is not None |
|
|
|
|
|
|
|
|
class Lambda(torch.nn.Module): |
|
|
def __init__(self, func): |
|
|
super().__init__() |
|
|
self.func = func |
|
|
|
|
|
def forward(self, x): |
|
|
return self.func(x) |
|
|
|
|
|
|
|
|
class SequentialWrapper(torch.nn.Module): |
|
|
""" |
|
|
Used to convert a deepspeed PipelineModule to an nn.Sequential like model whilst retaining |
|
|
activation checkpointing. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
layers, |
|
|
activation_checkpoint_interval, |
|
|
activation_checkpoint_func, |
|
|
parent_class_name=None, |
|
|
): |
|
|
super().__init__() |
|
|
self.sequential = torch.nn.Sequential(*layers) |
|
|
self.activation_checkpoint_interval = activation_checkpoint_interval |
|
|
self.parent_class_name = parent_class_name |
|
|
self.activation_checkpoint_func = activation_checkpoint_func |
|
|
|
|
|
def _is_checkpointable(self, funcs): |
|
|
if self.parent_class_name == "GPT2ModelPipe": |
|
|
return all( |
|
|
"ParallelTransformerLayerPipe" in f.__class__.__name__ for f in funcs |
|
|
) |
|
|
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)] |
|
|
return any(len(list(p)) > 0 for p in params) |
|
|
|
|
|
def inference_mode(self, use_cache=True): |
|
|
""" |
|
|
Sets up the model for inference by turning on k/v caching (if specified) and setting `parallel output` of the final layer to false, |
|
|
so logits are gathered across model parallel ranks. |
|
|
|
|
|
:param cache: (bool) True if you want to use caching during inference, False otherwise |
|
|
""" |
|
|
_set_use_cache(self.sequential, use_cache) |
|
|
|
|
|
def train_mode(self): |
|
|
""" |
|
|
Sets up the model for training by turning off k/v caching. |
|
|
""" |
|
|
_set_use_cache(self.sequential, False) |
|
|
|
|
|
def forward(self, forward_input): |
|
|
def exec_range_func(start, end): |
|
|
"""Helper function to be used with checkpoint() |
|
|
Adapted from torch.utils.checkpoint:checkpoint_sequential() |
|
|
""" |
|
|
|
|
|
def exec_func(*inputs): |
|
|
|
|
|
if len(inputs) == 1: |
|
|
inputs = inputs[0] |
|
|
for idx, layer in enumerate(self.sequential[start:end]): |
|
|
inputs = layer(inputs) |
|
|
return inputs |
|
|
|
|
|
return exec_func |
|
|
|
|
|
if self.activation_checkpoint_interval == 0: |
|
|
func = exec_range_func(0, len(self.sequential)) |
|
|
x = func(forward_input) |
|
|
else: |
|
|
num_layers = len(self.sequential) |
|
|
x = forward_input |
|
|
for start_idx in range(0, num_layers, self.activation_checkpoint_interval): |
|
|
end_idx = min( |
|
|
start_idx + self.activation_checkpoint_interval, num_layers |
|
|
) |
|
|
|
|
|
funcs = self.sequential[start_idx:end_idx] |
|
|
|
|
|
|
|
|
if not isinstance(x, tuple): |
|
|
x = (x,) |
|
|
|
|
|
if self._is_checkpointable(funcs): |
|
|
x = self.activation_checkpoint_func( |
|
|
exec_range_func(start_idx, end_idx), *x |
|
|
) |
|
|
else: |
|
|
x = exec_range_func(start_idx, end_idx)(*x) |
|
|
return x |
|
|
|
|
|
|
|
|
def recursive_setattr(m, attr, value, assert_type=None, type_filter=None): |
|
|
""" |
|
|
Recursively set attributes on a pytorch module or an iterable of modules. |
|
|
If an assert_type is provided, it will assert that the type of the value is the same as the assert_type. |
|
|
If a type_filter is provided, it will only set attributes on modules that match that type. |
|
|
""" |
|
|
if assert_type is not None: |
|
|
assert isinstance(value, assert_type), "Value is not the correct type." |
|
|
|
|
|
|
|
|
if isinstance(m, (list, GeneratorType)): |
|
|
for i in m: |
|
|
recursive_setattr(i, attr, value, assert_type, type_filter) |
|
|
elif isinstance(m, torch.nn.Module): |
|
|
if hasattr(m, attr): |
|
|
if type_filter is None or isinstance(m, type_filter): |
|
|
setattr(m, attr, value) |
|
|
if hasattr(m, "children"): |
|
|
recursive_setattr(m.children(), attr, value, assert_type, type_filter) |
|
|
|
|
|
|
|
|
def _set_use_cache(modules, value: bool): |
|
|
""" |
|
|
Recursively sets an use_cache to `value` on a list of pytorch modules, if they have a use_cache attribute. |
|
|
use_cache is used to decide whether we cache past key value activations or not in inference. |
|
|
""" |
|
|
recursive_setattr(modules, "use_cache", value, assert_type=bool) |
|
|
|
|
|
|
|
|
def configure_sparse_attention(neox_args, attention_type, num_attention_heads, mpu): |
|
|
from deepspeed.ops.sparse_attention import ( |
|
|
SparseSelfAttention, |
|
|
VariableSparsityConfig, |
|
|
FixedSparsityConfig, |
|
|
BigBirdSparsityConfig, |
|
|
BSLongformerSparsityConfig, |
|
|
) |
|
|
from deepspeed.ops.sparse_attention.sparsity_config import ( |
|
|
LocalSlidingWindowSparsityConfig, |
|
|
) |
|
|
|
|
|
if attention_type == "sparse_fixed": |
|
|
|
|
|
|
|
|
sparsity_config = FixedSparsityConfig( |
|
|
num_heads=num_attention_heads, |
|
|
block=neox_args.sparsity_config.get("block", 16), |
|
|
different_layout_per_head=neox_args.sparsity_config.get( |
|
|
"different_layout_per_head", False |
|
|
), |
|
|
num_local_blocks=neox_args.sparsity_config.get("num_local_blocks", 4), |
|
|
num_global_blocks=neox_args.sparsity_config.get("num_global_blocks", 1), |
|
|
num_different_global_patterns=neox_args.sparsity_config.get( |
|
|
"num_different_global_patterns", 1 |
|
|
), |
|
|
attention="unidirectional", |
|
|
horizontal_global_attention=False, |
|
|
) |
|
|
elif attention_type == "sparse_variable": |
|
|
sparsity_config = VariableSparsityConfig( |
|
|
num_heads=num_attention_heads, |
|
|
block=neox_args.sparsity_config.get("block", 16), |
|
|
different_layout_per_head=neox_args.sparsity_config.get( |
|
|
"different_layout_per_head", False |
|
|
), |
|
|
num_random_blocks=neox_args.sparsity_config.get("num_random_blocks", 0), |
|
|
local_window_blocks=neox_args.sparsity_config.get( |
|
|
"local_window_blocks", [4] |
|
|
), |
|
|
global_block_indices=neox_args.sparsity_config.get( |
|
|
"global_block_indices", [0] |
|
|
), |
|
|
global_block_end_indices=neox_args.sparsity_config.get( |
|
|
"global_block_end_indices", None |
|
|
), |
|
|
attention="unidirectional", |
|
|
horizontal_global_attention=False, |
|
|
) |
|
|
elif attention_type == "local": |
|
|
|
|
|
num_local_blocks = neox_args.sparsity_config.get( |
|
|
"num_local_blocks", |
|
|
neox_args.sparsity_config.get("num_sliding_window_blocks", 4), |
|
|
) |
|
|
sparsity_config = LocalSlidingWindowSparsityConfig( |
|
|
num_heads=num_attention_heads, |
|
|
block=neox_args.sparsity_config.get("block", 16), |
|
|
num_sliding_window_blocks=num_local_blocks, |
|
|
attention="unidirectional", |
|
|
) |
|
|
elif attention_type == "bigbird": |
|
|
sparsity_config = BigBirdSparsityConfig( |
|
|
num_heads=num_attention_heads, |
|
|
block=neox_args.sparsity_config.get("block", 16), |
|
|
different_layout_per_head=neox_args.sparsity_config.get( |
|
|
"different_layout_per_head", False |
|
|
), |
|
|
num_random_blocks=neox_args.sparsity_config.get("num_random_blocks", 1), |
|
|
num_sliding_window_blocks=neox_args.sparsity_config.get( |
|
|
"num_sliding_window_blocks", 3 |
|
|
), |
|
|
num_global_blocks=neox_args.sparsity_config.get("num_global_blocks", 1), |
|
|
attention="unidirectional", |
|
|
) |
|
|
elif attention_type == "bslongformer": |
|
|
sparsity_config = BSLongformerSparsityConfig( |
|
|
num_heads=num_attention_heads, |
|
|
block=neox_args.sparsity_config.get("block", 16), |
|
|
different_layout_per_head=neox_args.sparsity_config.get( |
|
|
"different_layout_per_head", False |
|
|
), |
|
|
num_sliding_window_blocks=neox_args.sparsity_config.get( |
|
|
"num_sliding_window_blocks", 3 |
|
|
), |
|
|
global_block_indices=neox_args.sparsity_config.get( |
|
|
"global_block_indices", [0] |
|
|
), |
|
|
global_block_end_indices=neox_args.sparsity_config.get( |
|
|
"global_block_end_indices", None |
|
|
), |
|
|
attention="unidirectional", |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Attention type {attention_type} not recognized") |
|
|
return SparseSelfAttention( |
|
|
sparsity_config=sparsity_config, |
|
|
max_seq_length=neox_args.seq_length, |
|
|
attn_mask_mode="add", |
|
|
mpu=mpu, |
|
|
) |
|
|
|
|
|
|
|
|
def get_fusion_type(neox_args): |
|
|
fusion_type = SoftmaxFusionTypes.none |
|
|
if neox_args.scaled_upper_triang_masked_softmax_fusion: |
|
|
fusion_type = SoftmaxFusionTypes.upper_triang |
|
|
elif neox_args.scaled_masked_softmax_fusion: |
|
|
fusion_type = SoftmaxFusionTypes.general |
|
|
return fusion_type |
|
|
|