MicheleDusi's picture
Upload folder using huggingface_hub
25fff03 verified
# Copyright (c) 2021 EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for models."""
import torch
from megatron.model.norms import LayerNorm, RMSNorm, ScaleNorm
from megatron.model.fused_softmax import SoftmaxFusionTypes
from types import GeneratorType
def get_params_for_weight_decay_optimization(module, neox_args):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and biases will have no weight decay but the rest will.
"""
weight_decay_params = {"params": []}
no_weight_decay_params = {"params": [], "weight_decay": 0.0}
for module_ in module.modules():
if any(
[
isinstance(module_, LayerNorm),
isinstance(module_, RMSNorm),
isinstance(module_, ScaleNorm),
]
) or (
neox_args.weight_decay == 0.0
): # also include all parameters here if no weight decay is being done
no_weight_decay_params["params"].extend(
[p for p in list(module_._parameters.values()) if p is not None]
)
else:
weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None and n != "bias"
]
)
no_weight_decay_params["params"].extend(
[
p
for n, p in list(module_._parameters.items())
if p is not None and n == "bias"
]
)
if neox_args.weight_decay == 0.0:
# only return a single param group
# with onebitadam, we want to minimize the calls to compressed_allreduce. Every param group calls it once.
# to avoid this, only use a single param group when weight decay is off.
return [no_weight_decay_params]
return weight_decay_params, no_weight_decay_params
def exists(x):
return x is not None
class Lambda(torch.nn.Module):
def __init__(self, func):
super().__init__()
self.func = func
def forward(self, x):
return self.func(x)
class SequentialWrapper(torch.nn.Module):
"""
Used to convert a deepspeed PipelineModule to an nn.Sequential like model whilst retaining
activation checkpointing.
"""
def __init__(
self,
layers,
activation_checkpoint_interval,
activation_checkpoint_func,
parent_class_name=None,
):
super().__init__()
self.sequential = torch.nn.Sequential(*layers)
self.activation_checkpoint_interval = activation_checkpoint_interval
self.parent_class_name = parent_class_name
self.activation_checkpoint_func = activation_checkpoint_func
def _is_checkpointable(self, funcs):
if self.parent_class_name == "GPT2ModelPipe":
return all(
"ParallelTransformerLayerPipe" in f.__class__.__name__ for f in funcs
)
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
return any(len(list(p)) > 0 for p in params)
def inference_mode(self, use_cache=True):
"""
Sets up the model for inference by turning on k/v caching (if specified) and setting `parallel output` of the final layer to false,
so logits are gathered across model parallel ranks.
:param cache: (bool) True if you want to use caching during inference, False otherwise
"""
_set_use_cache(self.sequential, use_cache)
def train_mode(self):
"""
Sets up the model for training by turning off k/v caching.
"""
_set_use_cache(self.sequential, False)
def forward(self, forward_input):
def exec_range_func(start, end):
"""Helper function to be used with checkpoint()
Adapted from torch.utils.checkpoint:checkpoint_sequential()
"""
def exec_func(*inputs):
# Single tensor inputs need to be unwrapped
if len(inputs) == 1:
inputs = inputs[0]
for idx, layer in enumerate(self.sequential[start:end]):
inputs = layer(inputs)
return inputs
return exec_func
if self.activation_checkpoint_interval == 0:
func = exec_range_func(0, len(self.sequential))
x = func(forward_input)
else:
num_layers = len(self.sequential)
x = forward_input
for start_idx in range(0, num_layers, self.activation_checkpoint_interval):
end_idx = min(
start_idx + self.activation_checkpoint_interval, num_layers
)
funcs = self.sequential[start_idx:end_idx]
# Since we either pass tensors or tuples of tensors without unpacking, we
# need to be careful not to double-wrap tensors with tuple.
if not isinstance(x, tuple):
x = (x,)
if self._is_checkpointable(funcs):
x = self.activation_checkpoint_func(
exec_range_func(start_idx, end_idx), *x
)
else:
x = exec_range_func(start_idx, end_idx)(*x)
return x
def recursive_setattr(m, attr, value, assert_type=None, type_filter=None):
"""
Recursively set attributes on a pytorch module or an iterable of modules.
If an assert_type is provided, it will assert that the type of the value is the same as the assert_type.
If a type_filter is provided, it will only set attributes on modules that match that type.
"""
if assert_type is not None:
assert isinstance(value, assert_type), "Value is not the correct type."
# if m is a list or a generator, iterate over the elements
if isinstance(m, (list, GeneratorType)):
for i in m:
recursive_setattr(i, attr, value, assert_type, type_filter)
elif isinstance(m, torch.nn.Module):
if hasattr(m, attr):
if type_filter is None or isinstance(m, type_filter):
setattr(m, attr, value)
if hasattr(m, "children"):
recursive_setattr(m.children(), attr, value, assert_type, type_filter)
def _set_use_cache(modules, value: bool):
"""
Recursively sets an use_cache to `value` on a list of pytorch modules, if they have a use_cache attribute.
use_cache is used to decide whether we cache past key value activations or not in inference.
"""
recursive_setattr(modules, "use_cache", value, assert_type=bool)
def configure_sparse_attention(neox_args, attention_type, num_attention_heads, mpu):
from deepspeed.ops.sparse_attention import (
SparseSelfAttention,
VariableSparsityConfig,
FixedSparsityConfig,
BigBirdSparsityConfig,
BSLongformerSparsityConfig,
)
from deepspeed.ops.sparse_attention.sparsity_config import (
LocalSlidingWindowSparsityConfig,
)
if attention_type == "sparse_fixed":
# you can think of local window size as `block_size` * `num_local_blocks`.
# so if you wanted to set a local window size of 256, set block size to 16 and `num_local_blocks` to 16
sparsity_config = FixedSparsityConfig(
num_heads=num_attention_heads,
block=neox_args.sparsity_config.get("block", 16),
different_layout_per_head=neox_args.sparsity_config.get(
"different_layout_per_head", False
),
num_local_blocks=neox_args.sparsity_config.get("num_local_blocks", 4),
num_global_blocks=neox_args.sparsity_config.get("num_global_blocks", 1),
num_different_global_patterns=neox_args.sparsity_config.get(
"num_different_global_patterns", 1
),
attention="unidirectional",
horizontal_global_attention=False,
)
elif attention_type == "sparse_variable":
sparsity_config = VariableSparsityConfig(
num_heads=num_attention_heads,
block=neox_args.sparsity_config.get("block", 16),
different_layout_per_head=neox_args.sparsity_config.get(
"different_layout_per_head", False
),
num_random_blocks=neox_args.sparsity_config.get("num_random_blocks", 0),
local_window_blocks=neox_args.sparsity_config.get(
"local_window_blocks", [4]
),
global_block_indices=neox_args.sparsity_config.get(
"global_block_indices", [0]
),
global_block_end_indices=neox_args.sparsity_config.get(
"global_block_end_indices", None
),
attention="unidirectional",
horizontal_global_attention=False,
)
elif attention_type == "local":
# can configure with `num_local_blocks` or `num_sliding_window_blocks`
num_local_blocks = neox_args.sparsity_config.get(
"num_local_blocks",
neox_args.sparsity_config.get("num_sliding_window_blocks", 4),
)
sparsity_config = LocalSlidingWindowSparsityConfig(
num_heads=num_attention_heads,
block=neox_args.sparsity_config.get("block", 16),
num_sliding_window_blocks=num_local_blocks,
attention="unidirectional",
)
elif attention_type == "bigbird":
sparsity_config = BigBirdSparsityConfig(
num_heads=num_attention_heads,
block=neox_args.sparsity_config.get("block", 16),
different_layout_per_head=neox_args.sparsity_config.get(
"different_layout_per_head", False
),
num_random_blocks=neox_args.sparsity_config.get("num_random_blocks", 1),
num_sliding_window_blocks=neox_args.sparsity_config.get(
"num_sliding_window_blocks", 3
),
num_global_blocks=neox_args.sparsity_config.get("num_global_blocks", 1),
attention="unidirectional",
)
elif attention_type == "bslongformer":
sparsity_config = BSLongformerSparsityConfig(
num_heads=num_attention_heads,
block=neox_args.sparsity_config.get("block", 16),
different_layout_per_head=neox_args.sparsity_config.get(
"different_layout_per_head", False
),
num_sliding_window_blocks=neox_args.sparsity_config.get(
"num_sliding_window_blocks", 3
),
global_block_indices=neox_args.sparsity_config.get(
"global_block_indices", [0]
),
global_block_end_indices=neox_args.sparsity_config.get(
"global_block_end_indices", None
),
attention="unidirectional",
)
else:
raise ValueError(f"Attention type {attention_type} not recognized")
return SparseSelfAttention(
sparsity_config=sparsity_config,
max_seq_length=neox_args.seq_length,
attn_mask_mode="add",
mpu=mpu,
)
def get_fusion_type(neox_args):
fusion_type = SoftmaxFusionTypes.none
if neox_args.scaled_upper_triang_masked_softmax_fusion:
fusion_type = SoftmaxFusionTypes.upper_triang
elif neox_args.scaled_masked_softmax_fusion:
fusion_type = SoftmaxFusionTypes.general
return fusion_type