diff --git a/tools/ckpts/merge_mp_partitions.py b/tools/ckpts/merge_mp_partitions.py deleted file mode 100644 index 6509718ff..000000000 --- a/tools/ckpts/merge_mp_partitions.py +++ /dev/null @@ -1,293 +0,0 @@ -# Copyright (c) 2021, EleutherAI -# This file is based on code by the authors denoted below and has been modified from its original version. -# -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Merge model parallel partitions.""" - -import os -import sys - -sys.path.append( - os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) -) - -import torch - -from megatron import mpu -from megatron.checkpointing import ensure_directory_exists -from megatron.checkpointing import get_checkpoint_name -from megatron.checkpointing import get_checkpoint_tracker_filename -from megatron.global_vars import rebuild_tokenizer -from megatron.global_vars import _parse_args - - -def split_into_partitions(tensor, num_partitions, partition_dim, stride): - - per_partition_size = mpu.utils.divide(tensor.size(partition_dim), num_partitions) - per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride) - - partitions_list = torch.split( - tensor, per_partition_per_stride_size, dim=partition_dim - ) - - partitions = [] - for i in range(num_partitions): - partition = torch.cat(partitions_list[i::num_partitions], dim=partition_dim) - partitions.append(partition) - - return partitions - - -def merge_partitions(merged, partitions, partition_dim, stride): - - # Number and size of each partition. - num_partitions = len(partitions) - per_partition_size = None - for partition in partitions: - if per_partition_size is None: - per_partition_size = partition.size(partition_dim) - else: - assert per_partition_size == partition.size(partition_dim) - - def concat_partitions(partitions_): - with torch.no_grad(): - if (per_partition_size * num_partitions) == merged.size(partition_dim): - torch.cat(partitions_, dim=partition_dim, out=merged) - else: - print( - " ***WARNING*** sizes do not match. Will cut " - "the merged partitions by {} along dimension {} " - "to reduce the size from {} to {} ...".format( - (per_partition_size * num_partitions) - - merged.size(partition_dim), - partition_dim, - per_partition_size * num_partitions, - merged.size(partition_dim), - ) - ) - merged_ = torch.cat(partitions_, dim=partition_dim) - merged_split = torch.split( - merged_, merged.size(partition_dim), dim=partition_dim - ) - merged_ = merged_split[0] - assert merged_.size(partition_dim) == merged.size(partition_dim) - merged.data.copy_(merged_.data) - - # If stride is 1, then do simple concatenation. - if stride == 1: - concat_partitions(partitions) - return - - # For none unity strides, first split based on stride and then group. - per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride) - # Chunk and build a list. - chunks = None - for i, partition in enumerate(partitions): - chunk = torch.split(partition, per_partition_per_stride_size, dim=partition_dim) - - if chunks is None: - chunks = [0] * (num_partitions * len(chunk)) - chunks[i::num_partitions] = chunk - - # Concatinate. - concat_partitions(chunks) - - return - - -def get_model(model_type): - - if model_type == "GPT2": - from pretrain_gpt2 import model_provider - else: - raise Exception("unrecognized model type: {}".format(model_type)) - - model = model_provider() - model = model.half() - - return model - - -def get_parallel_checkpoint_name(path): - - tracker_filename = get_checkpoint_tracker_filename(path) - iteration = 0 - with open(tracker_filename, "r") as f: - metastring = f.read().strip() - iteration = int(metastring) - assert iteration > 0 - checkpoint_name = get_checkpoint_name(path, iteration) - - return checkpoint_name, iteration - - -def test_split_merge(): - - print("testing split and merge ...") - - # [QKV.ROW-COL] - tensor = torch.FloatTensor( - [ - [1.11, 1.12, 1.13, 1.14, 1.15], - [1.21, 1.22, 1.23, 1.24, 1.25], - [1.31, 1.32, 1.33, 1.34, 1.35], - [1.41, 1.42, 1.43, 1.44, 1.45], - [2.11, 2.12, 2.13, 2.14, 2.15], - [2.21, 2.22, 2.23, 2.24, 2.25], - [2.31, 2.32, 2.33, 2.34, 2.35], - [2.41, 2.42, 2.43, 2.44, 2.45], - [3.11, 3.12, 3.13, 3.14, 3.15], - [3.21, 3.22, 3.23, 3.24, 3.25], - [3.31, 3.32, 3.33, 3.34, 3.35], - [3.41, 3.42, 3.43, 3.44, 3.45], - ] - ) - - num_partitions = 2 - partition_dim = 0 - stride = 3 - partitions = split_into_partitions(tensor, num_partitions, partition_dim, stride) - - merged = torch.zeros_like(tensor) - merge_partitions(merged, partitions, partition_dim, stride) - - max_error = (merged - tensor).abs().max() - print(" > max error (should be zero): {}".format(max_error)) - - -def get_mp_merge_args(parser): - """Provide extra arguments required for merging.""" - group = parser.add_argument_group(title="mp merge") - - group.add_argument( - "--model-type", - type=str, - required=True, - choices=["BERT", "GPT2", "RACE", "MNLI", "QQP"], - help="Type of the model.", - ) - - return parser - - -def main(): - - # Args - args = _parse_args(extra_args_provider=get_mp_merge_args) - model_type = args.model_type - orig_model_parallel_size = args.model_parallel_size - args.model_parallel_size = 1 - tokenizer = rebuild_tokenizer(args) - - print("\n merging model parallel partitions ...") - print(" > number of partitions: {}".format(orig_model_parallel_size)) - print(" > checkpoint path: {}".format(args.load)) - print(" > model parameters:") - print(" number of tokens ................ {} ".format(tokenizer.vocab_size)) - print(" number of layers ................ {}".format(args.num_layers)) - print(" hidden size ..................... {}".format(args.hidden_size)) - print(" number of attention heads ....... {}".format(args.num_attention_heads)) - print( - " maximum position embeddings ..... {}".format(args.max_position_embeddings) - ) - - # Full model. - print("> building the full model ...") - mpu.initialize.set_model_parallel_world_size(1) - mpu.initialize.set_model_parallel_rank(0) - merged_model = get_model(model_type) - - # Build and load partitions. - partitions = [] - iteration = 0 - args.model_parallel_size = orig_model_parallel_size - tokenizer = rebuild_tokenizer(args) - mpu.initialize.set_model_parallel_world_size(args.model_parallel_size) - for rank in range(args.model_parallel_size): - mpu.initialize.set_model_parallel_rank(rank) - checkpoint_name, iteration = get_parallel_checkpoint_name(args.load) - print("> loading {} ...".format(checkpoint_name)) - model_ = get_model(model_type) - sd = torch.load(checkpoint_name, map_location="cpu") - model_.load_state_dict(sd["model"]) - partitions.append(model_) - - # Parameter generators so we can loop through them semiltaneouly. - merged_params_gen = merged_model.named_parameters() - partitions_params_gen = [partition.named_parameters() for partition in partitions] - while True: - try: - - # Get the params and check names. - name, merged_param = next(merged_params_gen) - print(" > working on {} ...".format(name)) - print( - " merged type: {}, size: {}".format( - merged_param.dtype, list(merged_param.size()) - ) - ) - partitions_param = [] - for rank, partition_params_gen in enumerate(partitions_params_gen): - partition_name, partition_param = next(partition_params_gen) - assert partition_name == name - partitions_param.append(partition_param) - print( - " partition {} type: {}, size: {}".format( - rank, partition_param.dtype, list(partition_param.size()) - ) - ) - - # For the non-parallel parameters, simply copy the rank 0 values. - if not hasattr(merged_param, "model_parallel"): - print(" none-parallel parameter, simple copy from rank 0") - with torch.no_grad(): - merged_param.data.copy_(partitions_param[0].data) - # For parallel parameters, merge the values - else: - print( - " parallel parameter merge with stride {} along " - "dimension {}".format( - merged_param.stride, merged_param.partition_dim - ) - ) - merge_partitions( - merged_param, - partitions_param, - merged_param.partition_dim, - merged_param.stride, - ) - - except StopIteration: - break - - # Save the model. - args.model_parallel_size = 1 - mpu.initialize.set_model_parallel_rank(0) - sd = {} - sd["model"] = merged_model.state_dict() - sd["iteration"] = iteration - merged_path = os.path.join(args.load, "merged") - checkpoint_name = get_checkpoint_name(merged_path, iteration) - ensure_directory_exists(checkpoint_name) - print("> saving merged model to {}".format(checkpoint_name)) - torch.save(sd, checkpoint_name) - - print("done :-)") - - -if __name__ == "__main__": - - main()