ft-left-pythia-160 / tools /merge_mp_partitions.py
MicheleDusi's picture
Upload folder using huggingface_hub
25fff03 verified
# Copyright (c) 2021, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Merge model parallel partitions."""
import os
import sys
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))
)
import torch
from megatron import mpu
from megatron.checkpointing import ensure_directory_exists
from megatron.checkpointing import get_checkpoint_name
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.global_vars import rebuild_tokenizer
from megatron.global_vars import _parse_args
def split_into_partitions(tensor, num_partitions, partition_dim, stride):
per_partition_size = mpu.utils.divide(tensor.size(partition_dim), num_partitions)
per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
partitions_list = torch.split(
tensor, per_partition_per_stride_size, dim=partition_dim
)
partitions = []
for i in range(num_partitions):
partition = torch.cat(partitions_list[i::num_partitions], dim=partition_dim)
partitions.append(partition)
return partitions
def merge_partitions(merged, partitions, partition_dim, stride):
# Number and size of each partition.
num_partitions = len(partitions)
per_partition_size = None
for partition in partitions:
if per_partition_size is None:
per_partition_size = partition.size(partition_dim)
else:
assert per_partition_size == partition.size(partition_dim)
def concat_partitions(partitions_):
with torch.no_grad():
if (per_partition_size * num_partitions) == merged.size(partition_dim):
torch.cat(partitions_, dim=partition_dim, out=merged)
else:
print(
" ***WARNING*** sizes do not match. Will cut "
"the merged partitions by {} along dimension {} "
"to reduce the size from {} to {} ...".format(
(per_partition_size * num_partitions)
- merged.size(partition_dim),
partition_dim,
per_partition_size * num_partitions,
merged.size(partition_dim),
)
)
merged_ = torch.cat(partitions_, dim=partition_dim)
merged_split = torch.split(
merged_, merged.size(partition_dim), dim=partition_dim
)
merged_ = merged_split[0]
assert merged_.size(partition_dim) == merged.size(partition_dim)
merged.data.copy_(merged_.data)
# If stride is 1, then do simple concatenation.
if stride == 1:
concat_partitions(partitions)
return
# For none unity strides, first split based on stride and then group.
per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
# Chunk and build a list.
chunks = None
for i, partition in enumerate(partitions):
chunk = torch.split(partition, per_partition_per_stride_size, dim=partition_dim)
if chunks is None:
chunks = [0] * (num_partitions * len(chunk))
chunks[i::num_partitions] = chunk
# Concatinate.
concat_partitions(chunks)
return
def get_model(model_type):
if model_type == "GPT2":
from pretrain_gpt2 import model_provider
else:
raise Exception("unrecognized model type: {}".format(model_type))
model = model_provider()
model = model.half()
return model
def get_parallel_checkpoint_name(path):
tracker_filename = get_checkpoint_tracker_filename(path)
iteration = 0
with open(tracker_filename, "r") as f:
metastring = f.read().strip()
iteration = int(metastring)
assert iteration > 0
checkpoint_name = get_checkpoint_name(path, iteration)
return checkpoint_name, iteration
def test_split_merge():
print("testing split and merge ...")
# [QKV.ROW-COL]
tensor = torch.FloatTensor(
[
[1.11, 1.12, 1.13, 1.14, 1.15],
[1.21, 1.22, 1.23, 1.24, 1.25],
[1.31, 1.32, 1.33, 1.34, 1.35],
[1.41, 1.42, 1.43, 1.44, 1.45],
[2.11, 2.12, 2.13, 2.14, 2.15],
[2.21, 2.22, 2.23, 2.24, 2.25],
[2.31, 2.32, 2.33, 2.34, 2.35],
[2.41, 2.42, 2.43, 2.44, 2.45],
[3.11, 3.12, 3.13, 3.14, 3.15],
[3.21, 3.22, 3.23, 3.24, 3.25],
[3.31, 3.32, 3.33, 3.34, 3.35],
[3.41, 3.42, 3.43, 3.44, 3.45],
]
)
num_partitions = 2
partition_dim = 0
stride = 3
partitions = split_into_partitions(tensor, num_partitions, partition_dim, stride)
merged = torch.zeros_like(tensor)
merge_partitions(merged, partitions, partition_dim, stride)
max_error = (merged - tensor).abs().max()
print(" > max error (should be zero): {}".format(max_error))
def get_mp_merge_args(parser):
"""Provide extra arguments required for merging."""
group = parser.add_argument_group(title="mp merge")
group.add_argument(
"--model-type",
type=str,
required=True,
choices=["BERT", "GPT2", "RACE", "MNLI", "QQP"],
help="Type of the model.",
)
return parser
def main():
# Args
args = _parse_args(extra_args_provider=get_mp_merge_args)
model_type = args.model_type
orig_model_parallel_size = args.model_parallel_size
args.model_parallel_size = 1
tokenizer = rebuild_tokenizer(args)
print("\n merging model parallel partitions ...")
print(" > number of partitions: {}".format(orig_model_parallel_size))
print(" > checkpoint path: {}".format(args.load))
print(" > model parameters:")
print(" number of tokens ................ {} ".format(tokenizer.vocab_size))
print(" number of layers ................ {}".format(args.num_layers))
print(" hidden size ..................... {}".format(args.hidden_size))
print(" number of attention heads ....... {}".format(args.num_attention_heads))
print(
" maximum position embeddings ..... {}".format(args.max_position_embeddings)
)
# Full model.
print("> building the full model ...")
mpu.initialize.set_model_parallel_world_size(1)
mpu.initialize.set_model_parallel_rank(0)
merged_model = get_model(model_type)
# Build and load partitions.
partitions = []
iteration = 0
args.model_parallel_size = orig_model_parallel_size
tokenizer = rebuild_tokenizer(args)
mpu.initialize.set_model_parallel_world_size(args.model_parallel_size)
for rank in range(args.model_parallel_size):
mpu.initialize.set_model_parallel_rank(rank)
checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
print("> loading {} ...".format(checkpoint_name))
model_ = get_model(model_type)
sd = torch.load(checkpoint_name, map_location="cpu")
model_.load_state_dict(sd["model"])
partitions.append(model_)
# Parameter generators so we can loop through them semiltaneouly.
merged_params_gen = merged_model.named_parameters()
partitions_params_gen = [partition.named_parameters() for partition in partitions]
while True:
try:
# Get the params and check names.
name, merged_param = next(merged_params_gen)
print(" > working on {} ...".format(name))
print(
" merged type: {}, size: {}".format(
merged_param.dtype, list(merged_param.size())
)
)
partitions_param = []
for rank, partition_params_gen in enumerate(partitions_params_gen):
partition_name, partition_param = next(partition_params_gen)
assert partition_name == name
partitions_param.append(partition_param)
print(
" partition {} type: {}, size: {}".format(
rank, partition_param.dtype, list(partition_param.size())
)
)
# For the non-parallel parameters, simply copy the rank 0 values.
if not hasattr(merged_param, "model_parallel"):
print(" none-parallel parameter, simple copy from rank 0")
with torch.no_grad():
merged_param.data.copy_(partitions_param[0].data)
# For parallel parameters, merge the values
else:
print(
" parallel parameter merge with stride {} along "
"dimension {}".format(
merged_param.stride, merged_param.partition_dim
)
)
merge_partitions(
merged_param,
partitions_param,
merged_param.partition_dim,
merged_param.stride,
)
except StopIteration:
break
# Save the model.
args.model_parallel_size = 1
mpu.initialize.set_model_parallel_rank(0)
sd = {}
sd["model"] = merged_model.state_dict()
sd["iteration"] = iteration
merged_path = os.path.join(args.load, "merged")
checkpoint_name = get_checkpoint_name(merged_path, iteration)
ensure_directory_exists(checkpoint_name)
print("> saving merged model to {}".format(checkpoint_name))
torch.save(sd, checkpoint_name)
print("done :-)")
if __name__ == "__main__":
main()