diff --git a/.gitignore b/.gitignore index 3b4e6fc..acfadd4 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,8 @@ # pre-commit config ./.pre-commit-config.yaml + +# checkpoint +*.checkpoint_dir +*.distcp +.metadata \ No newline at end of file diff --git a/python/example/nanogpt_4D_finetune/README.md b/python/example/nanogpt_4D_finetune/README.md index 4de2467..33ea83a 100644 --- a/python/example/nanogpt_4D_finetune/README.md +++ b/python/example/nanogpt_4D_finetune/README.md @@ -17,9 +17,17 @@ cd data/shakespeare && python3 prepare.py Then, to finetune the Shakespeare dataset in an environment of multiple GPUs, run ``` -torchrun --standalone --nproc_per_node={Number of GPUs} finetune_4D.py config/finetune_shakespeare.py --compile=False --dp_size={DP Size} --tp_size={TP Size} +torchrun --standalone --nproc_per_node={Number of GPUs} finetune_4D.py config/finetune_shakespeare.py --compile=False --dp_size={DP Size} --tp_size={TP Size} --save_checkpoint_path={path to save checkpoints} ``` -where `DP Size` and `TP Size` denote the the degrees of Data and Tensor Parallelism that suit your environment. +where `DP Size` and `TP Size` denote the the degrees of Data and Tensor Parallelism that suit your environment, `save_checkpoint_path` is path to save checkpoints during the training. + +If you want to resume training from a checkpoint, add `--load_checkpoint_path={path to load checkpoint}` in the command. + +For example: +``` +torchrun --standalone --nproc_per_node={Number of GPUs} finetune_4D.py config/finetune_shakespeare.py --compile=False --dp_size={DP Size} --tp_size={TP Size} --save_checkpoint_path=./nanogpt_checkpoint_dir --load_checkpoint_path=./nanogpt_checkpoint_dir/iter_5 +``` + To produce the single GPU result, run ``` @@ -53,4 +61,4 @@ For the bf16 runs, in `base_train.py`, instead of using `torch.amp.autocast`, we 2. veScale does not focus on fp16, as fp16 is ancient in industry. -3. Checkpointing is not supported. \ No newline at end of file +3. Checkpointing is supported now. \ No newline at end of file diff --git a/python/example/nanogpt_4D_finetune/finetune_4D.py b/python/example/nanogpt_4D_finetune/finetune_4D.py index 4201a22..3f9bafd 100644 --- a/python/example/nanogpt_4D_finetune/finetune_4D.py +++ b/python/example/nanogpt_4D_finetune/finetune_4D.py @@ -42,7 +42,7 @@ from vescale.optim.distributed_optimizer import DistributedOptimizer from vescale.optim.base_optimizer import BasicOptimizer, GradOptimizerHookBase from sharding_plan import nanoGPT_plan - +import vescale # ----------------------------------------------------------------------------- # default config values designed to train a gpt2 (124M) on OpenWebText @@ -93,7 +93,8 @@ dp_size = 4 tp_size = 1 DDP_grads_in_fp32 = True - +save_checkpoint_path = "./nanogpt_checkpoint_dir" +load_checkpoint_path = "" config = {} @@ -208,8 +209,7 @@ def get_batch(split, bsz=batch_size, lbsz=local_batch_size): for k in ["n_layer", "n_head", "n_embd", "block_size", "bias", "vocab_size"]: model_args[k] = getattr(model.config, k) elif init_from == "resume": - print("WARNING: checkpointing is not supported") - print(f"Resuming the training process from: {out_dir}") + print(f"Resuming the training process from: {load_checkpoint_path}") # determine the vocab size we'll use for from-scratch training if meta_vocab_size is None: print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") @@ -326,6 +326,11 @@ def get_lr(it): global config wandb.init(project=wandb_project, name=wandb_run_name, config=config) + # Load checkpoint + if load_checkpoint_path: + checkpoint_state = {"model": model, "optimizer": optimizer} + with mesh: + vescale.checkpoint.load(load_checkpoint_path, checkpoint_state) # training loop X, Y = get_batch("train") # fetch the very first batch t0 = time.time() @@ -354,6 +359,12 @@ def get_lr(it): "mfu": running_mfu * 100, # convert to percentage } ) + if iter_num > 0: + # When iter_num == 0, the training does not start sotoptimizer state is empty, + # Don't save checkpoint + checkpoint_state = {"model": model, "optimizer": optimizer} + with mesh: + vescale.checkpoint.save(os.path.join(save_checkpoint_path, f"iter_{iter_num}"), checkpoint_state) if iter_num == 0 and eval_only: break diff --git a/python/vescale/__init__.py b/python/vescale/__init__.py index 242c534..7605c00 100644 --- a/python/vescale/__init__.py +++ b/python/vescale/__init__.py @@ -25,6 +25,7 @@ import torch import torch.utils._pytree as pytree +import vescale.checkpoint as checkpoint from vescale.dmodule.api import parallelize_module, is_dmodule, PlacementsInterface from vescale.dtensor.api import normalize_placements, distribute_tensor, from_local, redistribute_dtensor, to_local from vescale.dtensor.device_mesh import DeviceMesh, init_device_mesh @@ -44,6 +45,7 @@ "set_plan_overriding_policy", "get_plan_overriding_policy", "auto_parallelize_module", + "checkpoint", "DTensor", "DeviceMesh", "init_device_mesh", diff --git a/python/vescale/checkpoint/README.md b/python/vescale/checkpoint/README.md new file mode 100644 index 0000000..ca4fcd7 --- /dev/null +++ b/python/vescale/checkpoint/README.md @@ -0,0 +1,48 @@ +# vescale.checkpoint + +`vescale.checkpoint` is an automatic distributed checkpointing system for LLM training and inference. + +## Why `vescale.checkpoint`? + +1. Manually managing distributed checkpointing, such as writing model saving/loading/resharding scripts under complex distributed environments, is painful and error-prone. + +2. `torch.save` and `torch.load` lacks the capability of managing checkpointing in distributed settings, let alone resharding checkpoints for different distributed settings. +Although existing systems extend `torch.save` for saving checkpoints on multiple GPUs or machines, the saved checkpoints are heavily coupled with a single distributed setting like the degrees of data, tensor and pipeline parallelism. Consequently, existing systems with `torch.load` fail to load checkpoints with varying degrees of parallelism, which is common in elastic training or switching between training and fine-tuning. + +3. `PyTorch Distirbuted Checkpoint` indeed supports checkpoint resharding to some extent. Nonetheless, it currently only supports resharding for the simplest data parallelism, but not for the complex tensor nor pipeline parallelism, which are commonly used in 3D parallelism of LLM training. Furthermore, it does not support load-time resharding for Distributed Optimizer, nor provide decent performance optimizations. + +## What is `vescale.checkpoint`? + +`vescale.checkpoint` offers simple and straightforward APIs, +enabling users to load and save distributed model (e.g., `DModule`) and optimizer (e.g., `DistributedOptimizer`) seamlessly, +abstracting away the complexities of underlying details such as process rank and device mesh. + +`vescale.checkpoint` supports load-time checkpoint resharding when varying the degrees of data, tensor, or pipeline (TODO) parallelism for both veScale model (e.g., `DModule`) and optimizer (e.g., `DistributedOptimizer`). + +`vescale.checkpoint` incorporates [fast checkpointing](https://arxiv.org/abs/2402.15627) and various I/O optimization techinques, enhancing I/O efficiency during LLM training. + +`vescale.checkpoint` will be a part of `OmniStore` project, a new open-source project coming soon. + +`vescale.checkpoint` is built on top of `PyTorch Distributed Checkpoint` with significant differences as discussed above. + +## How to use `vescale.checkpoint`? + +- Saving checkpoint: + +``` +# prepare checkpoint state for the model and optimizer +checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer } +# save the checkpoint +vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state) +``` +- Loading checkpoint (under different world size or 3D parallelism degrees): +``` +# prepare checkpoint state for the model and optimizer +checkpoint_state = { "model": distributed_model, "optimizer": distributed_optimizer } +# load the checkpoint +vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state) +``` + +- More examples can be found under `/test/checkpoint` and `/python/example`. + +- Original examples can be found in PyTorch [Distributed Checkpoint](https://github.com/pytorch/pytorch/tree/main/torch/distributed/checkpoint) \ No newline at end of file diff --git a/python/vescale/checkpoint/__init__.py b/python/vescale/checkpoint/__init__.py new file mode 100644 index 0000000..646609f --- /dev/null +++ b/python/vescale/checkpoint/__init__.py @@ -0,0 +1,44 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ +# The "checkpoint" folder is ONLY USED for "open source" version veScale +# If you use veScale in ByteDance, please use OmniStore + +from .api.vescale_checkpointer import VeScaleCheckpointer +from .api.meta_type import CheckpointState + + +def save(path: str, checkpoint_state: CheckpointState): + """ + Save a checkpoint to a given path + Args: + path: Defines the storage path for checkpoint. + checkpoint_state: A dictionary contains key-value pairs for model and optimizer. + - Model: Identified by 'model' key, value should be a model instance. + - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + Example: + >>> checkpoint_state = { "model": distributd_model, "optimizer": distributed_optimizer } + >>> vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state) + """ + VeScaleCheckpointer.save(path, checkpoint_state) + + +def load(path: str, checkpoint_state: CheckpointState): + """ + Load a checkpoint from a given path + Args: + path: Defines the storage path for checkpoint. + checkpoint_state: A dictionary contains key-value pairs for model and optimizer. + - Model: Identified by 'model' key, value should be a model instance. + - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + Example: + >>> checkpoint_state = { "model": distributd_model, "optimizer": distributed_optimizer } + >>> vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state) + """ + VeScaleCheckpointer.load(path, checkpoint_state) diff --git a/python/vescale/checkpoint/api/base_checkpointer.py b/python/vescale/checkpoint/api/base_checkpointer.py new file mode 100644 index 0000000..8d911a2 --- /dev/null +++ b/python/vescale/checkpoint/api/base_checkpointer.py @@ -0,0 +1,47 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +from .meta_type import CheckpointState + + +class BaseCheckpointer: + """ + The Checkpointer class offers APIs that enable users to save and load state dictionarie. + It is designed for extension across various training frameworks. + """ + + @classmethod + def save(cls, path: str, checkpoint_state: CheckpointState): + """ + A Method for saving checkpoint + Args: + path: Defines the storage path for checkpoint. + checkpoint_state: A dictionary contains key-value pairs for model, optimizer and dataloader(TODO). + - Model: Identified by 'model' key, value should be a model instance. + - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + """ + pass + + def load(cls, path: str, checkpoint_state: CheckpointState): + """ + A Method for loading checkpoint + Args: + path: Defines the storage path for checkpoint. + checkpoint_state: A dictionary contains key-value pairs for model, optimizer and dataloader(TODO). + - Model: Identified by 'model' key, value should be a model instance. + - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + """ + pass diff --git a/python/vescale/checkpoint/api/meta_type.py b/python/vescale/checkpoint/api/meta_type.py new file mode 100644 index 0000000..5b85740 --- /dev/null +++ b/python/vescale/checkpoint/api/meta_type.py @@ -0,0 +1,38 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ +# meta_type.py saves all constants and data types commonly used in omnistore + +from enum import Enum +from typing import Dict, Any, TypeVar +from typing_extensions import Protocol, runtime_checkable + + +STATE_DICT_TYPE = Dict[str, Any] + +MODEL_STR = "model" +OPTIMIZER_STR = "optimizer" +SHM_PATH = "/dev/shm" + + +class SupportedStrategy(Enum): + Megatron = 0 + FSDP = 1 + VeScale = 2 + + +@runtime_checkable +class Stateful(Protocol): + def state_dict(self) -> Dict[str, Any]: ... + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... + + +T = TypeVar("T", bound=Stateful) +CheckpointState = Dict[str, T] diff --git a/python/vescale/checkpoint/api/vescale_checkpointer.py b/python/vescale/checkpoint/api/vescale_checkpointer.py new file mode 100644 index 0000000..b7fe0d8 --- /dev/null +++ b/python/vescale/checkpoint/api/vescale_checkpointer.py @@ -0,0 +1,156 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +from .base_checkpointer import BaseCheckpointer +from .meta_type import CheckpointState, MODEL_STR, OPTIMIZER_STR +from ..save_state_dict import save_state_dict +from ..load_state_dict import load_state_dict +from ..planner.vescale.vescale_planner import VeScaleSavePlanner, VeScaleLoadPlanner + + +from ..utilities import bfile +import os +from vescale.optim.distributed_optimizer import initialize_optimizer_state +import torch.distributed as dist +from ..utilities.logger import get_omnistore_logger + +logger = get_omnistore_logger() + +VESCALE_SUPPORTED_TYPES = {MODEL_STR, OPTIMIZER_STR} + + +def deduplicate_2d_list(lst): + seen = set() + deduplicated_list = [] + for item in lst: + # Convert the inner list to a tuple for hashing + tuple_item = tuple(sorted(item)) # Sorting to treat [1, 2] and [2, 1] as the same + if tuple_item not in seen: + seen.add(tuple_item) + # Convert back to list to preserve original type + deduplicated_list.append(item) + return deduplicated_list + + +class VeScaleCheckpointer(BaseCheckpointer): + """ + The Checkpointer class for VeScale, A PyTorch Native Auto Parallelism Framework + """ + + save_planner = VeScaleSavePlanner() + load_planner = VeScaleLoadPlanner() + + @classmethod + def save(cls, path: str, checkpoint_state: CheckpointState): + # Check if we support saving the components + for key in checkpoint_state.keys(): + if key not in VESCALE_SUPPORTED_TYPES: + raise ValueError(f"{key} is not supported by VeScaleCheckpointer") + if bfile.is_local_path(path): + logger.warning( + "The local path for checkpointing should be accessible to all ranks. It can be a NFS/FUSE path" + ) + + # Start saving checkpoint + for key, value in checkpoint_state.items(): + if key == MODEL_STR: + # Get model path + model_path = os.path.join(path, MODEL_STR) + # Create a "model" folder on under root path + if dist.get_rank() == 0: + bfile.makedirs(model_path) + dist.barrier() + # Save model + save_state_dict( + state_dict=value.state_dict(), + path=model_path, + process_group=None, + coordinator_rank=0, + no_dist=False, + planner=cls.save_planner, + ) + elif key == OPTIMIZER_STR: + # Create a "optimizer" folder on under root path + # to save different parts of optimizer + optim_root_path = os.path.join(path, OPTIMIZER_STR) + if dist.get_rank() == 0: + bfile.makedirs(optim_root_path) + dist.barrier() + + # Get optimizer path based on PP rank + optimizer_path = os.path.join(optim_root_path, "pp_0") + # Create optimizer folder on under root path + if dist.get_rank() == 0: + bfile.makedirs(optimizer_path) + dist.barrier() + + save_state_dict( + state_dict=value.state_dict(), + path=optimizer_path, + process_group=None, + coordinator_rank=0, + no_dist=False, + planner=cls.save_planner, + ) + + @classmethod + def load(cls, path: str, checkpoint_state: CheckpointState): + # Add warning + if bfile.is_local_path(path): + logger.warning( + "The local path for checkpointing should be accessible to all ranks. It can be a NFS/FUSE path" + ) + # Check if we support loading the component. + for key in checkpoint_state.keys(): + if key not in VESCALE_SUPPORTED_TYPES: + raise ValueError(f"{key} is not supported by VeScaleCheckpointer") + + # Start loading checkpoint + for key, value in checkpoint_state.items(): + if key == MODEL_STR: + # Get model path + model_path = os.path.join(path, MODEL_STR) + # Get model state dictionary + model_state = value.state_dict() + # Load model + load_state_dict( + state_dict=model_state, + path=model_path, + process_group=None, + coordinator_rank=0, + no_dist=False, + planner=cls.load_planner, + ) + # Load back to model + value.load_state_dict(model_state) + elif key == OPTIMIZER_STR: + optimizer_path = os.path.join(path, f"{OPTIMIZER_STR}", "pp_0") + # Initialize optimizer states + initialize_optimizer_state(value) + # Get optimizer state + optimizer_state = value.state_dict() + # Load optimizer state dictionary + load_state_dict( + state_dict=optimizer_state, + path=optimizer_path, + process_group=None, + coordinator_rank=0, + no_dist=False, + planner=cls.load_planner, + ) + # Load back to optimizer + value.load_state_dict(optimizer_state) + dist.barrier() diff --git a/python/vescale/checkpoint/load_state_dict.py b/python/vescale/checkpoint/load_state_dict.py new file mode 100644 index 0000000..c8ba977 --- /dev/null +++ b/python/vescale/checkpoint/load_state_dict.py @@ -0,0 +1,92 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +from typing import Optional + +import torch +import torch.distributed as dist +from torch.distributed.checkpoint.planner import LoadPlanner +from torch.distributed.checkpoint.utils import _DistWrapper +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner +from torch.distributed.checkpoint.filesystem import FileSystemReader +from .api.meta_type import STATE_DICT_TYPE +import time +from .utilities.logger import get_omnistore_logger + +logger = get_omnistore_logger() + +META_DATA_FILE = ".metadata" + + +def load_state_dict( + state_dict: STATE_DICT_TYPE, + path: str, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[LoadPlanner] = None, +) -> None: + load_start_time = time.time() + """ + [veScale version] Loads a distributed ``state_dict`` in SPMD style. Fix sub-group storage. + """ + + storage_reader = FileSystemReader(path) + + torch._C._log_api_usage_once("omnistore.checkpoint.vescale_checkpoint.load_state_dict") + # Step 0: create distributed world based on process group and coordinator rank + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if process_group: + distW.coordinator_rank = dist.get_global_rank(process_group, distW.coordinator_rank) + if planner is None: + planner = DefaultLoadPlanner() + plan_start_time = time.time() + + # Step 1: all processes create local read plan, + # then coordinator gathers all local plans and create global plan. + def local_step(): + assert planner is not None + meta_read_start_time = time.time() + metadata = storage_reader.read_metadata() + meat_read_cost_time = time.time() - meta_read_start_time + logger.info(f"Finish read meta file. Cost time: {meat_read_cost_time}s") + planner.set_up_planner(state_dict, metadata, distW.is_coordinator) + storage_reader.set_up_storage_reader(metadata, distW.is_coordinator) + + local_plan = planner.create_local_plan() + local_plan = storage_reader.prepare_local_plan(local_plan) + return local_plan + + def global_step(all_local_plans): + assert planner is not None + all_local_plans = planner.create_global_plan(all_local_plans) + all_local_plans = storage_reader.prepare_global_plan(all_local_plans) + return all_local_plans + + central_plan = distW.reduce_scatter("plan", local_step, global_step) + load_ckpt_plan_cost_time = time.time() - plan_start_time + logger.info(f"Finish planning. Cost time: {load_ckpt_plan_cost_time}s") + + read_start_time = time.time() + + # Step 2: all processes read data from path + def read_data(): + assert planner is not None + final_local_plan = planner.finish_plan(central_plan) + all_reads = storage_reader.read_data(final_local_plan, planner) + all_reads.wait() + return None + + _ = distW.all_gather("read", read_data) + read_cost_time = time.time() - read_start_time + logger.info(f"Finish reading. Cost time: {read_cost_time}s") + + load_ckpt_cost_time = time.time() - load_start_time + logger.info(f"Finish loading. Cost time: {load_ckpt_cost_time}s") diff --git a/python/vescale/checkpoint/planner/vescale/__init__.py b/python/vescale/checkpoint/planner/vescale/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/vescale/checkpoint/planner/vescale/vescale_planner.py b/python/vescale/checkpoint/planner/vescale/vescale_planner.py new file mode 100644 index 0000000..3bec62d --- /dev/null +++ b/python/vescale/checkpoint/planner/vescale/vescale_planner.py @@ -0,0 +1,275 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ +import io +import dataclasses +import logging +import torch +from typing import Any, Dict, Union, List, Tuple +from torch.distributed.checkpoint.default_planner import ( + DefaultSavePlanner, + DefaultLoadPlanner, +) +import math +import torch.distributed as dist +from torch.distributed.checkpoint.planner import ( + SavePlan, + LoadPlan, + ReadItem, + WriteItem, + WriteItemType, +) +from vescale.optim.distributed_optimizer import OptimizerStateSpec +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed.checkpoint.metadata import MetadataIndex, Metadata +from vescale.dtensor import DTensor +from .vescale_planner_helpers import ( + _create_write_items, + _create_read_items, + find_state_dict_object, +) + +from vescale.dtensor.device_mesh import mesh_resources + +logger: logging.Logger = logging.getLogger(__file__) + +__all__ = [ + "VeScaleSavePlanner", + "VeScaleLoadPlanner", + "create_default_local_load_plan", + "create_default_local_save_plan", +] + + +def sort_rank_ranges(process_list: List[Tuple]) -> List[Tuple]: + """ + Decide which rank is receiver and writer + Let rank with most parameters receives and writes tensors + for the best communication cost + If two ranks has the same data size, choose the smaller rank + Args: + A process list with tuples, each tuple is (rank, data_size) + Returns: + A sorted list, data size are sorted in descending order, + if two ranks has the same data size, ranks are in the asceonding order + """ + sorted_process_list = sorted(process_list, key=lambda x: (-x[1], x[0])) + return sorted_process_list + + +def custom_dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]: + """ + A function to remove duplicate tensors to write + when creating global writing plan for saving checkpoint + """ + all_plans = list(all_plans) + key_to_plan: Dict[MetadataIndex, List[int]] = {} + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + # NOTE: the only difference from pytorch official + if write_item.type != WriteItemType.SHARD: + key_to_plan.setdefault(write_item.index, []).append(plan_idx) + + replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1} + + # Remove duplicates by always keeping the first entry. + # Compute the per-rank remove set. + plan_to_keys: Dict[int, List[MetadataIndex]] = {} + for key, plans in replicated_items.items(): + for plan_idx in plans[1:]: + plan_to_keys.setdefault(plan_idx, []).append(key) + logger.info("Duplicate keys to remove: %s", plan_to_keys) + + for plan_idx, keys in plan_to_keys.items(): + key_set = set(keys) + # rewrite items and remove elements + new_items = [write_item for write_item in all_plans[plan_idx].items if write_item.index not in key_set] + all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) + + return all_plans + + +class VeScaleLoadPlanner(DefaultLoadPlanner): + """ + A planner class for loading vescale checkpoint using PyTorch DCP + """ + + def __init__(self): + super().__init__() + + def create_local_plan(self) -> LoadPlan: + return create_default_local_load_plan(self.state_dict, self.metadata) + + def resolve_tensor(self, read_item: ReadItem): + tensor = self.lookup_tensor(read_item.dest_index) + return self.transform_tensor(read_item, tensor) + + def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: + """ + This is an extension from the planner interface to make it easy to extend the default planner + """ + return find_state_dict_object(self.state_dict, index) + + +def create_default_local_load_plan(state_dict: Dict[str, Any], metadata: Metadata) -> LoadPlan: + """ + A function for creating local loading plan for loading checkpoint + """ + requests = [] + for fqn, obj in state_dict.items(): + md = metadata.state_dict_metadata[fqn] + if isinstance(obj, DTensor): + if obj.device_mesh.get_coordinate() is not None: + requests += _create_read_items(fqn, md, obj) + elif isinstance(obj, ShardedTensor): + # For veScale DOptimizer, it will provide empty shards + # if current process does not own the shard of tensor + local_shards = obj.local_shards() + total_size = 0 + for local_shard in local_shards: + for size in local_shard.metadata.shard_sizes: + size += total_size + if size > 0: + requests += _create_read_items(fqn, md, obj) + elif isinstance(obj, OptimizerStateSpec): + # If the state is distributed on multiple dp ranks + # Read with local_shape, then in DOptimizer then + # get flaaten to 1D and get the part belonging to current dp rank + if obj.dp_ranks_ranges: + obj.local_tensor = torch.zeros( + obj.local_shape, dtype=obj.local_tensor.dtype, device=obj.local_tensor.device + ) + requests += _create_read_items(fqn, md, obj) + else: + # If the state is owned by only one dp rank + # Read directly + obj.local_tensor = obj.local_tensor.reshape(obj.local_shape) + requests += _create_read_items(fqn, md, obj) + else: + requests += _create_read_items(fqn, md, obj) + return LoadPlan(requests) + + +class VeScaleSavePlanner(DefaultSavePlanner): + """ + A planner class for saving vescale checkpoint using PyTorch DCP + """ + + def __init__(self): + super().__init__() + + def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]: + object = self.lookup_object(write_item.index) + return self.transform_object(write_item, object) + + def create_local_plan(self) -> SavePlan: + plan = create_default_local_save_plan(self.state_dict, self.is_coordinator) + if self.flatten_state_dict: + plan = dataclasses.replace(plan, planner_data=self.mappings) + self.plan = plan + return self.plan + + def lookup_object(self, index: MetadataIndex) -> Any: + return find_state_dict_object(self.state_dict, index) + + def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: + self.dedup_replicated_tensors = True + # all_plans = custom_dedup_tensors(all_plans) + rst_value = super().create_global_plan(all_plans) + return rst_value + + +def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: bool) -> SavePlan: + """ + A function for creating local saving plan for saving checkpoint + """ + requests = [] + device_mesh = mesh_resources.get_current_mesh() + dp_device_mesh = device_mesh["DP"] + for fqn, obj in state_dict.items(): + # Since DTensor supports submesh, adding extra check to ensure _create_write_items() + # gets called only when the current rank is part of the mesh for the corresponding DTensor. + if isinstance(obj, DTensor): + if obj.device_mesh.get_coordinate() is not None: + requests += _create_write_items(fqn, obj) + elif isinstance(obj, ShardedTensor): + # For veScale DOptimizer, it will provide empty shards + # if current process does not own the shard of tensor + local_shards = obj.local_shards() + total_size = 0 + for local_shard in local_shards: + for size in local_shard.metadata.shard_sizes: + size += total_size + if size > 0: + requests += _create_write_items(fqn, obj) + elif isinstance(obj, OptimizerStateSpec): + # Create write requests if the process is the real writer + if obj.dp_ranks_ranges: + process_list = [] + for rank, param_range in obj.dp_ranks_ranges.items(): + process_list.append((rank, len(param_range))) + sorted_list = sort_rank_ranges(process_list) + writer_rank = sorted_list[0][0] + p2p_ops = [] + recv_tensors = {} + + # Case 1: I am writer + # Receive tensors + + if dist.get_rank() == writer_rank: + for k, param_range in obj.dp_ranks_ranges.items(): + if k != dist.get_rank(): + recv_tensor = torch.zeros( + (len(param_range),), dtype=obj.local_tensor.dtype, device=obj.local_tensor.device + ) + recv_op = dist.P2POp( + op=dist.irecv, + tensor=recv_tensor, + peer=k, + group=dp_device_mesh.get_dim_groups(0), + ) + recv_tensors[k] = recv_tensor + p2p_ops.append(recv_op) + else: + # Case 2: I am not writer + # Send my tensor + send_op = dist.P2POp( + op=dist.isend, + tensor=obj.local_tensor, + peer=writer_rank, + group=dp_device_mesh.get_dim_groups(0), + ) + p2p_ops.append(send_op) + + reqs = dist.batch_isend_irecv(p2p_ops) + + for req in reqs: + req.wait() + + if writer_rank == dist.get_rank(): + new_local_tensor = torch.zeros( + (math.prod(obj.local_shape),), dtype=obj.local_tensor.dtype, device=obj.local_tensor.device + ) + new_local_tensor[obj.dp_ranks_ranges[writer_rank].start : obj.dp_ranks_ranges[writer_rank].end] = ( + obj.local_tensor + ) + for k, param_range in obj.dp_ranks_ranges.items(): + if k != writer_rank: + new_local_tensor[param_range.start : param_range.end] = recv_tensors[k] + obj.local_tensor = new_local_tensor + + obj.local_tensor = obj.local_tensor.reshape(obj.local_shape) + requests += _create_write_items(fqn, obj) + else: + obj.local_tensor = obj.local_tensor.reshape(obj.local_shape) + requests += _create_write_items(fqn, obj) + elif isinstance(obj, (torch.Tensor)) or is_coordinator: + requests += _create_write_items(fqn, obj) + + return SavePlan(requests) diff --git a/python/vescale/checkpoint/planner/vescale/vescale_planner_helpers.py b/python/vescale/checkpoint/planner/vescale/vescale_planner_helpers.py new file mode 100644 index 0000000..041f3f3 --- /dev/null +++ b/python/vescale/checkpoint/planner/vescale/vescale_planner_helpers.py @@ -0,0 +1,284 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +from typing import Any, List +import torch +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed.checkpoint.planner import WriteItem, WriteItemType, ReadItem, LoadItemType, TensorWriteData +from torch.distributed.checkpoint.metadata import ( + STATE_DICT_TYPE, + STORAGE_TYPES, + MetadataIndex, + ChunkStorageMetadata, + BytesStorageMetadata, + TensorStorageMetadata, +) +from torch.distributed._shard.sharded_tensor import TensorProperties +from torch.distributed._shard.sharded_tensor.shard import Shard +from torch.distributed.checkpoint.resharding import ( + _check_shard_metadata_pair_overlap, + _shards_get_overlap_region_wrt_saved_tensor, +) + +from vescale.dtensor import DTensor +from vescale.dtensor._utils import compute_local_shape, compute_local_offset +from vescale.optim.distributed_optimizer import OptimizerStateSpec + + +def _create_write_items_for_dtensor(fqn, tensor: DTensor) -> WriteItem: + sizes = torch.Size(compute_local_shape(tensor.shape, tensor.device_mesh, tensor.placements)) + offsets = torch.Size(compute_local_offset(tensor.shape, tensor.device_mesh, tensor.placements)) + + return WriteItem( + index=MetadataIndex(fqn=fqn, offset=offsets), + type=WriteItemType.SHARD, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata(offsets=offsets, sizes=sizes), + properties=TensorProperties.create_from_tensor(tensor._local_tensor), # keep out of autograd + size=tensor.size(), + ), + ) + + +def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata: + sizes = torch.Size(compute_local_shape(tensor.shape, tensor.device_mesh, tensor.placements)) + offsets = torch.Size(compute_local_offset(tensor.shape, tensor.device_mesh, tensor.placements)) + return ChunkStorageMetadata(offsets=offsets, sizes=sizes) + + +def _sharded_tensor_metadata(sharded_tensor: ShardedTensor, shard_md: ShardMetadata) -> TensorWriteData: + return TensorWriteData( + chunk=_chunk_for_shard(shard_md), + properties=sharded_tensor.metadata().tensor_properties, + size=sharded_tensor.metadata().size, + ) + + +def _create_write_item_for_shard(fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata) -> WriteItem: + offsets = torch.Size(shard_md.shard_offsets) + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.SHARD, + tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md), + ) + + +def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem: + offsets = torch.Size([0] * len(tensor.size())) + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.TENSOR, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()), + properties=TensorProperties.create_from_tensor(tensor), + size=tensor.size(), + ), + ) + + +def _create_write_item_for_optimizer_state(fqn, object: OptimizerStateSpec) -> WriteItem: + sizes = object.local_shape + offsets = object.global_offset + + return WriteItem( + index=MetadataIndex(fqn=fqn, offset=offsets), + type=WriteItemType.SHARD, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata(offsets=offsets, sizes=sizes), + properties=TensorProperties.create_from_tensor(object.local_tensor), + size=object.global_shape, + ), + ) + + +def _create_write_item_for_bytesio(fqn: str, bytes: Any): + return WriteItem( + index=MetadataIndex(fqn), + type=WriteItemType.BYTE_IO, + ) + + +def _create_write_items(fqn: str, object: Any) -> List[WriteItem]: + if isinstance(object, DTensor): + return [_create_write_items_for_dtensor(fqn, object)] + elif isinstance(object, ShardedTensor): + return [_create_write_item_for_shard(fqn, object, shard.metadata) for shard in object.local_shards()] + elif isinstance(object, torch.Tensor): + return [_create_write_item_for_tensor(fqn, object)] + elif isinstance(object, OptimizerStateSpec): + return [_create_write_item_for_optimizer_state(fqn, object)] + else: + return [_create_write_item_for_bytesio(fqn, object)] + + +def _create_read_item_for_tensor(dest_index, dest_offsets, storage_index, storage_offsets, lengths): + return ReadItem( + type=LoadItemType.TENSOR, + dest_index=dest_index, + dest_offsets=torch.Size(dest_offsets), + storage_index=storage_index, + storage_offsets=torch.Size(storage_offsets), + lengths=torch.Size(lengths), + ) + + +def create_read_items_for_chunk_list( + fqn: str, + checkpoint_md: TensorStorageMetadata, + local_chunks: List[ChunkStorageMetadata], +) -> List[ReadItem]: + """ + Creates a list of ``ReadItem`` based on the checkpoint and local chunks. + + This applies the resharding algorithm and computes the reads needed + to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``. + + Args: + fqn (str) : The state_dict FQN to pass to ``ReadItem``. + checkpoint_md (TensorStorageMetadata): metadata for a given tensor + from a checkpoint. + local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be + loaded. + + Returns: + A list of ``ReadItem`` that will satisfy all input chunks. + """ + read_items = [] + # this is a naive quadratic algo that can be optimized later + for idx, shard in enumerate(local_chunks): + for storage_idx, storage_md in enumerate(checkpoint_md.chunks): + if not _check_shard_metadata_pair_overlap(shard, storage_md): + continue + + storage_offsets = [] + dest_offsets = [] + lengths = [] + for ( + dim, + offset_for_saved_tensor, + offset_for_current_tensor, + length, + ) in _shards_get_overlap_region_wrt_saved_tensor(saved_shard=storage_md, current_shard=shard): + storage_offsets.append(offset_for_saved_tensor) + dest_offsets.append(offset_for_current_tensor) + lengths.append(length) + + read_items.append( + _create_read_item_for_tensor( + dest_index=MetadataIndex(fqn, shard.offsets, idx), + dest_offsets=dest_offsets, + storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx), + storage_offsets=storage_offsets, + lengths=lengths, + ) + ) + return read_items + + +def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata: + return ChunkStorageMetadata(offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()) + + +def _create_read_item_for_byteio(dest_index, dest_offset, storage_index, storage_offset, length): + return ReadItem( + type=LoadItemType.BYTE_IO, + dest_index=dest_index, + dest_offsets=torch.Size((dest_offset,)), + storage_index=storage_index, + storage_offsets=torch.Size((storage_offset,)), + lengths=torch.Size((length,)), + ) + + +def _create_chunk_from_optimizer_spec(obj: OptimizerStateSpec) -> ChunkStorageMetadata: + return ChunkStorageMetadata(offsets=obj.global_offset, sizes=obj.local_shape) + + +def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]: + if not isinstance(md, BytesStorageMetadata): + if isinstance(obj, DTensor): + local_chunks = [_create_chunk_from_dtensor(obj)] + elif isinstance(obj, ShardedTensor): + local_chunks = [_chunk_for_shard(shard.metadata) for shard in obj.local_shards()] + elif isinstance(obj, torch.Tensor): + local_chunks = [_create_chunk_from_tensor(obj)] + elif isinstance(obj, OptimizerStateSpec): + local_chunks = [_create_chunk_from_optimizer_spec(obj)] + else: + raise ValueError( + f"Invalid checkpoint metadata for {fqn}, " + f"expected BytesStorageMetadata but found {type(md)}" + ) + return create_read_items_for_chunk_list(fqn, md, local_chunks) + else: + return [ + _create_read_item_for_byteio( + dest_index=MetadataIndex(fqn), + dest_offset=0, + storage_index=MetadataIndex(fqn), + storage_offset=0, + length=0, + ) + ] + + +def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata: + return ChunkStorageMetadata( + offsets=torch.Size(shard_md.shard_offsets), + sizes=torch.Size(shard_md.shard_sizes), + ) + + +def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard: + if index.offset is None: + raise ValueError(f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided") + + shards = tensor.local_shards() + # index fast path + if index.index is not None: + if len(shards) > index.index and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset: + return shards[index.index] + + for shard in shards: + if torch.Size(shard.metadata.shard_offsets) == index.offset: + return shard + raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'") + + +def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor: + if isinstance(tensor, DTensor): + return tensor._local_tensor # keep out of autograd + if isinstance(tensor, ShardedTensor): + return _find_shard(tensor, index).tensor + if index.offset is not None: + # special case looking up a tensor by origin + if index.offset == torch.Size([0] * len(tensor.size())): + return tensor + raise ValueError(f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'") + return tensor + + +def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any: + # Called when real writing happened + # The filesystem writer calls resolve_data , then it will + # call find_state_dict_object + if index.fqn not in state_dict: + raise ValueError(f"Could not find FQN: '{index.fqn}'") + obj = state_dict[index.fqn] + + if isinstance(obj, torch.Tensor): + return find_tensor_shard(obj, index) + elif isinstance(obj, OptimizerStateSpec): + return obj.local_tensor + elif index.offset is not None: + raise ValueError( + f"FQN: '{index.fqn}' is not a ShardedTensor, it is a {type(obj)} can't find by offset: '{index.offset}'" + ) + return obj diff --git a/python/vescale/checkpoint/save_state_dict.py b/python/vescale/checkpoint/save_state_dict.py new file mode 100644 index 0000000..569745e --- /dev/null +++ b/python/vescale/checkpoint/save_state_dict.py @@ -0,0 +1,114 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +from typing import Optional +import torch +from .utilities.mem_checkpoint import TorchCheckpointRecorder +import torch.distributed as dist +from torch.distributed.checkpoint.filesystem import FileSystemWriter +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.utils import _DistWrapper +from torch.distributed.checkpoint.metadata import Metadata +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner +from .api.meta_type import STATE_DICT_TYPE +from .utilities.logger import get_omnistore_logger +import time +import atexit + +logger = get_omnistore_logger() +_io_workers = None + + +def _clean_up(): + if _io_workers: + _io_workers.terminate() + _io_workers.join() + + +atexit.register(_clean_up) + + +def save_state_dict( + state_dict: STATE_DICT_TYPE, + path: str, + # storage_writer: StorageWriter, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[SavePlanner] = None, + strategy=None, +) -> Metadata: + """ + [veScale version] Saves a distributed model in SPMD style. Fix sub-group storage. + Args and usage is the same as `torch.distributed.checkpoint.save_state_dict`. + """ + save_ckpt_start_time = time.time() + torch._C._log_api_usage_once("omnistore.checkpoint.vescale_checkpoint.save_state_dict") + + # Step 0: create distributed world based on process group and coordinator rank + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if process_group: + distW.coordinator_rank = dist.get_global_rank(process_group, distW.coordinator_rank) + if planner is None: + planner = DefaultSavePlanner() + assert planner is not None + + global_metatadata = None + + storage_writer = FileSystemWriter(path) + + # Step 1: all processes create local write plan, + # then coordinator gathers all local plans and create global plan. + def local_step(): + assert planner is not None + planner.set_up_planner(state_dict, distW.is_coordinator) + storage_writer.set_up_storage_writer(distW.is_coordinator) + local_plan = planner.create_local_plan() + local_plan = storage_writer.prepare_local_plan(local_plan) + return local_plan + + def global_step(all_local_plans): + nonlocal global_metatadata + + assert planner is not None + all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans) + all_local_plans = storage_writer.prepare_global_plan(all_local_plans) + return all_local_plans + + plan_start_time = time.time() + central_plan = distW.reduce_scatter("plan", local_step, global_step) + plan_cost_time = time.time() - plan_start_time + logger.info(f"Finish planning. Cost time: {plan_cost_time}s") + + # Step 2: all processes write data from GPUs to pinned memory pool, then dump to local path + # then coordinator write meta-data to local path. + def write_data(): + assert planner is not None + final_local_plan = planner.finish_plan(central_plan) + # Use pinned memory pool and mult_processing for dumping ckpt to local directory efficiently + global _io_workers + if not _io_workers: + _io_workers = torch.multiprocessing.get_context("spawn").Pool(2) + with TorchCheckpointRecorder(async_worker=_io_workers): + all_writes = storage_writer.write_data(final_local_plan, planner) + all_writes.wait() + return all_writes.value() + + def finish_checkpoint(all_results): + assert global_metatadata is not None + storage_writer.finish(metadata=global_metatadata, results=all_results) + return global_metatadata + + dump_local_start_time = time.time() + all_reduce_results = distW.all_reduce("write", write_data, finish_checkpoint) + dump_local_cost_time = time.time() - dump_local_start_time + logger.info(f"Finish dumping. Cost time: {dump_local_cost_time}s") + + return all_reduce_results diff --git a/python/vescale/checkpoint/storage/checkpoint_adapter.py b/python/vescale/checkpoint/storage/checkpoint_adapter.py new file mode 100644 index 0000000..6c32304 --- /dev/null +++ b/python/vescale/checkpoint/storage/checkpoint_adapter.py @@ -0,0 +1,317 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +from abc import abstractmethod +from tqdm import tqdm +from typing import Dict +import os +import re +import torch +import torch.distributed as dist # current we need to use mpi launch +from vescale import DeviceMesh, DTensor +from .checkpoint_format import LLMHandWriteFormat +from typing import Optional, List, Any +from torch.distributed.distributed_c10d import ( + ProcessGroup, + get_rank, + get_world_size, +) +from torch.distributed.checkpoint._nested_dict import flatten_state_dict, unflatten_state_dict +from torch.distributed.checkpoint.metadata import ( + STATE_DICT_TYPE, +) + +from ..utilities.bfile import listdir, BFile + + +def _construct_megatron_downloading_map(filenames: List[str]): + weight_dic_pattern = r"mp_rank_\d\d_\d\d\d$" + filtered_files = [file for file in filenames if re.match(weight_dic_pattern, file)] + + download_map = {} + for file in filtered_files: + parts = file.split("_") + tp_rank = int(parts[2]) + pp_rank = int(parts[3]) + if pp_rank not in download_map: + download_map[pp_rank] = {} + download_map[pp_rank][tp_rank] = file + + return download_map + + +def _construct_reverse_pp_tp_map(vescale_path: str): + if not os.path.exists(vescale_path): + raise RuntimeError(f"vescale_path not exists. path: {vescale_path}") + files = os.listdir(vescale_path) + match = r"rank\d+.pt" + + filtered_files = [file for file in files if re.match(match, file)] + rank_map = {} + for file in filtered_files: + rank = re.search(r"\d+", file).group(0) + rank_map[rank] = os.path.join(vescale_path, file) + return rank_map + + +def _construct_pp_tp_map(megatron_path: str): + """ + construct tp pp index mapping dict + { + # for pp 0 + 0: { + # for tp 0 + 0 : "xxx.pt", + 1 : "xxx.pt" + } + } + """ + dics = listdir(megatron_path) + if len(dics) == 0: + raise RuntimeError(f"megatron_path not exists or is empty. path: {megatron_path}") + + weight_map = dict() + optim_map = dict() + + def update_dict(dic_, pp_r, tp_r, file_path): + if pp_r in dic_: + pp_dic = dic_[pp_r] + pp_dic.update({tp_r: file_path}) + else: + new_dic = {tp_r: file_path} + dic_[pp_r] = new_dic + + weight_dict = r"mp_rank_\d\d_\d\d\d$" + optim_dict = r"mp_rank_\d\d_\d\d\d_\d\d\d$" + filtered_weights_dics = [dic for dic in dics if re.match(weight_dict, dic)] + filtered_optim_dics = [dic for dic in dics if re.match(optim_dict, dic)] + + # construct weight 2-dims maps + for dic in filtered_weights_dics: + split_ul = re.split("_", dic) + tp_rank = int(split_ul[2]) + pp_rank = int(split_ul[3]) + weight_file = os.path.join(megatron_path, dic, "model_rng.pt") + update_dict(weight_map, pp_rank, tp_rank, weight_file) + + # construct optimize 2-dims maps + for dic in filtered_optim_dics: + split_ul = re.split("_", dic) + tp_rank = int(split_ul[2]) + pp_rank = int(split_ul[3]) + optim_file = os.path.join(megatron_path, dic, "optim.pt") + update_dict(optim_map, pp_rank, tp_rank, optim_file) + return weight_map, optim_map + + +def _get_megatron_tp_group(world_size, pp_size, tp_size, dp_size, cur_rank) -> tuple[ProcessGroup, ProcessGroup]: + """make sub pg group""" + return dist.new_subgroups(group_size=tp_size * dp_size) + + +def _deduce_parallel_plan_by_device_mesh(mesh: DeviceMesh): + """make rank to megatron tp_rank, pp_rank map""" + # FIXME(cery.69) : current only support data parallel is 1 + # allways parallel in last dim + tp_size = mesh.size() + # for rank = pp_rank * tp_size + tp_rank + # (rank - tp_rank) / tp_size = pp_rank + tp_rank = get_rank() % tp_size + assert (get_rank() - tp_rank) % tp_size == 0, "megatron not support pp size undivided by tp size" + pp_rank = (get_rank() - tp_rank) // tp_size + return tp_rank, pp_rank + + +def _filter_unused_tensors_and_renaming(old_state_dict: Dict[str, Any], param_resharding_plan: Dict[str, Any]): + new_state_dict = {} + + flatten_old_st, _ = flatten_state_dict(old_state_dict) + + for key, value in flatten_old_st.items(): + for pattern in param_resharding_plan.keys(): + start_index = key.find(pattern) + if start_index == -1: + continue + else: + new_state_dict[pattern] = value + print(new_state_dict.keys()) + return new_state_dict + + +################################################################## +##################### for visitor ##################### +################################################################## + + +class StateDictVisitor: + def set_device_mesh(self, mesh: DeviceMesh): + self.device_mesh = mesh + + @abstractmethod + def parsing_state_dict(self, st: dict, *args, **kwargs): + """ + flattened parsing module dict, using process function to handle each Tensor + """ + f_st, mapping = flatten_state_dict(st) + # flattened_key , value + for key, value in tqdm(f_st.items()): + if isinstance(value, (torch.Tensor, DTensor)): + self.tensor_process_func(f_st, key, value, *args, **kwargs) + new_st = unflatten_state_dict(f_st, mapping) + st.update(new_st) + + @abstractmethod + def tensor_process_func(self, parent: dict, key: str, value: Any, *args, **kwargs): + raise NotImplementedError("method abstruct method is call") + + @abstractmethod + def apply(self, state_dict: dict, *args, **kwargs): + self.parsing_state_dict(state_dict, *args, **kwargs) + + +class DefaultM2VDFSVisitor(StateDictVisitor): + def __init__(self, format: LLMHandWriteFormat): + self.format = format + super().__init__() + + def tensor_process_func(self, parent: dict, key: str, value: Any, *args, **kwargs): + assert self.format is not None, "format is not set" + tensor_placement = self.format.get_tensor_sharding_plan_by_name(key) + assert isinstance(value, torch.Tensor) + + is_requires_grad = value.requires_grad + with torch.no_grad(): # keep out of autograd + dtensor = DTensor.from_local(value, self.device_mesh, tensor_placement) + dtensor.requires_grad_(is_requires_grad) + + parent[key] = dtensor + + def apply(self, state_dict: dict, *args, **kwargs): + self.parsing_state_dict(state_dict, *args, **kwargs) + + +class DefaultV2MDFSVisitor(StateDictVisitor): + def __init__(self): + super().__init__() + + def tensor_process_func(self, parent: dict, key: str, value: DTensor, *args, **kwargs): + parent[key] = value._local_tensor # keep out of autograd + + def apply(self, state_dict: dict, *args, **kwargs): + self.parsing_state_dict(state_dict, *args, **kwargs) + + +################################################################## +##################### for api func ##################### +################################################################## + + +def convert_vescale_checkpoint_to_megatron( + vescale_path: str, megatron_path: str, visitor: StateDictVisitor, device=torch.device("cpu") +) -> STATE_DICT_TYPE: + rank_map = _construct_reverse_pp_tp_map(vescale_path) + world_size = len(rank_map) + assert world_size == get_world_size(), f"world size mismatch {world_size} vs {get_world_size()}" + rank = get_rank() + rank_file_name = rank_map[str(rank)] + rank_file_path = os.path.join(vescale_path, rank_file_name) + if os.path.exists(rank_file_path): + st = torch.load(rank_file_path, map_location=device) + + def find_device_mesh(st): + for key in st: + value = st[key] + if isinstance(value, DTensor): + mesh = value.device_mesh + return mesh + elif isinstance(value, dict): + mesh = find_device_mesh(value) + if mesh: + return mesh + return None + + device_mesh = find_device_mesh(st) + assert device_mesh is not None, "not find devicemesh in vescale format please check" + tp_rank, pp_rank = _deduce_parallel_plan_by_device_mesh(device_mesh) + visitor.apply(st) + megatron_dict = f"mp_rank_{str(tp_rank).zfill(2)}_{str(pp_rank).zfill(3)}" + tmp_path = megatron_path + megatron_save_path = os.path.join(tmp_path, megatron_dict) + os.makedirs(megatron_save_path, exist_ok=True) + megatron_save_file = os.path.join(megatron_save_path, "model_rng.pt") + if "optim" in st: + optim = st["optim"] + megatron_optim_dict = f"mp_rank_{str(tp_rank).zfill(2)}_{str(pp_rank).zfill(3)}_000" + megatron_optim_dict_path = os.path.join(tmp_path, megatron_optim_dict) + os.makedirs(megatron_optim_dict_path, exist_ok=True) + torch.save(optim, os.path.join(megatron_optim_dict_path, "optim.pt")) + del st["optim"] + torch.save(st, megatron_save_file) + # FIXME(cery.69): support dp not 1 + return st + + +def convert_megatron_checkpoint_to_vescale( + megatron_path: str, visitor: DefaultM2VDFSVisitor, device=torch.device("cpu"), vescale_path: Optional[str] = None +) -> STATE_DICT_TYPE: + weight_map, optim_map = _construct_pp_tp_map(megatron_path) + tp_equal = [(len(weight_map[pp]) == len(weight_map[0])) for pp in weight_map] + assert all(tp_equal), "megatron not support unmodified devided split plan" + tp_size = len(weight_map[0]) + pp_size = len(weight_map) + + rank = get_rank() + + for pp_rank in range(0, pp_size): + for tp_rank in range(0, tp_size): + megatron_rank = pp_rank * tp_size + tp_rank + if megatron_rank != rank: + continue + megatron_weight_pt = weight_map[pp_rank][tp_rank] + # phase 1. parse weight + with BFile(megatron_weight_pt, "rb") as f: + m_st = torch.load(f, map_location=device) + args = m_st["args"] + megatron_cur_rank = args.rank + megatron_world_size = args.world_size + megatron_tp_size = args.tensor_model_parallel_size + megatron_pp_size = args.pipeline_model_parallel_size + megatron_dp_size = args.data_parallel_size + + local_pg, _ = _get_megatron_tp_group( + megatron_world_size, megatron_pp_size, megatron_tp_size, megatron_dp_size, megatron_cur_rank + ) + device_mesh = DeviceMesh(device.__str__(), None, pg=local_pg) + visitor.set_device_mesh(device_mesh) + visitor.apply(m_st["model"], "model") + + new_st = {} + new_st["models"] = _filter_unused_tensors_and_renaming( + m_st["model"], visitor.format.default_params_sharding_plan + ) + if len(optim_map) > 0: + megatron_optim_pt_path = optim_map[pp_rank][tp_rank] + # phase 2. parse optimizer + with BFile(megatron_optim_pt_path, "rb") as f: + optim = torch.load(f, map_location=device) + visitor.apply(optim, "") + new_st["optim"] = optim + if vescale_path: + save_file = f"rank{rank}.pt" + with BFile(os.path.join(vescale_path, save_file), "wb") as f: + torch.save(new_st, f) + return new_st diff --git a/python/vescale/checkpoint/storage/checkpoint_format.py b/python/vescale/checkpoint/storage/checkpoint_format.py new file mode 100644 index 0000000..f3e262f --- /dev/null +++ b/python/vescale/checkpoint/storage/checkpoint_format.py @@ -0,0 +1,48 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import re +from typing import Sequence + + +from vescale import Shard, Replicate, Placement + + +class LLMHandWriteFormat: + def __init__(self, params_sharding_plan): + super().__init__() + self.default_params_sharding_plan = params_sharding_plan + + def get_tensor_sharding_plan_by_name(self, name: str) -> Sequence[Placement]: + for pattern, placements in self.default_params_sharding_plan.items(): + if re.search(pattern, name): + return placements + return [Replicate()] + + +MEGATRON_GPT_RULES = { + r"model.gpt_model.language_model.embedding.word_embeddings.weight": [Shard(0)], + r"model.gpt_model.language_model.encoder.layers.\d+.mlp.dense_h_to_4h.weight": [Shard(0)], + r"model.gpt_model.language_model.encoder.layers.\d+.mlp.dense_h_to_4h_lora.weight": [Shard(0)], + r"model.gpt_model.language_model.encoder.layers.\d+.mlp.dense_4h_to_h.weight": [Shard(1)], + r"model.gpt_model.language_model.encoder.layers.\d+.mlp.dense_4h_to_h_lora.weight": [Shard(1)], + r"model.gpt_model.language_model.encoder.layers.\d+.self_attention.query_key_value.weight": [Shard(0)], + r"model.visual_encoder.blocks.\d+.attn.qkv.weight": [Shard(0)], + r"model.visual_encoder.blocks.\d+.attn.proj.weight": [Shard(1)], + r"model.visual_encoder.blocks.\d+.mlp.fc1.weight": [Shard(0)], + r"model.visual_encoder.blocks.\d+.mlp.fc2.weight": [Shard(1)], +} diff --git a/python/vescale/checkpoint/utilities/bfile.py b/python/vescale/checkpoint/utilities/bfile.py new file mode 100644 index 0000000..980dbc8 --- /dev/null +++ b/python/vescale/checkpoint/utilities/bfile.py @@ -0,0 +1,129 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +# Existing APIs all follow the rule here: +# https://www.tensorflow.org/api_docs/python/tf/io/gfile + + +import os +import enum +import contextlib +import uuid +from .logger import get_omnistore_logger +import shutil +from .server import mem_server_lib + +logger = get_omnistore_logger() +BFILE_DEFAULT_TIMEOUT = None + + +class FileType(enum.Enum): + LOCAL = 0 + LOCAL_MEM = 1 + + +def local_list_folder(folder_path: str, recursive: bool = False): + file_paths = [] + if recursive: + for root, _, files in os.walk(folder_path): + for file_name in files: + file_path = os.path.join(root, file_name) + file_paths.append(file_path) + else: + if os.path.isdir(folder_path): + file_paths.extend([os.path.join(folder_path, d) for d in os.listdir(folder_path)]) + elif os.path.isfile(folder_path): + file_paths.append(folder_path) + else: + logger.warning(f"Path {folder_path} is invalid") + + return file_paths + + +def get_schema(path: str): + if path.startswith(mem_server_lib.SCHEMA): + return FileType.LOCAL_MEM + return FileType.LOCAL + + +def rename(src, dst, overwrite=False): + t = get_schema(src) + if t == FileType.LOCAL_MEM: + return mem_server_lib.rename(src, dst, overwrite) + return os.rename(src, dst) + + +def listdir(path): + t = get_schema(path) + if t == FileType.LOCAL_MEM: + return mem_server_lib.listdir(path) + absolute_files = local_list_folder(path) + return [f[f.rfind("/") + 1 :] for f in absolute_files] + + +def remove(path): + t = get_schema(path) + if t == FileType.LOCAL_MEM: + return mem_server_lib.remove(path) + return shutil.rmtree(path, ignore_errors=True) + + +def exists(path): + t = get_schema(path) + if t == FileType.LOCAL_MEM: + return mem_server_lib.exists(path) + return os.path.exists(path) + + +def makedirs(path): + t = get_schema(path) + if t == FileType.LOCAL_MEM: + # Local mem doesn't have empty folder + return + return os.makedirs(path, exist_ok=True) + + +@contextlib.contextmanager +def BFile(name, mode="r"): + t = get_schema(name) + if t == FileType.LOCAL_MEM: + with mem_server_lib.open(name, mode) as f: + yield f + else: + with open(name, mode) as f: + yield f + + +# ---- Below is some useful utilities ----- + + +def atomic_write(path: str, content: bytes, **kwargs): + tmp_path = path + "_tmp_" + str(uuid.uuid4()) + with BFile(tmp_path, "wb", **kwargs) as f: + f.write(content) + rename(tmp_path, path, overwrite=True) + + +def safe_atomic_write(path: str, content: bytes, **kwargs): + makedirs(os.path.dirname(path)) + atomic_write(path, content, **kwargs) + + +def is_local_path(path: str): + t = get_schema(path) + if t == FileType.LOCAL_MEM or t == FileType.LOCAL: + return True + return False diff --git a/python/vescale/checkpoint/utilities/logger.py b/python/vescale/checkpoint/utilities/logger.py new file mode 100644 index 0000000..98bb964 --- /dev/null +++ b/python/vescale/checkpoint/utilities/logger.py @@ -0,0 +1,251 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +"""Utilities for loggers.""" + +from argparse import Namespace +from typing import Any, Dict, Generator, List, MutableMapping, Optional, Union +import logging +import warnings +import os +import sys + +import numpy as np +import torch + + +def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]: + """Ensure parameters are a dict or convert to dict if necessary. + Args: + params: Target to be converted to a dictionary + + Returns: + params as a dictionary + + """ + # in case converting from namespace + if isinstance(params, Namespace): + params = vars(params) + + if params is None: + params = {} + + return params + + +def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]: + """Sanitize callable params dict, e.g. ``{'a': } -> {'a': 'function_****'}``. + + Args: + params: Dictionary containing the hyperparameters + + Returns: + dictionary with all callables sanitized + """ + + def _sanitize_callable(val: Any) -> Any: + # Give them one chance to return a value. Don't go rabbit hole of recursive call + if callable(val): + try: + _val = val() + if callable(_val): + return val.__name__ + return _val + # todo: specify the possible exception + except Exception: + return getattr(val, "__name__", None) + return val + + return {key: _sanitize_callable(val) for key, val in params.items()} + + +def _flatten_dict(params: Dict[Any, Any], delimiter: str = "/") -> Dict[str, Any]: + """Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``. + + Args: + params: Dictionary containing the hyperparameters + delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``. + + Returns: + Flattened dict. + + Examples: + >>> _flatten_dict({'a': {'b': 'c'}}) + {'a/b': 'c'} + >>> _flatten_dict({'a': {'b': 123}}) + {'a/b': 123} + >>> _flatten_dict({5: {'a': 123}}) + {'5/a': 123} + """ + + def _dict_generator( + input_dict: Any, prefixes: List[Optional[str]] = None + ) -> Generator[Any, Optional[List[str]], List[Any]]: + prefixes = prefixes[:] if prefixes else [] + if isinstance(input_dict, MutableMapping): + for key, value in input_dict.items(): + key = str(key) + if isinstance(value, (MutableMapping, Namespace)): + value = vars(value) if isinstance(value, Namespace) else value + yield from _dict_generator(value, prefixes + [key]) + else: + yield prefixes + [key, value if value is not None else str(None)] + else: + yield prefixes + [input_dict if input_dict is None else str(input_dict)] + + return {delimiter.join(keys): val for *keys, val in _dict_generator(params)} + + +def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: + """Returns params with non-primitvies converted to strings for logging. + + >>> params = {"float": 0.3, + ... "int": 1, + ... "string": "abc", + ... "bool": True, + ... "list": [1, 2, 3], + ... "namespace": Namespace(foo=3), + ... "layer": torch.nn.BatchNorm1d} + >>> import pprint + >>> pprint.pprint(_sanitize_params(params)) # doctest: +NORMALIZE_WHITESPACE + {'bool': True, + 'float': 0.3, + 'int': 1, + 'layer': "", + 'list': '[1, 2, 3]', + 'namespace': 'Namespace(foo=3)', + 'string': 'abc'} + """ + for k in params.keys(): + # convert relevant np scalars to python types first (instead of str) + if isinstance(params[k], (np.bool_, np.integer, np.floating)): + params[k] = params[k].item() + elif type(params[k]) not in [bool, int, float, str, torch.Tensor]: + params[k] = str(params[k]) + return params + + +def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: + """Insert prefix before each key in a dict, separated by the separator. + + Args: + metrics: Dictionary with metric names as keys and measured quantities as values + prefix: Prefix to insert before each key + separator: Separates prefix and original key name + + Returns: + Dictionary with prefix and separator inserted before each key + """ + if prefix: + metrics = {f"{prefix}{separator}{k}": v for k, v in metrics.items()} + + return metrics + + +def _name(loggers: List[Any], separator: str = "_") -> str: + if len(loggers) == 1: + return loggers[0].name + else: + # Concatenate names together, removing duplicates and preserving order + return separator.join(dict.fromkeys(str(logger.name) for logger in loggers)) + + +def _version(loggers: List[Any], separator: str = "_") -> Union[int, str]: + if len(loggers) == 1: + return loggers[0].version + else: + # Concatenate versions together, removing duplicates and preserving order + return separator.join(dict.fromkeys(str(logger.version) for logger in loggers)) + + +# from https://stackoverflow.com/questions/2183233/how-to-add-a-custom-loglevel-to-pythons-logging-facility +def _add_logging_level(level_name, level_num, method_name=None): + """ + Comprehensively adds a new logging level to the `logging` module and the + currently configured logging class. + + `level_name` becomes an attribute of the `logging` module with the value + `level_num`. `method_name` becomes a convenience method for both `logging` + itself and the class returned by `logging.getLoggerClass()` (usually just + `logging.Logger`). If `method_name` is not specified, `level_name.lower()` is + used. + + To avoid accidental clobberings of existing attributes, this method will + raise an `AttributeError` if the level name is already an attribute of the + `logging` module or if the method name is already present + + Example + ------- + >>> addLoggingLevel('TRACE', logging.DEBUG - 5) + >>> logging.getLogger(__name__).setLevel("TRACE") + >>> logging.getLogger(__name__).trace('that worked') + >>> logging.trace('so did this') + >>> logging.TRACE + 5 + + """ + if not method_name: + method_name = level_name.lower() + + if hasattr(logging, level_name): + warnings.warn(f"{level_name} already defined in logging module") + return + if hasattr(logging, method_name): + warnings.warn(f"{method_name} already defined in logging module") + return + if hasattr(logging.getLoggerClass(), method_name): + warnings.warn(f"{method_name} already defined in logger class") + return + + # This method was inspired by the answers to Stack Overflow post + # http://stackoverflow.com/q/2183233/2988730, especially + # http://stackoverflow.com/a/13638084/2988730 + def log_for_level(self, message, *args, **kwargs): + if self.isEnabledFor(level_num): + self._log(level_num, message, args, **kwargs) + + def log_to_root(message, *args, **kwargs): + logging.log(level_num, message, *args, **kwargs) + + logging.addLevelName(level_num, level_name) + setattr(logging, level_name, level_num) + setattr(logging.getLoggerClass(), method_name, log_for_level) + setattr(logging, method_name, log_to_root) + + +class OmniStoreLogger: + def __new__(cls): + if not hasattr(cls, "instance"): + level = logging.WARNING + level_str = os.environ.get("OMNISTORE_LOGGING_LEVEL", "WARNING").upper() + if level_str in logging._nameToLevel: + level = logging._nameToLevel[level_str] + formatter = logging.Formatter( + "[%(asctime)s][%(levelname)s][%(filename)s:%(lineno)d][%(module)s]" "[pid:%(process)d] - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + handler = logging.StreamHandler(stream=sys.stdout) + handler.setFormatter(formatter) + cls.instance = logging.getLogger("omnistore") + cls.instance.addHandler(handler) + cls.instance.setLevel(level) + cls.instance.propagate = False + return cls.instance + + +def get_omnistore_logger(): + """Get omnistore logger with logging level OMNISTORE_LOGGING_LEVEL, and output to stdout.""" + return OmniStoreLogger() diff --git a/python/vescale/checkpoint/utilities/mem_checkpoint.py b/python/vescale/checkpoint/utilities/mem_checkpoint.py new file mode 100644 index 0000000..3a725db --- /dev/null +++ b/python/vescale/checkpoint/utilities/mem_checkpoint.py @@ -0,0 +1,619 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +import io +import dataclasses +import os +import torch +from torch import multiprocessing +import threading +from typing import Callable, Dict, Any, DefaultDict, Tuple, List, Optional +import pickle + +from .server import server_lib +from . import bfile +from .logger import get_omnistore_logger + +logger = get_omnistore_logger() + +if hasattr(torch.storage, "TypedStorage"): + TypedStorage = torch.storage.TypedStorage +elif hasattr(torch.storage, "_TypedStorage"): + TypedStorage = torch.storage._TypedStorage + +# TypedStorage changes in pytorch 2. +if torch.__version__ >= "2": + + def untyped_storage(o): + return o.untyped_storage() + + def location_caster(o): + return o +elif torch.__version__ >= "1.11": + + def untyped_storage(o): + return o.storage()._storage + + def location_caster(o): + return o._storage if isinstance(o, TypedStorage) else o + + +try: + lib = torch.cuda.cudart() +except: + lib = None + + +def _bytes_to_tensor(b: bytes): + # Copied from `_object_to_tensor` in + # https://pytorch.org/docs/2.0/_modules/torch/distributed/distributed_c10d.html + byte_storage = torch.ByteStorage.from_buffer(b) + return torch.ByteTensor(byte_storage) + + +class PinnedStoragePool: + def __init__(self): + self._l = threading.Lock() + self._m = DefaultDict(set) + + def allocate(self, nbytes: int): + with self._l: + # We don't really need storage to have the exact size. So in theory we can find a + # bigger storage that may suit here. But so far we keep everything simple here. + s = self._m[nbytes] + if not s: + t = torch.empty([nbytes], dtype=torch.uint8) + t = t.share_memory_() + if lib is not None and nbytes != 0: + err = lib.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) + assert err == 0, err + storage = untyped_storage(t) + s.add(storage) + return s.pop() + + def deallocate(self, s): + with self._l: + self._m[s.nbytes()].add(s) + + +GLOBAL_POOL = PinnedStoragePool() + + +class _CalledOnce: + def __init__(self, func): + self._l = threading.Lock() + self._func = func + self._res = None + self._called = False + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + with self._l: + if self._called: + return self._res + self._called = True + self._res = self._func(*args, **kwargs) + return self._res + + +_LOCATION_TAG_LOCK = threading.Lock() + + +@dataclasses.dataclass +class _SaveArgs: + obj: object + storage_tags: list + pickle_module: __module__ + args: list + kwargs: dict + + +def _serialize_obj_with_map(a: _SaveArgs, as_shared_tensor=False): + """Called to serialize an object to a byte stream or a shared tensor. + + Args: + a (_SaveArgs): The save args consist of the original tensor to serialize, + the location tags, the pickle module and other args + as_shared_tensor (bool): Whether to serialize to a shared tensor or a byte stream. + Set False if no inter process communication will happen subsequently + + Returns: + byte stream or shared tensor: The serialized object + + """ + lm = {} + for storage, tag in a.storage_tags: + lm[storage._cdata] = tag + + def location_tag(storage): + loc = lm.get(storage._cdata, None) + if loc is None: + if storage.nbytes() == 0: + # if return None, save will succeed, but load will fail afterwards + return "cpu" + raise ValueError("Unknown storage") + return loc + + with _LOCATION_TAG_LOCK: + old_location_tag = torch.serialization.location_tag + torch.serialization.location_tag = location_tag + + bio = io.BytesIO() + pickle_module = a.pickle_module or pickle + torch.save(a.obj, bio, pickle_module=pickle_module, *a.args, **a.kwargs) + + torch.serialization.location_tag = old_location_tag + b = bio.getvalue() + if not as_shared_tensor: + return b + else: + return _bytes_to_tensor(b).share_memory_() + + +def _write(f, sa): + # Serialize tensor obj directly to a byte stream, no need to convert it + # back to a shared tensor because the whole procedure happens in the same + # process + b = _serialize_obj_with_map(sa) + bfile.safe_atomic_write(f, b) + + +@dataclasses.dataclass +class _PoolArgs: + pinned_pool: PinnedStoragePool + pooled_storages: list + + +class _WriteFunc: + def __init__(self, sa: _SaveArgs, pa: _PoolArgs, async_worker): + self._sa = sa + if self._sa.pickle_module == pickle: + # This makes wa serializable. + self._sa.pickle_module = None + self._pa = pa + self._async_worker = async_worker + + self._enable_mp = async_worker is not None and sa.pickle_module is None + self._des = _CalledOnce(self._des_do_not_call_directly) + self._l = threading.RLock() + self._serialized = None + self._bytes = None + + def _des_do_not_call_directly(self): + for s in self._pa.pooled_storages: + self._pa.pinned_pool.deallocate(s) + + def __del__(self): + self._des() + + @property + def serialized(self): + with self._l: + if self._serialized is None: + if self._enable_mp: + self._serialized = self._async_worker.apply(_serialize_obj_with_map, (self._sa, True)) + else: + self._serialized = _serialize_obj_with_map(self._sa) + self._des() + return self._serialized + + @property + def bytes(self): + if self._bytes is None: + with self._l: + if self._enable_mp: + self._bytes = self.serialized.numpy().tobytes() + else: + self._bytes = self.serialized + return self._bytes + + def __call__(self, file: str = None): + if file is None: + return self.bytes + + if self._async_worker: + self._async_worker.apply(_write, (file, self._sa)) + else: + _write(file, self._sa) + self._des() + + +class TorchCheckpointRecorder: + def __init__( + self, + fast_mode=None, + async_worker: multiprocessing.Pool = None, + pinned_pool=GLOBAL_POOL, + ): + self._thread_id = threading.get_ident() + self._m = {} + + # After 1.11, typed storage is publicly accessible. + condition = torch.__version__ >= "1.11" + self._fast_mode = fast_mode if fast_mode is not None else condition + # Safety check. + assert not self._fast_mode or condition + + self._async_worker = async_worker + self._pinned_pool = pinned_pool + + def __enter__(self): + self._old_save = torch.save + torch.save = self._save_wrapper + if self._fast_mode: + self._old_warning = getattr(torch.storage, "_warn_typed_storage_removal", None) + torch.storage._warn_typed_storage_removal = lambda *args, **kwags: None + return self + + def __exit__(self, *args): + torch.save = self._old_save + if self._fast_mode: + if self._old_warning: + torch.storage._warn_typed_storage_removal = self._old_warning + + def _save_wrapper(self, obj, f, pickle_module=pickle, *args, **kwargs): + if threading.get_ident() != self._thread_id or not isinstance(f, (str, os.PathLike)): + return self._old_save(obj, f, pickle_module, *args, **kwargs) + + if self._fast_mode: + func = self._copy_to_buffer(obj, pickle_module, *args, **kwargs) + else: + func = self._save_to_buffer(obj, pickle_module, *args, **kwargs) + + self._m[str(f)] = func + + def _save_to_buffer(self, obj, *args, **kwags): + b = io.BytesIO() + self._old_save(obj, b, *args, **kwags) + + def gen_func(b): + def func(f: str = None): + if f: + return bfile.safe_atomic_write(f, b.getvalue()) + return b.getvalue() + + return func + + return gen_func(b) + + def _copy_to_buffer(self, obj, pickle_module, *args, **kwargs): + m = {} + storage_tags = [] + pooled_storages = [] + + def persistent_id(o): + if torch.is_storage(o) or isinstance(o, TypedStorage): + storage = o + if storage._cdata in m: + return storage._cdata + if storage.device.type != "cpu": + copied = self._pinned_pool.allocate(storage.nbytes()) + pooled_storages.append(copied) + copied.copy_(storage, non_blocking=False) + if isinstance(storage, TypedStorage): + copied = storage._new_wrapped_storage(copied) + else: + copied = storage.clone() + m[storage._cdata] = copied + tag = torch.serialization.location_tag(location_caster(storage)) + storage_tags.append((copied, tag)) + return storage._cdata + return + + b = io.BytesIO() + p = pickle_module.Pickler(b) + p.persistent_id = persistent_id + p.dump(obj) + b.seek(0) + up = pickle_module.Unpickler(b) + up.persistent_load = lambda i: m[i] + nobj = up.load() + + sa = _SaveArgs( + obj=nobj, + storage_tags=storage_tags, + pickle_module=pickle_module, + args=args, + kwargs=kwargs, + ) + pa = _PoolArgs(pinned_pool=self._pinned_pool, pooled_storages=pooled_storages) + + return _WriteFunc(sa, pa, self._async_worker) + + @property + def files(self) -> Dict[str, Callable[[Optional[List[str]]], Optional[bytes]]]: + return self._m + + +@dataclasses.dataclass +class Item: + file: str + src: int + dsts: List[int] + + +@dataclasses.dataclass +class Strategy: + eof: bool = False + bc_ranks_and_files: List[Item] = dataclasses.field(default_factory=list) + sr_ranks_and_files: List[Item] = dataclasses.field(default_factory=list) + + +def _choose_rank(ranks: list, existences: List[bool]): + for rank in ranks: + if existences[rank]: + return rank + for rank, exist in enumerate(existences): + if exist: + return rank + + return ranks[0] + + +def make_strategy(files: List[str], file_existences: Dict[str, List[bool]], bc_threshold): + file_to_ranks = DefaultDict(list) + + for rank, file in enumerate(files): + if not file: + continue + file_to_ranks[file].append(rank) + + s = Strategy() + + DOWNLOAD = 0 + BC = 1 + SR = 2 + + for file, ranks in file_to_ranks.items(): + mode = DOWNLOAD + + if len(ranks) >= bc_threshold: + mode = BC + else: + for rank in ranks: + if not file_existences[file][rank]: + mode = SR + item = Item(file=file, src=0, dsts=list(ranks)) + if mode == BC: + ranks_and_files = s.bc_ranks_and_files + elif mode == SR: + ranks_and_files = s.sr_ranks_and_files + else: + ranks_and_files = None + + if ranks_and_files is not None: + item.src = _choose_rank(ranks, file_existences[file]) + ranks_and_files.append(item) + + return s + + +class DistributedTorchLoader: + """Use torch distributed communication library to distribute dump. + The key idea here is to use `broadcast` to distribute common files. + """ + + def __init__( + self, + stub, + rank: int, + path_prefix_pair: Tuple[str, str] = ("", ""), + bc_thres=6, + timeout=30 * 60, + custom_strategy=make_strategy, + ): + self._stub = stub + self._rank = rank + self._ppp = path_prefix_pair + self._bc_thres = bc_thres + self._custom_strategy = custom_strategy + self._thread_id = threading.get_ident() + self._old_load = None + self._timeout = timeout + + def __enter__(self): + self._old_load = torch.load + torch.load = self._load_warpper + return self + + def __exit__(self, *args): + if any(i is not None for i in args): + raise RuntimeError( + f"[rank {self._rank}]: DistributedTorchLoader exits with Exception. The exit args are {args}" + ) + self._end_loop() + torch.load = self._old_load + + def _load_warpper(self, f, *args, **kwargs): + if threading.get_ident() != self._thread_id or not isinstance(f, (str, os.PathLike)): + return self._old_load(f, *args, **kwargs) + + f = str(f) + if f.startswith(self._ppp[0]): + f = f[len(self._ppp[0]) :] + f = self._ppp[1] + f + + b, _ = self._coordinate(f) + logger.info(f"_coordinate in _load_warpper of DistributedTorchLoader is called for file {f}") + + return self._old_load(io.BytesIO(b), *args, **kwargs) + + def _coordinate(self, input_f: str): + s = self._make_strategy(input_f) + file_to_bytes = self._download_files(s, input_f) + file_to_size = self._broadcast_file_sizes(file_to_bytes) + ret = file_to_bytes.get(input_f, None) + + for item in s.bc_ranks_and_files: + b = self._bytes_broadcast(file_to_bytes.get(item.file, None), file_to_size[item.file], item.src) + if item.file == input_f: + ret = b + + for item in s.sr_ranks_and_files: + if self._rank == item.src or self._rank in item.dsts: + b = self._bytes_sr( + file_to_bytes.get(item.file, None), + file_to_size[item.file], + item.src, + item.dsts, + ) + if item.file == input_f: + ret = b + + server_lib.barrier(self._stub, self._rank, timeout=self._timeout) + return ret, s.eof + + def _bytes_broadcast(self, b: bytes, size: int, src: int): + # We can't serialize `b` directly since python3.7's pickler has limit of 4GB + t = self._bytes_to_dist_tensor(b, size) + torch.distributed.broadcast(t, src) + b = t.cpu().numpy().tobytes() + del t + return b + + def _bytes_sr(self, b: bytes, size: int, src: int, dsts: List[int]): + t = self._bytes_to_dist_tensor(b, size) + if src == self._rank: + results = [] + for dst in dsts: + if src != dst: + results.append(torch.distributed.isend(t, dst)) + for res in results: + res.wait() + else: + torch.distributed.irecv(t).wait() + b = t.cpu().numpy().tobytes() + del t + return b + + def _bytes_to_dist_tensor(self, b, size): + pg = torch.distributed.GroupMember.WORLD + if pg.name() == torch.distributed.Backend.NCCL: + device = torch.device("cuda", torch.cuda.current_device()) + else: + device = torch.device("cpu") + if b is not None: + return _bytes_to_tensor(b).to(device) + else: + return torch.empty(size, dtype=torch.uint8, device=device) + + def _end_loop(self): + while True: + _, eof = self._coordinate("") + if eof: + break + + def _gather(self, obj): + return server_lib.gather(self._stub, 0, self._rank, obj, timeout=self._timeout) + + def _broadcast(self, obj): + return server_lib.broadcast(self._stub, 0, self._rank, obj, timeout=self._timeout) + + def _download(self, f): + if not bfile.exists(f): + error_msg = f"Unable to get {f} in {self._rank}" + logger.error(error_msg) + raise RuntimeError(error_msg) + with bfile.BFile(f, "rb") as f_obj: + return f_obj.read() + + def _download_files(self, s: Strategy, input_f: str): + from_remote = False + file_to_bytes = {} + for item in (*s.bc_ranks_and_files, *s.sr_ranks_and_files): + if item.src == self._rank: + file_to_bytes[item.file] = self._download(item.file) + if item.file == input_f: + from_remote = True + if input_f and not from_remote and input_f not in file_to_bytes: + file_to_bytes[input_f] = self._download(input_f) + return file_to_bytes + + def _broadcast_file_sizes(self, file_to_bytes): + file_to_size = {f: len(b) for f, b in file_to_bytes.items()} + file_to_size_list = self._gather(file_to_size) + if self._rank == 0: + agg_file_to_size = {} + for ele in file_to_size_list: + agg_file_to_size.update(ele) + file_to_size = agg_file_to_size + return self._broadcast(file_to_size) + + def _make_strategy(self, input_f: str) -> Strategy: + accessible_file, inaccessible_file = None, None + if bfile.is_local_path(input_f) and not bfile.exists(input_f): + inaccessible_file = input_f + else: + accessible_file = input_f + + file_tuples = self._gather((accessible_file, inaccessible_file)) + + inac_files = None + if self._rank == 0: + files = [t[0] or t[1] for t in file_tuples] + inac_files = list({t[1] for t in file_tuples if t[1]}) + inac_files = self._broadcast(inac_files) + existences = [] + for f in inac_files: + existences.append(bfile.exists(f)) + agg_existences = self._gather(existences) + s = None + if self._rank == 0: + ac_file_existence = (True,) * len(files) + file_existences = DefaultDict(lambda: ac_file_existence) + for file in inac_files: + file_existences[file] = list() + for existences in agg_existences: + for file, exist in zip(inac_files, existences): + file_existences[file].append(exist) + logger.info(f"File existences info: {file_existences}") + + eof = True + for f in files: + if f: + eof = False + if eof: + s = Strategy(eof=True) + else: + s = self._custom_strategy(files, file_existences, self._bc_thres) + return self._broadcast(s) + + +class RemappingTorchLoader: + _LOAD_LOCK = threading.Lock() + + def __init__(self, path_prefix_pair: Tuple[str, str] = ("", "")): + self._ppp = path_prefix_pair + self._old_load = None + + def __enter__(self): + RemappingTorchLoader._LOAD_LOCK.acquire() + self._old_load = torch.load + torch.load = self._loader_wrapper + return self + + def __exit__(self, *args): + torch.load = self._old_load + RemappingTorchLoader._LOAD_LOCK.release() + + def _loader_wrapper(self, f, *args, **kwargs): + if not isinstance(f, (str, os.PathLike)): + return self._old_load(f, *args, **kwargs) + f = str(f) + if f.startswith(self._ppp[0]): + f = f[len(self._ppp[0]) :] + f = self._ppp[1] + f + with bfile.BFile(f, "rb") as fi: + return self._old_load(fi, *args, **kwargs) diff --git a/python/vescale/checkpoint/utilities/server/__init__.py b/python/vescale/checkpoint/utilities/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/python/vescale/checkpoint/utilities/server/detached_mem_server.py b/python/vescale/checkpoint/utilities/server/detached_mem_server.py new file mode 100644 index 0000000..c42a8bf --- /dev/null +++ b/python/vescale/checkpoint/utilities/server/detached_mem_server.py @@ -0,0 +1,30 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +import argparse +import os + +from . import mem_server_lib + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--name") + args = parser.parse_args() + server = mem_server_lib.start_server(args.name) + try: + server.wait_for_termination() + finally: + os.remove(mem_server_lib.get_mem_server_sock_file(args.name)) diff --git a/python/vescale/checkpoint/utilities/server/mem_file_service.proto b/python/vescale/checkpoint/utilities/server/mem_file_service.proto new file mode 100644 index 0000000..6ca5723 --- /dev/null +++ b/python/vescale/checkpoint/utilities/server/mem_file_service.proto @@ -0,0 +1,72 @@ +// Run +// +// python -m grpc_tools.protoc -I. --python_out=. --pyi_out=. \ +// --grpc_python_out=. ./omnistore/utilities/server/mem_file_service.proto +// +// to generate new protos. + +syntax = "proto3"; + +message OmniStoreWriteRequest { + bytes content = 1; + + string name = 8; +} + +message OmniStoreWriteResponse { +} + +message OmniStoreReadRequest { + string name = 1; +} + +message OmniStoreReadResponse { + bytes content = 1; +} + +message OmniStoreRenameRequest { + string src = 1; + string dst = 2; + bool overwrite = 3; +} + +message OmniStoreRenameResponse { +} + +message OmniStoreRemoveRequest { + string name = 1; +} + +message OmniStoreRemoveResponse { +} + +message OmniStoreListdirRequest { + string name = 1; +} + +message OmniStoreListdirResponse { + repeated string names = 1; +} + +message OmniStoreExistsRequest { + string name = 1; +} + +message OmniStoreExistsResponse { + bool exists = 1; +} + +service OmniStoreMemFileService { + rpc Write(stream OmniStoreWriteRequest) returns (OmniStoreWriteResponse) { + } + rpc Read(OmniStoreReadRequest) returns (stream OmniStoreReadResponse) { + } + rpc Rename(OmniStoreRenameRequest) returns (OmniStoreRenameResponse) { + } + rpc Remove(OmniStoreRemoveRequest) returns (OmniStoreRemoveResponse) { + } + rpc Listdir(OmniStoreListdirRequest) returns (OmniStoreListdirResponse) { + } + rpc Exists(OmniStoreExistsRequest) returns (OmniStoreExistsResponse) { + } +} \ No newline at end of file diff --git a/python/vescale/checkpoint/utilities/server/mem_file_service_pb2.py b/python/vescale/checkpoint/utilities/server/mem_file_service_pb2.py new file mode 100644 index 0000000..feebf88 --- /dev/null +++ b/python/vescale/checkpoint/utilities/server/mem_file_service_pb2.py @@ -0,0 +1,66 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: omnistore/utilities/server/mem_file_service.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" + +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n1OmniStore/utilities/server/mem_file_service.proto"6\n\x15OmniStoreWriteRequest\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\x0c\x12\x0c\n\x04name\x18\x08 \x01(\t"\x18\n\x16OmniStoreWriteResponse"$\n\x14OmniStoreReadRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"(\n\x15OmniStoreReadResponse\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\x0c"E\n\x16OmniStoreRenameRequest\x12\x0b\n\x03src\x18\x01 \x01(\t\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x11\n\toverwrite\x18\x03 \x01(\x08"\x19\n\x17OmniStoreRenameResponse"&\n\x16OmniStoreRemoveRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"\x19\n\x17OmniStoreRemoveResponse"\'\n\x17OmniStoreListdirRequest\x12\x0c\n\x04name\x18\x01 \x01(\t")\n\x18OmniStoreListdirResponse\x12\r\n\x05names\x18\x01 \x03(\t"&\n\x16OmniStoreExistsRequest\x12\x0c\n\x04name\x18\x01 \x01(\t")\n\x17OmniStoreExistsResponse\x12\x0e\n\x06\x65xists\x18\x01 \x01(\x08\x32\x91\x03\n\x17OmniStoreMemFileService\x12<\n\x05Write\x12\x16.OmniStoreWriteRequest\x1a\x17.OmniStoreWriteResponse"\x00(\x01\x12\x39\n\x04Read\x12\x15.OmniStoreReadRequest\x1a\x16.OmniStoreReadResponse"\x00\x30\x01\x12=\n\x06Rename\x12\x17.OmniStoreRenameRequest\x1a\x18.OmniStoreRenameResponse"\x00\x12=\n\x06Remove\x12\x17.OmniStoreRemoveRequest\x1a\x18.OmniStoreRemoveResponse"\x00\x12@\n\x07Listdir\x12\x18.OmniStoreListdirRequest\x1a\x19.OmniStoreListdirResponse"\x00\x12=\n\x06\x45xists\x12\x17.OmniStoreExistsRequest\x1a\x18.OmniStoreExistsResponse"\x00\x62\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "omnistore.utilities.server.mem_file_service_pb2", _globals) +if _descriptor._USE_C_DESCRIPTORS is False: + DESCRIPTOR._options = None + _globals["_OMNISTOREWRITEREQUEST"]._serialized_start = 53 + _globals["_OMNISTOREWRITEREQUEST"]._serialized_end = 107 + _globals["_OMNISTOREWRITERESPONSE"]._serialized_start = 109 + _globals["_OMNISTOREWRITERESPONSE"]._serialized_end = 133 + _globals["_OMNISTOREREADREQUEST"]._serialized_start = 135 + _globals["_OMNISTOREREADREQUEST"]._serialized_end = 171 + _globals["_OMNISTOREREADRESPONSE"]._serialized_start = 173 + _globals["_OMNISTOREREADRESPONSE"]._serialized_end = 213 + _globals["_OMNISTORERENAMEREQUEST"]._serialized_start = 215 + _globals["_OMNISTORERENAMEREQUEST"]._serialized_end = 284 + _globals["_OMNISTORERENAMERESPONSE"]._serialized_start = 286 + _globals["_OMNISTORERENAMERESPONSE"]._serialized_end = 311 + _globals["_OMNISTOREREMOVEREQUEST"]._serialized_start = 313 + _globals["_OMNISTOREREMOVEREQUEST"]._serialized_end = 351 + _globals["_OMNISTOREREMOVERESPONSE"]._serialized_start = 353 + _globals["_OMNISTOREREMOVERESPONSE"]._serialized_end = 378 + _globals["_OMNISTORELISTDIRREQUEST"]._serialized_start = 380 + _globals["_OMNISTORELISTDIRREQUEST"]._serialized_end = 419 + _globals["_OMNISTORELISTDIRRESPONSE"]._serialized_start = 421 + _globals["_OMNISTORELISTDIRRESPONSE"]._serialized_end = 462 + _globals["_OMNISTOREEXISTSREQUEST"]._serialized_start = 464 + _globals["_OMNISTOREEXISTSREQUEST"]._serialized_end = 502 + _globals["_OMNISTOREEXISTSRESPONSE"]._serialized_start = 504 + _globals["_OMNISTOREEXISTSRESPONSE"]._serialized_end = 545 + _globals["_OMNISTOREMEMFILESERVICE"]._serialized_start = 548 + _globals["_OMNISTOREMEMFILESERVICE"]._serialized_end = 949 +# @@protoc_insertion_point(module_scope) diff --git a/python/vescale/checkpoint/utilities/server/mem_file_service_pb2.pyi b/python/vescale/checkpoint/utilities/server/mem_file_service_pb2.pyi new file mode 100644 index 0000000..dc71884 --- /dev/null +++ b/python/vescale/checkpoint/utilities/server/mem_file_service_pb2.pyi @@ -0,0 +1,78 @@ +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Optional as _Optional + +DESCRIPTOR: _descriptor.FileDescriptor + +class OmniStoreWriteRequest(_message.Message): + __slots__ = ("content", "name") + CONTENT_FIELD_NUMBER: _ClassVar[int] + NAME_FIELD_NUMBER: _ClassVar[int] + content: bytes + name: str + def __init__(self, content: _Optional[bytes] = ..., name: _Optional[str] = ...) -> None: ... + +class OmniStoreWriteResponse(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class OmniStoreReadRequest(_message.Message): + __slots__ = ("name",) + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class OmniStoreReadResponse(_message.Message): + __slots__ = ("content",) + CONTENT_FIELD_NUMBER: _ClassVar[int] + content: bytes + def __init__(self, content: _Optional[bytes] = ...) -> None: ... + +class OmniStoreRenameRequest(_message.Message): + __slots__ = ("src", "dst", "overwrite") + SRC_FIELD_NUMBER: _ClassVar[int] + DST_FIELD_NUMBER: _ClassVar[int] + OVERWRITE_FIELD_NUMBER: _ClassVar[int] + src: str + dst: str + overwrite: bool + def __init__(self, src: _Optional[str] = ..., dst: _Optional[str] = ..., overwrite: bool = ...) -> None: ... + +class OmniStoreRenameResponse(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class OmniStoreRemoveRequest(_message.Message): + __slots__ = ("name",) + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class OmniStoreRemoveResponse(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class OmniStoreListdirRequest(_message.Message): + __slots__ = ("name",) + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class OmniStoreListdirResponse(_message.Message): + __slots__ = ("names",) + NAMES_FIELD_NUMBER: _ClassVar[int] + names: _containers.RepeatedScalarFieldContainer[str] + def __init__(self, names: _Optional[_Iterable[str]] = ...) -> None: ... + +class OmniStoreExistsRequest(_message.Message): + __slots__ = ("name",) + NAME_FIELD_NUMBER: _ClassVar[int] + name: str + def __init__(self, name: _Optional[str] = ...) -> None: ... + +class OmniStoreExistsResponse(_message.Message): + __slots__ = ("exists",) + EXISTS_FIELD_NUMBER: _ClassVar[int] + exists: bool + def __init__(self, exists: bool = ...) -> None: ... diff --git a/python/vescale/checkpoint/utilities/server/mem_file_service_pb2_grpc.py b/python/vescale/checkpoint/utilities/server/mem_file_service_pb2_grpc.py new file mode 100644 index 0000000..978b388 --- /dev/null +++ b/python/vescale/checkpoint/utilities/server/mem_file_service_pb2_grpc.py @@ -0,0 +1,321 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" + +import grpc + +from . import ( + mem_file_service_pb2 as OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2, +) + + +class OmniStoreMemFileServiceStub: + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Write = channel.stream_unary( + "/OmniStoreMemFileService/Write", + request_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreWriteRequest.SerializeToString, + response_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreWriteResponse.FromString, + ) + self.Read = channel.unary_stream( + "/OmniStoreMemFileService/Read", + request_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreReadRequest.SerializeToString, + response_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreReadResponse.FromString, + ) + self.Rename = channel.unary_unary( + "/OmniStoreMemFileService/Rename", + request_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRenameRequest.SerializeToString, + response_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRenameResponse.FromString, + ) + self.Remove = channel.unary_unary( + "/OmniStoreMemFileService/Remove", + request_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRemoveRequest.SerializeToString, + response_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRemoveResponse.FromString, + ) + self.Listdir = channel.unary_unary( + "/OmniStoreMemFileService/Listdir", + request_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreListdirRequest.SerializeToString, + response_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreListdirResponse.FromString, + ) + self.Exists = channel.unary_unary( + "/OmniStoreMemFileService/Exists", + request_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreExistsRequest.SerializeToString, + response_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreExistsResponse.FromString, + ) + + +class OmniStoreMemFileServiceServicer: + """Missing associated documentation comment in .proto file.""" + + def Write(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def Read(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def Rename(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def Remove(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def Listdir(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def Exists(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_OmniStoreMemFileServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + "Write": grpc.stream_unary_rpc_method_handler( + servicer.Write, + request_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreWriteRequest.FromString, + response_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreWriteResponse.SerializeToString, + ), + "Read": grpc.unary_stream_rpc_method_handler( + servicer.Read, + request_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreReadRequest.FromString, + response_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreReadResponse.SerializeToString, + ), + "Rename": grpc.unary_unary_rpc_method_handler( + servicer.Rename, + request_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRenameRequest.FromString, + response_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRenameResponse.SerializeToString, + ), + "Remove": grpc.unary_unary_rpc_method_handler( + servicer.Remove, + request_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRemoveRequest.FromString, + response_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRemoveResponse.SerializeToString, + ), + "Listdir": grpc.unary_unary_rpc_method_handler( + servicer.Listdir, + request_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreListdirRequest.FromString, + response_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreListdirResponse.SerializeToString, + ), + "Exists": grpc.unary_unary_rpc_method_handler( + servicer.Exists, + request_deserializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreExistsRequest.FromString, + response_serializer=OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreExistsResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler("OmniStoreMemFileService", rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class OmniStoreMemFileService: + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Write( + request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.stream_unary( + request_iterator, + target, + "/OmniStoreMemFileService/Write", + OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreWriteRequest.SerializeToString, + OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreWriteResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def Read( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_stream( + request, + target, + "/OmniStoreMemFileService/Read", + OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreReadRequest.SerializeToString, + OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreReadResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def Rename( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/OmniStoreMemFileService/Rename", + OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRenameRequest.SerializeToString, + OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRenameResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def Remove( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/OmniStoreMemFileService/Remove", + OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRemoveRequest.SerializeToString, + OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreRemoveResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def Listdir( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/OmniStoreMemFileService/Listdir", + OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreListdirRequest.SerializeToString, + OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreListdirResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def Exists( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/OmniStoreMemFileService/Exists", + OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreExistsRequest.SerializeToString, + OmniStore_dot_utilities_dot_server_dot_mem__file__service__pb2.OmniStoreExistsResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/python/vescale/checkpoint/utilities/server/mem_server_lib.py b/python/vescale/checkpoint/utilities/server/mem_server_lib.py new file mode 100644 index 0000000..b12a8e1 --- /dev/null +++ b/python/vescale/checkpoint/utilities/server/mem_server_lib.py @@ -0,0 +1,307 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +import dataclasses +import io +import grpc +from typing import Tuple +import os +import threading +import contextlib +import pathlib +import subprocess +import time +import queue +from concurrent import futures + +from . import mem_file_service_pb2 +from . import mem_file_service_pb2_grpc + + +class _Directory(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lock = threading.RLock() + + +@dataclasses.dataclass +class _File: + content: bytes = b"" + + +_CHUNK_SIZE = 2 * 1024 * 1024 + + +def get_mem_server_sock_file(name: str): + return f"/var/tmp/mem_server_{name}.sock" + + +class MemFileServicer(mem_file_service_pb2_grpc.OmniStoreMemFileServiceServicer): + def __init__(self): + self._d = _Directory() + + def Write(self, request_iterator, ctx: grpc.ServicerContext): + b = io.BytesIO() + name = None + for req in request_iterator: + if name is None: + if not req.name: + ctx.abort(grpc.StatusCode.INVALID_ARGUMENT, "Name must be specified.") + name = req.name + d, bn = self._iterate_dir(name, ctx, create=True) + b.write(req.content) + if name: + with d.lock: + d[bn] = _File(content=b.getvalue()) + return mem_file_service_pb2.OmniStoreWriteResponse() + + def Read(self, req, ctx: grpc.ServicerContext): + d, bn = self._iterate_dir(req.name, ctx) + with d.lock: + if bn not in d or not isinstance(d[bn], _File): + ctx.abort(grpc.StatusCode.NOT_FOUND, f"{req.name} not found.") + f: _File = d[bn] + cur = 0 + while cur < len(f.content): + yield mem_file_service_pb2.OmniStoreReadResponse(content=f.content[cur : cur + _CHUNK_SIZE]) + cur += _CHUNK_SIZE + + def Rename(self, req, ctx: grpc.ServicerContext): + src_dir, src_bn = self._iterate_dir(req.src, ctx) + dst_dir, dst_bn = self._iterate_dir(req.dst, ctx) + if src_dir != dst_dir: + ctx.abort(grpc.StatusCode.UNIMPLEMENTED, "Rename across dir is not supported.") + d = src_dir + with d.lock: + if src_bn not in src_bn: + ctx.abort(grpc.StatusCode.NOT_FOUND, f"{req.src} is not found.") + if not req.overwrite and dst_bn in d: + ctx.abort(grpc.StatusCode.ALREADY_EXISTS, f"{req.dst} already exists.") + d[dst_bn] = d[src_bn] + del d[src_bn] + return mem_file_service_pb2.OmniStoreRenameResponse() + + def Remove(self, req, ctx: grpc.ServicerContext): + d, bn = self._iterate_dir(req.name, ctx) + if bn not in d: + ctx.abort(grpc.StatusCode.NOT_FOUND, f"{req.name} not found.") + with d.lock: + del d[bn] + return mem_file_service_pb2.OmniStoreRemoveResponse() + + def Listdir(self, req, ctx: grpc.ServicerContext): + d, _ = self._iterate_dir(os.path.join(req.name, "*")) + if d is None: + return mem_file_service_pb2.OmniStoreListdirResponse() + + resp = mem_file_service_pb2.OmniStoreListdirResponse() + with d.lock: + for name in d: + resp.names.append(name) + return resp + + def Exists(self, req, ctx: grpc.ServicerContext): + d, bn = self._iterate_dir(req.name) + if d is None: + return mem_file_service_pb2.OmniStoreExistsResponse(exists=False) + with d.lock: + return mem_file_service_pb2.OmniStoreExistsResponse(exists=bn in d) + + def _iterate_dir(self, name: str, ctx: grpc.ServicerContext = None, create=False) -> Tuple[_Directory, str]: + if ctx is None: + + class FakeCtx: + def abort(*args, **kwargs): + return None, None + + ctx = FakeCtx() + name = str(pathlib.Path(name).absolute())[1:] + parts = name.split("/") + cur = self._d + for part in parts[:-1]: + with cur.lock: + if part not in cur: + if not create: + return ctx.abort(grpc.StatusCode.NOT_FOUND, f"{part} doesn't exist.") + else: + cur[part] = _Directory() + cur = cur[part] + if not isinstance(cur, _Directory): + return ctx.abort( + grpc.StatusCode.ALREADY_EXISTS, + f"{part} already exist as a file.", + ) + return cur, parts[-1] + + +def start_server(name: str, force=False): + sock = get_mem_server_sock_file(name) + if os.path.exists(sock) and not force: + raise OSError("Mem server is already running.") + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + mem_file_service_pb2_grpc.add_OmniStoreMemFileServiceServicer_to_server(MemFileServicer(), server) + server.add_insecure_port(f"unix:{sock}") + server.start() + return server + + +# --- Below is general file interface --- + +_STUB_CACHE = {} +_STUB_CACHE_LOCK = threading.Lock() +SCHEMA = "/local_mem/" + + +def get_prefix(name: str): + return SCHEMA + name + + +def _get_mem_name_and_name(path: str): + path = path[len(SCHEMA) :] + pos = path.find("/") + if pos == -1: + return path, "/" + else: + return path[:pos], path[pos:] + + +def _get_stub_and_name( + path: str, +) -> Tuple[mem_file_service_pb2_grpc.OmniStoreMemFileServiceStub, str]: + mem_name, name = _get_mem_name_and_name(path) + if mem_name not in _STUB_CACHE: + c = grpc.insecure_channel(f"unix:{get_mem_server_sock_file(mem_name)}") + with _STUB_CACHE_LOCK: + _STUB_CACHE[mem_name] = mem_file_service_pb2_grpc.OmniStoreMemFileServiceStub(c) + return _STUB_CACHE[mem_name], name + + +class _FileLike: + def __init__(self, name: str, mode: str): + if mode not in ["rb", "wb"]: + raise NotImplementedError(f"{mode} is not implemented.") + self._stub, self._name = _get_stub_and_name(name) + self._mode = mode + self._is_write = "w" in mode + if self._is_write: + self._write_async() + self._read_buf = None + + @property + def read_buf(self): + if self._read_buf is None: + self._read_buf = io.BytesIO() + for resp in self._stub.Read(mem_file_service_pb2.OmniStoreReadRequest(name=self._name)): + self._read_buf.write(resp.content) + self._read_buf.seek(0) + return self._read_buf + + def __getattr__(self, name): + if not self._is_write: + return getattr(self.read_buf, name) + + def _write_async(self): + self._q = queue.Queue() + + def streaming(): + while True: + content, eof = self._q.get() + if eof: + break + cur = 0 + while cur < len(content): + req = mem_file_service_pb2.OmniStoreWriteRequest(content=content[cur : cur + _CHUNK_SIZE]) + if cur == 0: + req.name = self._name + yield req + cur += _CHUNK_SIZE + + self._write_future = self._stub.Write.future(streaming()) + + def write(self, content): + self._q.put((content, False)) + + def close(self): + if self._is_write: + self._q.put((None, True)) + self._write_future.result() + + +@contextlib.contextmanager +def open(name, mode) -> io.FileIO: + f = _FileLike(name, mode) + try: + yield f + finally: + f.close() + + +def rename(src, dst, overwrite=False): + stub, src_name = _get_stub_and_name(src) + dst_stub, dst_name = _get_stub_and_name(dst) + if stub != dst_stub: + raise ValueError(f"Rename across mem file system is not supported. {src} {dst}") + stub.Rename(mem_file_service_pb2.OmniStoreRenameRequest(src=src_name, dst=dst_name, overwrite=overwrite)) + + +def remove(name): + stub, subname = _get_stub_and_name(name) + stub.Remove(mem_file_service_pb2.OmniStoreRemoveRequest(name=subname)) + + +def listdir(name): + try: + stub, subname = _get_stub_and_name(name) + resp = stub.Listdir(mem_file_service_pb2.OmniStoreListdirRequest(name=subname)) + return list(resp.names) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + return [] + raise + + +def exists(name): + try: + stub, subname = _get_stub_and_name(name) + resp = stub.Exists(mem_file_service_pb2.OmniStoreExistsRequest(name=subname)) + return resp.exists + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + return False + raise + + +# --- interface done --- + + +def start_server_in_new_process(name: str): + filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), "detached_mem_server.py") + return subprocess.Popen(["python3", filename, f"--name={name}"]) + + +def wait_until_fs_ready(name: str, timeout=120): + stub, _ = _get_stub_and_name(os.path.join(SCHEMA, name)) + t0 = time.time() + while time.time() < t0 + timeout: + try: + stub.Listdir(mem_file_service_pb2.OmniStoreListdirRequest(name="/")) + return True + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + time.sleep(0.1) + continue + raise + return False diff --git a/python/vescale/checkpoint/utilities/server/report_service.proto b/python/vescale/checkpoint/utilities/server/report_service.proto new file mode 100644 index 0000000..d5ead0c --- /dev/null +++ b/python/vescale/checkpoint/utilities/server/report_service.proto @@ -0,0 +1,49 @@ +// Run +// +// python -m grpc_tools.protoc -I. --python_out=. --pyi_out=. \ +// --grpc_python_out=. ./omnistore/utilities/server/report_service.proto +// +// to generate new protos. + +syntax = "proto3"; + +message OmniStoreGatherRequest { + // Used to distinguish different tasks. + string tag = 1; + int32 rank = 2; + bytes content = 3; + bool with_result = 4; +} + +message OmniStoreGatherResponse { + repeated bytes contents = 1; +} + +message OmniStoreBroadcastRequest { + string tag = 1; + int32 rank = 2; + bytes content = 3; + int32 src_rank = 4; +} + +message OmniStoreBroadcastResponse { + bytes content = 1; +} + +message OmniStoreGetStatusRequest { +} + +message OmniStoreGetStatusResponse { + bytes status = 1; +} + +service OmniStoreReportService { + rpc Gather(OmniStoreGatherRequest) returns (OmniStoreGatherResponse) { + } + + rpc Broadcast(OmniStoreBroadcastRequest) returns (OmniStoreBroadcastResponse) { + } + + rpc GetStatus(OmniStoreGetStatusRequest) returns (OmniStoreGetStatusResponse) { + } +} diff --git a/python/vescale/checkpoint/utilities/server/report_service_pb2.py b/python/vescale/checkpoint/utilities/server/report_service_pb2.py new file mode 100644 index 0000000..2a7f110 --- /dev/null +++ b/python/vescale/checkpoint/utilities/server/report_service_pb2.py @@ -0,0 +1,54 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: omnistore/utilities/server/report_service.proto +# Protobuf Python Version: 4.25.0 +"""Generated protocol buffer code.""" + +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n/omnistore/utilities/server/report_service.proto"Y\n\x16OmniStoreGatherRequest\x12\x0b\n\x03tag\x18\x01 \x01(\t\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\x0c\x12\x13\n\x0bwith_result\x18\x04 \x01(\x08"+\n\x17OmniStoreGatherResponse\x12\x10\n\x08\x63ontents\x18\x01 \x03(\x0c"Y\n\x19OmniStoreBroadcastRequest\x12\x0b\n\x03tag\x18\x01 \x01(\t\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x0f\n\x07\x63ontent\x18\x03 \x01(\x0c\x12\x10\n\x08src_rank\x18\x04 \x01(\x05"-\n\x1aOmniStoreBroadcastResponse\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\x0c"\x1b\n\x19OmniStoreGetStatusRequest",\n\x1aOmniStoreGetStatusResponse\x12\x0e\n\x06status\x18\x01 \x01(\x0c\x32\xe7\x01\n\x16OmniStoreReportService\x12=\n\x06Gather\x12\x17.OmniStoreGatherRequest\x1a\x18.OmniStoreGatherResponse"\x00\x12\x46\n\tBroadcast\x12\x1a.OmniStoreBroadcastRequest\x1a\x1b.OmniStoreBroadcastResponse"\x00\x12\x46\n\tGetStatus\x12\x1a.OmniStoreGetStatusRequest\x1a\x1b.OmniStoreGetStatusResponse"\x00\x62\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "omnistore.utilities.server.report_service_pb2", _globals) +if _descriptor._USE_C_DESCRIPTORS is False: + DESCRIPTOR._options = None + _globals["_OMNISTOREGATHERREQUEST"]._serialized_start = 51 + _globals["_OMNISTOREGATHERREQUEST"]._serialized_end = 140 + _globals["_OMNISTOREGATHERRESPONSE"]._serialized_start = 142 + _globals["_OMNISTOREGATHERRESPONSE"]._serialized_end = 185 + _globals["_OMNISTOREBROADCASTREQUEST"]._serialized_start = 187 + _globals["_OMNISTOREBROADCASTREQUEST"]._serialized_end = 276 + _globals["_OMNISTOREBROADCASTRESPONSE"]._serialized_start = 278 + _globals["_OMNISTOREBROADCASTRESPONSE"]._serialized_end = 323 + _globals["_OMNISTOREGETSTATUSREQUEST"]._serialized_start = 325 + _globals["_OMNISTOREGETSTATUSREQUEST"]._serialized_end = 352 + _globals["_OMNISTOREGETSTATUSRESPONSE"]._serialized_start = 354 + _globals["_OMNISTOREGETSTATUSRESPONSE"]._serialized_end = 398 + _globals["_OMNISTOREREPORTSERVICE"]._serialized_start = 401 + _globals["_OMNISTOREREPORTSERVICE"]._serialized_end = 632 +# @@protoc_insertion_point(module_scope) diff --git a/python/vescale/checkpoint/utilities/server/report_service_pb2.pyi b/python/vescale/checkpoint/utilities/server/report_service_pb2.pyi new file mode 100644 index 0000000..a031c11 --- /dev/null +++ b/python/vescale/checkpoint/utilities/server/report_service_pb2.pyi @@ -0,0 +1,64 @@ +from google.protobuf.internal import containers as _containers +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Iterable as _Iterable, Optional as _Optional + +DESCRIPTOR: _descriptor.FileDescriptor + +class OmniStoreGatherRequest(_message.Message): + __slots__ = ("tag", "rank", "content", "with_result") + TAG_FIELD_NUMBER: _ClassVar[int] + RANK_FIELD_NUMBER: _ClassVar[int] + CONTENT_FIELD_NUMBER: _ClassVar[int] + WITH_RESULT_FIELD_NUMBER: _ClassVar[int] + tag: str + rank: int + content: bytes + with_result: bool + def __init__( + self, + tag: _Optional[str] = ..., + rank: _Optional[int] = ..., + content: _Optional[bytes] = ..., + with_result: bool = ..., + ) -> None: ... + +class OmniStoreGatherResponse(_message.Message): + __slots__ = ("contents",) + CONTENTS_FIELD_NUMBER: _ClassVar[int] + contents: _containers.RepeatedScalarFieldContainer[bytes] + def __init__(self, contents: _Optional[_Iterable[bytes]] = ...) -> None: ... + +class OmniStoreBroadcastRequest(_message.Message): + __slots__ = ("tag", "rank", "content", "src_rank") + TAG_FIELD_NUMBER: _ClassVar[int] + RANK_FIELD_NUMBER: _ClassVar[int] + CONTENT_FIELD_NUMBER: _ClassVar[int] + SRC_RANK_FIELD_NUMBER: _ClassVar[int] + tag: str + rank: int + content: bytes + src_rank: int + def __init__( + self, + tag: _Optional[str] = ..., + rank: _Optional[int] = ..., + content: _Optional[bytes] = ..., + src_rank: _Optional[int] = ..., + ) -> None: ... + +class OmniStoreBroadcastResponse(_message.Message): + __slots__ = ("content",) + CONTENT_FIELD_NUMBER: _ClassVar[int] + content: bytes + def __init__(self, content: _Optional[bytes] = ...) -> None: ... + +class OmniStoreGetStatusRequest(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class OmniStoreGetStatusResponse(_message.Message): + __slots__ = ("status",) + STATUS_FIELD_NUMBER: _ClassVar[int] + status: bytes + def __init__(self, status: _Optional[bytes] = ...) -> None: ... diff --git a/python/vescale/checkpoint/utilities/server/report_service_pb2_grpc.py b/python/vescale/checkpoint/utilities/server/report_service_pb2_grpc.py new file mode 100644 index 0000000..85f55c4 --- /dev/null +++ b/python/vescale/checkpoint/utilities/server/report_service_pb2_grpc.py @@ -0,0 +1,184 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" + +import grpc + +from . import report_service_pb2 as OmniStore_dot_utilities_dot_server_dot_report__service__pb2 + + +class OmniStoreReportServiceStub: + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Gather = channel.unary_unary( + "/OmniStoreReportService/Gather", + request_serializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGatherRequest.SerializeToString, + response_deserializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGatherResponse.FromString, + ) + self.Broadcast = channel.unary_unary( + "/OmniStoreReportService/Broadcast", + request_serializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreBroadcastRequest.SerializeToString, + response_deserializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreBroadcastResponse.FromString, + ) + self.GetStatus = channel.unary_unary( + "/OmniStoreReportService/GetStatus", + request_serializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGetStatusRequest.SerializeToString, + response_deserializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGetStatusResponse.FromString, + ) + + +class OmniStoreReportServiceServicer: + """Missing associated documentation comment in .proto file.""" + + def Gather(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def Broadcast(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def GetStatus(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_OmniStoreReportServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + "Gather": grpc.unary_unary_rpc_method_handler( + servicer.Gather, + request_deserializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGatherRequest.FromString, + response_serializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGatherResponse.SerializeToString, + ), + "Broadcast": grpc.unary_unary_rpc_method_handler( + servicer.Broadcast, + request_deserializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreBroadcastRequest.FromString, + response_serializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreBroadcastResponse.SerializeToString, + ), + "GetStatus": grpc.unary_unary_rpc_method_handler( + servicer.GetStatus, + request_deserializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGetStatusRequest.FromString, + response_serializer=OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGetStatusResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler("OmniStoreReportService", rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class OmniStoreReportService: + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Gather( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/OmniStoreReportService/Gather", + OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGatherRequest.SerializeToString, + OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGatherResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def Broadcast( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/OmniStoreReportService/Broadcast", + OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreBroadcastRequest.SerializeToString, + OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreBroadcastResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def GetStatus( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/OmniStoreReportService/GetStatus", + OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGetStatusRequest.SerializeToString, + OmniStore_dot_utilities_dot_server_dot_report__service__pb2.OmniStoreGetStatusResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/python/vescale/checkpoint/utilities/server/server_lib.py b/python/vescale/checkpoint/utilities/server/server_lib.py new file mode 100644 index 0000000..ee6830b --- /dev/null +++ b/python/vescale/checkpoint/utilities/server/server_lib.py @@ -0,0 +1,229 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +import grpc +import asyncio +import threading +import dataclasses +import socket +import ipaddress +from typing import DefaultDict, Dict +import pickle +import multiprocessing +import time +import zlib + +from . import report_service_pb2 +from . import report_service_pb2_grpc + + +@dataclasses.dataclass +class Item: + cv: asyncio.Condition = dataclasses.field(default_factory=asyncio.Condition) + contents: dict = dataclasses.field(default_factory=dict) + ranks: set = dataclasses.field(default_factory=set) + + +_GRPC_OPTIONS = [ + ("grpc.max_send_message_length", 1024 * 1024 * 1024), + ("grpc.max_receive_message_length", 1024 * 1024 * 1024), + ("grpc.enable_http_proxy", 0), +] + + +class ReportServicer(report_service_pb2_grpc.OmniStoreReportServiceServicer): + """A servicer that simulate `gather` in sync training. + Using asyncio since we will block all incoming requests + until we gather all. + Usage: + GatherRank: servicer.wait_for(tag) + OtherRanks: stub.ReportAndWait(Request(tag=tag)) + """ + + def __init__(self, world_size: int): + self._l = asyncio.Lock() + self._world_size = world_size + + self._gather_dict = DefaultDict(Item) + self._bc_dict = DefaultDict(Item) + + async def Gather(self, req: report_service_pb2.OmniStoreGatherRequest, ctx: grpc.aio.ServicerContext): + i = await self._record(self._gather_dict, req, ctx) + resp = report_service_pb2.OmniStoreGatherResponse() + if req.with_result: + resp.contents.extend([v for k, v in sorted(i.contents.items(), key=lambda x: x[0])]) + + return resp + + async def Broadcast(self, req: report_service_pb2.OmniStoreBroadcastRequest, ctx: grpc.aio.ServicerContext): + i = await self._record(self._bc_dict, req, ctx) + return report_service_pb2.OmniStoreBroadcastResponse(content=i.contents[req.src_rank]) + + async def _record(self, d: Dict[str, Item], req, ctx: grpc.aio.ServicerContext): + async with self._l: + i = d[req.tag] + async with i.cv: + if req.rank in i.ranks: + ctx.abort( + grpc.StatusCode.INTERNAL, + f"Using the same tag in multiple threads/processes. tag: {req.tag}", + ) + i.ranks.add(req.rank) + if req.content: + i.contents[req.rank] = req.content + if len(i.ranks) == self._world_size: + async with self._l: + del d[req.tag] + i.cv.notify_all() + await i.cv.wait_for(lambda: len(i.ranks) == self._world_size) + return i + + async def GetStatus(self, req: report_service_pb2.OmniStoreGetStatusRequest, ctx: grpc.aio.ServicerContext): + async with self._l: + b = pickle.dumps( + { + "world_size": self._world_size, + "gather_dict": self._gather_dict, + "bc_dict": self._bc_dict, + } + ) + return report_service_pb2.OmniStoreGetStatusResponse(status=b) + + +def _is_ipv6_address(ip: str): + try: + ip_obj = ipaddress.ip_address(ip) + except ValueError: + return False + return ip_obj.version == 6 + + +def _concat_ip_and_port(ip: str, port: int): + if not _is_ipv6_address(ip): + return f"{ip}:{port}" + else: + return f"[{ip}]:{port}" + + +def _get_local_ip(): + try: + return socket.getaddrinfo(socket.gethostname(), None)[0][4][0] + except socket.gaierror: + return socket.getaddrinfo(socket.gethostname(), None, family=socket.AF_INET6)[0][4][0] + + +@dataclasses.dataclass +class _AsyncObj: + e: threading.Event = dataclasses.field(default_factory=threading.Event) + obj: object = None + + +async def async_serve(servicer, async_addr: _AsyncObj): + server: grpc.Server = grpc.aio.server(options=_GRPC_OPTIONS) + report_service_pb2_grpc.add_OmniStoreReportServiceServicer_to_server(servicer, server) + port = server.add_insecure_port("[::]:0") + await server.start() + async_addr.obj = _concat_ip_and_port(_get_local_ip(), port) + async_addr.e.set() + await server.wait_for_termination() + + +def serve(servicer) -> str: + async_addr = _AsyncObj() + th = threading.Thread( + target=lambda servicer=servicer, async_addr=async_addr: asyncio.run(async_serve(servicer, async_addr)), + daemon=True, + ) + th.start() + async_addr.e.wait() + return async_addr.obj + + +def _serve_in_loop(world_size, conn): + servicer = ReportServicer(world_size) + addr = serve(servicer) + conn.send(addr) + conn.close() + while True: + time.sleep(1) + + +def start_server_in_new_process(world_size: int): + parent_conn, child_conn = multiprocessing.Pipe() + p = multiprocessing.get_context("spawn").Process(target=_serve_in_loop, args=(world_size, child_conn), daemon=True) + p.start() + return parent_conn.recv() + + +def get_stub(addr: str): + channel = grpc.insecure_channel(addr, options=_GRPC_OPTIONS) + return report_service_pb2_grpc.OmniStoreReportServiceStub(channel) + + +def _get_tag(): + return "_default_tag" + + +def gather( + stub: report_service_pb2_grpc.OmniStoreReportServiceStub, + gather_rank: int, + rank: int, + obj, + tag: str = None, + timeout=None, +): + tag = tag or _get_tag() + req = report_service_pb2.OmniStoreGatherRequest( + tag=tag, rank=rank, content=pickle.dumps(obj), with_result=(gather_rank == rank) + ) + resp = stub.Gather(req, timeout=timeout) + if gather_rank != rank: + return + return [pickle.loads(content) for content in resp.contents] + + +def broadcast( + stub: report_service_pb2_grpc.OmniStoreReportServiceStub, + src_rank: int, + rank: int, + obj=None, + tag: str = None, + timeout=None, +): + tag = tag or _get_tag() + content = b"" if rank != src_rank else pickle.dumps(obj) + # Since we will transfer this to all machines, compression here is important. + c_content = zlib.compress(content) + resp = stub.Broadcast( + report_service_pb2.OmniStoreBroadcastRequest(tag=tag, rank=rank, content=c_content, src_rank=src_rank), + timeout=timeout, + ) + content = zlib.decompress(resp.content) + return pickle.loads(content) + + +def barrier( + stub: report_service_pb2_grpc.OmniStoreReportServiceStub, + rank: int, + tag: str = None, + timeout=None, +): + gather(stub, 0, rank, tag=tag, obj=None, timeout=timeout) + + +def get_server_status(stub: report_service_pb2_grpc.OmniStoreReportServiceStub): + resp = stub.GetStatus(report_service_pb2.OmniStoreGetStatusRequest()) + return pickle.loads(resp.status) diff --git a/python/vescale/checkpoint/utilities/server/server_status_client.py b/python/vescale/checkpoint/utilities/server/server_status_client.py new file mode 100644 index 0000000..bd2d6d5 --- /dev/null +++ b/python/vescale/checkpoint/utilities/server/server_status_client.py @@ -0,0 +1,26 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +import argparse + +from . import server_lib + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--addr") + args = parser.parse_args() + stub = server_lib.get_stub(args.addr) + print(server_lib.get_server_status(stub)) diff --git a/python/vescale/checkpoint/utilities/sync_queue.py b/python/vescale/checkpoint/utilities/sync_queue.py new file mode 100644 index 0000000..1f1746a --- /dev/null +++ b/python/vescale/checkpoint/utilities/sync_queue.py @@ -0,0 +1,48 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +import threading + + +# SynchronizedQueue is used for communications between training process and checkpoint uploading thread +class SynchronizedQueue: + def __init__(self): + self._task_done = True + self._item = None + self._cond = threading.Condition() + + def put(self, item) -> None: + with self._cond: + self._cond.wait_for(lambda: self._task_done) + self._task_done = False + self._item = item + self._cond.notify_all() + + def get(self): + with self._cond: + self._cond.wait_for(lambda: self._item is not None) + item = self._item + self._item = None + return item + + def task_done(self): + with self._cond: + self._task_done = True + self._cond.notify_all() + + def join(self, timeout=None) -> bool: + with self._cond: + return self._cond.wait_for(lambda: self._task_done, timeout=timeout) diff --git a/python/vescale/checkpoint/version.py b/python/vescale/checkpoint/version.py new file mode 100644 index 0000000..f7d2536 --- /dev/null +++ b/python/vescale/checkpoint/version.py @@ -0,0 +1,17 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +__version__ = "0.1.5" diff --git a/scripts/run_test.sh b/scripts/run_test.sh index 8fd7ec1..78a0430 100755 --- a/scripts/run_test.sh +++ b/scripts/run_test.sh @@ -21,7 +21,7 @@ do pkill -9 python3 || true # ok if nothing to kill pytest -s "${file}" pkill -9 python3 || true -done < <(find . -name 'test_*.py' -print0) +done < <(find . -name 'test_*.py' -not -name 'test_open_llama_*.py' -print0) # return popd diff --git a/test/checkpoint/__init__.py b/test/checkpoint/__init__.py new file mode 100644 index 0000000..1f4b03d --- /dev/null +++ b/test/checkpoint/__init__.py @@ -0,0 +1 @@ +# This file makes the directory a Python package for relative path import diff --git a/test/checkpoint/common_func.py b/test/checkpoint/common_func.py new file mode 100644 index 0000000..9479fdf --- /dev/null +++ b/test/checkpoint/common_func.py @@ -0,0 +1,309 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +# Define functions which are commonly used for nano_gpt checkpointing test +import torch +import math + +from vescale.dtensor.placement_types import Replicate, Shard +from vescale.dtensor.device_mesh import init_device_mesh +from vescale.dmodule.api import parallelize_module +from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP +from vescale.optim.distributed_optimizer import DistributedOptimizer +from transformers import AutoModelForCausalLM +from .nano_gpt import GPT, GPTConfig +import os + + +def flatten_dict(d, parent_key="", sep="_"): + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +fwd_plan = { + "transformer.wte.input": [[Replicate()]], + "transformer.wte.output": [[Replicate()]], + "transformer.wpe.input": [[Replicate()]], + "transformer.wpe.output": [[Replicate()]], + r"transformer.h.\d+.input": [[Shard(1)]], + r"transformer.h.\d+.attn.input": [[Replicate()]], + r"transformer.h.\d+.attn.c_proj.output": [[Replicate()]], + r"transformer.h.\d+.attn.output": [[Shard(1)]], + r"transformer.h.\d+.mlp.c_fc.input": [[Replicate()]], + r"transformer.h.\d+.mlp.c_proj.output": [[Replicate()]], + r"transformer.h.\d+.mlp.output": [[Shard(1)]], + "transformer.ln_f.input": [[Shard(1)]], + "lm_head.input": [[Shard(2)]], + "lm_head.output": [[Replicate()]], +} + +params_plan = { + "transformer.wte.weight": [Shard(1)], + "transformer.wpe.weight": [Shard(1)], + r"transformer.h.\d+.attn.q_proj.weight": [Shard(0)], + r"transformer.h.\d+.attn.q_proj.bias": [Shard(0)], + r"transformer.h.\d+.attn.k_proj.weight": [Shard(0)], + r"transformer.h.\d+.attn.k_proj.bias": [Shard(0)], + r"transformer.h.\d+.attn.v_proj.weight": [Shard(0)], + r"transformer.h.\d+.attn.v_proj.bias": [Shard(0)], + r"transformer.h.\d+.attn.c_proj.weight": [Shard(1)], + r"transformer.h.\d+.attn.c_proj.bias": [Replicate()], + r"transformer.h.\d+.mlp.c_fc.weight": [Shard(0)], + r"transformer.h.\d+.mlp.c_fc.bias": [Shard(0)], + r"transformer.h.\d+.mlp.c_proj.weight": [Shard(1)], + r"transformer.h.\d+.mlp.c_proj.bias": [Replicate()], + "lm_head.weight": [Shard(1)], +} + +nanoGPT_plan = {"parameter": params_plan, "forward": fwd_plan} + + +def build_gpt_model_optimizer_and_dataset(init_method, dp_size=1, tp_size=1): + # ----------------------------------------------------------------------------- + num_iters = 1 + # data + batch_size = 4 + block_size = 8 + vocab_size = 32 + # model + n_layer = 2 + n_head = 4 + n_embd = 16 + dropout = 0.1 # for pretraining 0 is good, for finetuning try 0.1+ + bias = True # do we use bias inside LayerNorm and Linear layers? + # system + torch.use_deterministic_algorithms(True) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + torch.manual_seed(999) + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + # ----------------------------------------------------------------------------- + # fake data loader + data_set = [] + for _ in range(num_iters): + idx = torch.randint(0, vocab_size, (batch_size, block_size), dtype=torch.int64).cuda() + target = torch.randint(0, vocab_size, (batch_size, block_size), dtype=torch.int64).cuda() + data_set.append((idx, target)) + + # model config + model_args = dict( + block_size=block_size, + vocab_size=vocab_size, + n_layer=n_layer, + n_head=n_head, + n_embd=n_embd, + dropout=dropout, + bias=bias, + ) + # DP=2 TP=2 + gptconf = GPTConfig(**model_args) + if init_method == "scratch": + gpt = GPT(gptconf).bfloat16() + else: + gpt = GPT.from_pretrained(init_method, dict(dropout=0.0)).bfloat16() + + device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP")) + device_mesh.__enter__() + + # Enable tensor Parallel + tp_gpt = parallelize_module(gpt, device_mesh["TP"], nanoGPT_plan) + + # Enable data Parallel + ddp_gpt = DDP( + tp_gpt, + data_pg_or_device_mesh=device_mesh["DP"], + accumulate_allreduce_grads_in_fp32=True, + overlap_grad_reduce=False, + use_distributed_optimizer=True, + ) + + # Build distributed optimizer + dist_optimizer = DistributedOptimizer( + torch.optim.Adam(ddp_gpt.parameters(), lr=0.01), + clip_grad=0.0, + overlap_param_gather=False, + models=[ddp_gpt], + ) + return ddp_gpt, dist_optimizer, data_set + + +def merge_optimizer_states(states): + merged_kvs = {} + # Use length directly instead of world size + # Because we may merge it on CPU + memory using one process + state_length = len(states) + for s_dict in states: + s_dict[torch.float32] = flatten_dict(s_dict[torch.float32]) + + for s_dict in states: + for k, v in s_dict[torch.float32].items(): + if "step" not in k: + cross_dp = False + for rank in range(state_length): + if k in states[rank][torch.float32] and states[rank][torch.float32][k].dp_ranks_ranges: + cross_dp = True + break + + if not cross_dp: + assert v.dp_ranks_ranges is None + if k not in merged_kvs: + merged_kvs[k] = torch.zeros(v.global_shape, dtype=v.local_tensor.dtype) + + if len(v.global_shape) == 1: + merged_kvs[k][v.global_offset[0] : v.global_offset[0] + v.local_shape[0],] = ( + v.local_tensor.view(v.local_shape) + ) + elif len(v.global_shape) == 2: + merged_kvs[k][ + v.global_offset[0] : v.global_offset[0] + v.local_shape[0], + v.global_offset[1] : v.global_offset[1] + v.local_shape[1], + ] = v.local_tensor.view(v.local_shape) + else: + if k not in merged_kvs: + # Two stage merging: + # Stage 1: merge tensors with different dp and same tp + + # Create tp sharded tensors + # Key: global offset + # Value: tensor after tp sharding + tp_offset_shape = {} + tp_sharded_tensors = {} + for rank in range(state_length): + if k in states[rank][torch.float32]: + state_on_dp = states[rank][torch.float32][k] + range_1d = state_on_dp.dp_ranks_ranges[rank] + + if state_on_dp.global_offset not in tp_sharded_tensors: + tp_sharded_tensors[state_on_dp.global_offset] = torch.zeros( + (math.prod(state_on_dp.local_shape),), dtype=state_on_dp.local_tensor.dtype + ) + tp_offset_shape[state_on_dp.global_offset] = state_on_dp.local_shape + + tp_sharded_tensors[state_on_dp.global_offset][range_1d.start : range_1d.end] = ( + state_on_dp.local_tensor + ) + + # Stage 2: merge tensors with different tp + merged_kvs[k] = torch.zeros(v.global_shape, dtype=v.local_tensor.dtype) + + for offset, tensor in tp_sharded_tensors.items(): + shape = tp_offset_shape[offset] + if len(v.global_shape) == 1: + merged_kvs[k][offset[0] : offset[0] + shape[0]] = tensor.view(shape) + elif len(v.global_shape) == 2: + merged_kvs[k][offset[0] : offset[0] + shape[0], offset[1] : offset[1] + shape[1]] = ( + tensor.view(shape) + ) + + return merged_kvs + + +def get_open_llama_model(layer_number=None): + if layer_number is None: + model = AutoModelForCausalLM.from_pretrained("openlm-research/open_llama_7b") + else: + model = AutoModelForCausalLM.from_pretrained("openlm-research/open_llama_7b", num_hidden_layers=layer_number) + docoder = model.model + return docoder, model.config + + +# forward resharding plan for a single open llama decoder +_decoder_fwd_resharding_plan = { + "input": {"hidden_states": [Shard(1)], "attention_mask": [Replicate()], "position_ids": [Replicate()]}, + # atten + "self_attn.input": {"hidden_states": [Replicate()], "attention_mask": [Replicate()], "position_ids": [Replicate()]}, + "self_attn.o_proj.output": [[Shard(1)]], + "self_attn.output": [[Shard(1)], None, None], + # feedforward(mlp) + "mlp.input": [[Replicate()]], + "mlp.output": [[Shard(1)]], + "output": [[Shard(1)], None], +} + +# parameter sharding plan for a single open llama decoder +_decoder_param_sharding_plan = { + # atten weight, no bias + "self_attn.q_proj.weight": [Shard(0)], + "self_attn.k_proj.weight": [Shard(0)], + "self_attn.v_proj.weight": [Shard(0)], + "self_attn.o_proj.weight": [Shard(1)], + # feedforward(mlp) + "mlp.up_proj.weight": [Shard(0)], + "mlp.gate_proj.weight": [Shard(0)], + "mlp.down_proj.weight": [Shard(1)], +} + +# forward resharding plan for the whole open llama model +model_fwd_resharding_plan = { + ".input": [[Replicate()]], + "embed_tokens.output": [[Shard(1)]], + "norm.input": [[Shard(1)]], + ".output": { + "last_hidden_state": [Replicate()], + }, + **{rf"layers.\d+.{k}": v for k, v in _decoder_fwd_resharding_plan.items()}, +} + +# model parameter sharding plan for the whole open llama model +model_param_sharding_plan = { + "embed_tokens.weight": [Shard(1)], + **{rf"layers.\d+.{k}": v for k, v in _decoder_param_sharding_plan.items()}, +} + +sharding_plan = {"parameter": model_param_sharding_plan, "forward": model_fwd_resharding_plan} + + +def get_open_llama_model_optimizer(dp_size, tp_size, layer_number=None): + device_mesh = init_device_mesh( + "cuda", + ( + dp_size, + tp_size, + ), + mesh_dim_names=("DP", "TP"), + ) + device_mesh.__enter__() + # Set 4 layers to avoid timeout on CI + # Use 32 layers when running on training platform + vescale_decoder, config = get_open_llama_model(layer_number=layer_number) + + vescale_decoder = parallelize_module( + vescale_decoder, + device_mesh["TP"], + sharding_plan, + ) + + ddp_decoder = DDP( + vescale_decoder, + data_pg_or_device_mesh=device_mesh["DP"], + accumulate_allreduce_grads_in_fp32=True, + overlap_grad_reduce=False, + use_distributed_optimizer=True, + ) + + ve_optimizer = DistributedOptimizer( + torch.optim.Adam(ddp_decoder.parameters(), lr=0.01), + clip_grad=0.0, + overlap_param_gather=False, + models=[ddp_decoder], + ) + return ddp_decoder, ve_optimizer, config diff --git a/test/checkpoint/nano_gpt.py b/test/checkpoint/nano_gpt.py new file mode 100644 index 0000000..bbe8cb9 --- /dev/null +++ b/test/checkpoint/nano_gpt.py @@ -0,0 +1,422 @@ +################################################################################ +# MIT License +# +# Copyright (c) 2022 Andrej Karpathy +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +"""Source: https://github.com/karpathy/nanoGPT/blob/master/model.py commit: f08abb4""" + +import math +import inspect +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +class LayerNorm(nn.Module): + """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" + + def __init__(self, ndim, bias): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + + def forward(self, input): + return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) + + +class CausalSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # + + + key, query, value projections in separation below + + + + # self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # + + + key, query, value projections in separation above + + + + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") + if not self.flash: + print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "bias", + torch.tril(torch.ones(config.block_size, config.block_size)).view( + 1, 1, config.block_size, config.block_size + ), + ) + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # + + + calculate query, key, values in separation below + + + + # q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x) + # + + + calculate query, key, values in separation above + + + + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True + ) + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.gelu = nn.GELU() + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x): + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + + +class Block(nn.Module): + def __init__(self, config): + super().__init__() + self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) + self.attn = CausalSelfAttention(config) + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.mlp = MLP(config) + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +@dataclass +class GPTConfig: + block_size: int = 1024 + vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.0 + bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + + +class GPT(nn.Module): + def __init__(self, config): + super().__init__() + assert config.vocab_size is not None + assert config.block_size is not None + self.config = config + + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.vocab_size, config.n_embd), + wpe=nn.Embedding(config.block_size, config.n_embd), + drop=nn.Dropout(config.dropout), + h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), + ln_f=LayerNorm(config.n_embd, bias=config.bias), + ) + ) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + # with weight tying when using torch.compile() some warnings get generated: + # "UserWarning: functional_call was passed multiple values for tied weights. + # This behavior is deprecated and will be an error in future versions" + # not 100% sure what this is, so far seems to be harmless. TODO investigate + self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying + + # init all weights + self.apply(self._init_weights) + # apply special scaled init to the residual projections, per GPT-2 paper + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) + + # report number of parameters + print("number of parameters: %.2fM" % (self.get_num_params() / 1e6)) # noqa: UP031 + + def get_num_params(self, non_embedding=True): + """ + Return the number of parameters in the model. + For non-embedding count (default), the position embeddings get subtracted. + The token embeddings would too, except due to the parameter sharing these + params are actually used as weights in the final layer, so we include them. + """ + n_params = sum(p.numel() for p in self.parameters()) + if non_embedding: + n_params -= self.transformer.wpe.weight.numel() + return n_params + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward(self, idx, targets=None): + device = idx.device + b, t = idx.size() + assert ( + t <= self.config.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) + + # forward the GPT model itself + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) + x = self.transformer.drop(tok_emb + pos_emb) # (b, t, n_embd) + for block in self.transformer.h: + x = block(x) # (b, t, n_embd) + x = self.transformer.ln_f(x) # (b, t, n_embd) + + if targets is not None: + # if we are given some desired targets also calculate the loss + logits = self.lm_head(x) + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + else: + # inference-time mini-optimization: only forward the lm_head on the very last position + logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim + loss = None + + return logits, loss + + def crop_block_size(self, block_size): + # model surgery to decrease the block size if necessary + # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) + # but want to use a smaller block size for some smaller, simpler model + assert block_size <= self.config.block_size + self.config.block_size = block_size + self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) + for block in self.transformer.h: + if hasattr(block.attn, "bias"): + block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] + + @classmethod + def from_pretrained(cls, model_type, override_args=None): + assert model_type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"} + override_args = override_args or {} # default to empty dict + # only dropout can be overridden see more notes below + assert all(k == "dropout" for k in override_args) + from transformers import GPT2LMHeadModel + + print("loading weights from pretrained gpt: %s" % model_type) + + # n_layer, n_head and n_embd are determined from model_type + config_args = { + "gpt2": dict(n_layer=12, n_head=12, n_embd=768), # 124M params + "gpt2-medium": dict(n_layer=24, n_head=16, n_embd=1024), # 350M params + "gpt2-large": dict(n_layer=36, n_head=20, n_embd=1280), # 774M params + "gpt2-xl": dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params + }[model_type] + print("forcing vocab_size=50257, block_size=1024, bias=True") + config_args["vocab_size"] = 50257 # always 50257 for GPT model checkpoints + config_args["block_size"] = 1024 # always 1024 for GPT model checkpoints + config_args["bias"] = True # always True for GPT model checkpoints + # we can override the dropout rate, if desired + if "dropout" in override_args: + print(f"overriding dropout rate to {override_args['dropout']}") + config_args["dropout"] = override_args["dropout"] + # create a from-scratch initialized minGPT model + config = GPTConfig(**config_args) + model = GPT(config) + sd = model.state_dict() + sd_keys = sd.keys() + # We ignore the qkv length validation + # because veScale use Q, K, and Vmartices + # But original nanoGPT use one matrix to represent them + # We shuold process it manually + sd_keys_for_length = [ + k + for k in sd_keys + if not k.endswith(".q_proj.weight") + and not k.endswith(".k_proj.weight") + and not k.endswith(".v_proj.weight") + and not k.endswith(".q_proj.bias") + and not k.endswith(".k_proj.bias") + and not k.endswith(".v_proj.bias") + ] # discard this mask / buffer, not a param + + # init a huggingface/transformers model + model_hf = GPT2LMHeadModel.from_pretrained(model_type) + sd_hf = model_hf.state_dict() + + # copy while ensuring all of the parameters are aligned and match in names and shapes + sd_keys_hf = sd_hf.keys() + sd_keys_hf = [k for k in sd_keys_hf if not k.endswith(".attn.masked_bias")] # ignore these, just a buffer + sd_keys_hf = [k for k in sd_keys_hf if not k.endswith(".attn.bias")] # same, just the mask (buffer) + # We ignore the qkv length validation + # because veScale use Q, K, and V martices + # But original nanoGPT use one matrix to represent them + # We shuold process it manually + sd_keys_hf_for_length = [ + k for k in sd_keys_hf if not k.endswith(".attn.c_attn.weight") and not k.endswith(".attn.c_attn.bias") + ] + transposed = ["attn.c_attn.weight", "attn.c_proj.weight", "mlp.c_fc.weight", "mlp.c_proj.weight"] + + # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear + # this means that we have to transpose these weights when we import them + assert len(sd_keys_hf_for_length) == len( + sd_keys_for_length + ), f"mismatched keys: {len(sd_keys_hf_for_length)} != {len(sd_keys_for_length)}" + + for k in sd_keys_hf: + if any(k.endswith(w) for w in transposed): + if k.endswith("attn.c_attn.weight"): + # Original nanoGPT QKV weight shape [3d, d] + # Huggingface GPT-2 QKV weight shape [d, 3d] + # So the original logic in "else" below simply + # use transpose to set nanoGPT model atten weight + # We should also do transpose then + # add one more step to get splited q, k, v + # [3d, d] -> [3d, d] + tranposed_atten_weight = sd_hf[k].t() + # Get embedding dimension + n_embd = tranposed_atten_weight.shape[1] + + # Q [0:d, ] in [3d, d] + sd[k.split("c_attn.weight")[0] + "q_proj.weight"].copy_(tranposed_atten_weight[:n_embd, :]) + # K [d:2d, ] in [3d, d] + sd[k.split("c_attn.weight")[0] + "k_proj.weight"].copy_( + tranposed_atten_weight[n_embd : 2 * n_embd, :] + ) + # V [2d:3d, ] in [3d, d] + sd[k.split("c_attn.weight")[0] + "k_proj.weight"].copy_(tranposed_atten_weight[2 * n_embd :, :]) + else: + # special treatment for the Conv1D weights we need to transpose + assert sd_hf[k].shape[::-1] == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k].t()) + else: + if k.endswith("attn.c_attn.bias"): + n_embd = sd_hf[k].shape[0] // 3 + # Q [0:d, ] in [3d] + sd[k.split("c_attn.bias")[0] + "q_proj.bias"].copy_(sd_hf[k][:n_embd]) + # K [d:2d] in [3d] + sd[k.split("c_attn.bias")[0] + "k_proj.bias"].copy_(sd_hf[k][n_embd : 2 * n_embd]) + # V [2d:3d, ] in [3d] + sd[k.split("c_attn.bias")[0] + "k_proj.bias"].copy_(sd_hf[k][2 * n_embd :]) + else: + # vanilla copy over the other parameters + assert sd_hf[k].shape == sd[k].shape + with torch.no_grad(): + sd[k].copy_(sd_hf[k]) + + return model + + def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): + # start with all of the candidate parameters + param_dict = {pn: p for pn, p in self.named_parameters()} # noqa: C416 + # filter out those that do not require grad + param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} + # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. + # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. + decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] + nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] + optim_groups = [ + {"params": decay_params, "weight_decay": weight_decay}, + {"params": nodecay_params, "weight_decay": 0.0}, + ] + num_decay_params = sum(p.numel() for p in decay_params) + num_nodecay_params = sum(p.numel() for p in nodecay_params) + print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") + print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + # Create AdamW optimizer and use the fused version if it is available + fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and device_type == "cuda" + extra_args = dict(fused=True) if use_fused else dict() + optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + print(f"using fused AdamW: {use_fused}") + + return optimizer + + def estimate_mfu(self, fwdbwd_per_iter, dt): + """estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS""" + # first estimate the number of flops we do per iteration. + # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 + N = self.get_num_params() + cfg = self.config + L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd // cfg.n_head, cfg.block_size + flops_per_token = 6 * N + 12 * L * H * Q * T + flops_per_fwdbwd = flops_per_token * T + flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter + # express our flops throughput as ratio of A100 bfloat16 peak flops + flops_achieved = flops_per_iter * (1.0 / dt) # per second + flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS + mfu = flops_achieved / flops_promised + return mfu + + @torch.no_grad() + def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): + """ + Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + for _ in range(max_new_tokens): + # if the sequence context is growing too long we must crop it at block_size + idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size :] + # forward the model to get the logits for the index in the sequence + logits, _ = self(idx_cond) + # pluck the logits at the final step and scale by desired temperature + logits = logits[:, -1, :] / temperature + # optionally crop the logits to only the top k options + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float("Inf") + # apply softmax to convert logits to (normalized) probabilities + probs = F.softmax(logits, dim=-1) + # sample from the distribution + idx_next = torch.multinomial(probs, num_samples=1) + # append sampled index to the running sequence and continue + idx = torch.cat((idx, idx_next), dim=1) + + return idx diff --git a/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py b/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py new file mode 100644 index 0000000..70880de --- /dev/null +++ b/test/checkpoint/nano_gpt/test_nano_gpt_load_save.py @@ -0,0 +1,130 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +import torch +import torch.distributed as dist +from common_dtensor import DTensorTestBase, with_comms, skip_unless_torch_gpu +from torch.testing._internal.common_utils import run_tests +from vescale.dtensor.device_mesh import mesh_resources +import vescale +from vescale.dtensor.placement_types import Replicate + + +from checkpoint.common_func import build_gpt_model_optimizer_and_dataset, flatten_dict + +TMP_CKPT_DIR = "open_source_gpt_load_save_checkpoint_dir" + + +class TestNanoGPT1(DTensorTestBase): + @property + def world_size(self): + return 4 + + @property + def init_method(self): + # If the value is "scratch", the GPT is trained from scratch + # It the value is "gpt2", "gpt2-medium", "gpt2-large", or "gpt2-xl" + # the GPT loads pretrained weights from OpenAI GPT2 repository on Huggingface + return "scratch" + + @skip_unless_torch_gpu + @with_comms + def test_save(self): + ddp_gpt, dist_optimizer, data_set = build_gpt_model_optimizer_and_dataset( + self.init_method, dp_size=2, tp_size=2 + ) + device_mesh = mesh_resources.get_current_mesh() + # Do fwd+bwd+step on the first data + for X, Y in data_set[:1]: + input = vescale.distribute_tensor(X, device_mesh["TP"], [Replicate()]) + output = vescale.distribute_tensor(Y, device_mesh["TP"], [Replicate()]) + dist_optimizer.zero_grad() + _, output = ddp_gpt(input, output) + loss = output.mean() + loss.backward() + ddp_gpt.finish_grad_sync() + dist_optimizer.step() + + # Save the model and optimizer before second data foward + + # OmniStore Style API + ckpt_state = {"model": ddp_gpt, "optimizer": dist_optimizer} + vescale.checkpoint.save(TMP_CKPT_DIR, ckpt_state) + + # Dump model state_dict + dumped_model_sd = {} + for k, v in ddp_gpt.state_dict().items(): + dumped_model_sd[k] = v._local_tensor + torch.save(dumped_model_sd, f"gpt_load_save_model_{dist.get_rank()}.pt") + # Dump optimizer state_dict + optimizer_state = dist_optimizer.state_dict() + + dumped_optimizer_sd = {} + for k, v in flatten_dict(optimizer_state[torch.float32]).items(): + if "step" not in k: + dumped_optimizer_sd[k] = v.local_tensor + + torch.save(dumped_optimizer_sd, f"gpt_load_save_optimizer_{dist.get_rank()}.pt") + + +class TestNanoGPT2(DTensorTestBase): + @property + def world_size(self): + return 4 + + @property + def init_method(self): + # If the value is "scratch", the GPT is trained from scratch + # It the value is "gpt2", "gpt2-medium", "gpt2-large", or "gpt2-xl" + # the GPT loads pretrained weights from OpenAI GPT2 repository on Huggingface + return "scratch" + + @skip_unless_torch_gpu + @with_comms + def test_load(self): + ddp_gpt, dist_optimizer, _ = build_gpt_model_optimizer_and_dataset( + self.init_method, dp_size=2, tp_size=2 + ) + + # Load the model and optimizer after first data + + # OmniStore Style API + # One line function, model and optimizer will be loaded automatically + ckpt_state = {"model": ddp_gpt, "optimizer": dist_optimizer} + vescale.checkpoint.load(TMP_CKPT_DIR, ckpt_state) + + # Load model state dict and verify it + dumped_model_sd = torch.load(f"gpt_load_save_model_{dist.get_rank()}.pt") + + current_model_sd = ddp_gpt.state_dict() + for k, v in current_model_sd.items(): + if not torch.allclose(dumped_model_sd[k], v._local_tensor): + print(f"k={k} truth={dumped_model_sd[k]} tensor in current model={v}") + raise AssertionError() + + # Load optimizer state dict and verfify + dumped_optimizer_sd = torch.load(f"gpt_load_save_optimizer_{dist.get_rank()}.pt") + + current_optimizer_sd = dist_optimizer.state_dict() + for k, v in flatten_dict(current_optimizer_sd[torch.float32]).items(): + if "step" not in k: + if not torch.allclose(dumped_optimizer_sd[k], v.local_tensor): + print(f"k={k} truth={dumped_optimizer_sd[k]} tensor in optim={v.local_tensor}") + raise AssertionError() + + +if __name__ == "__main__": + run_tests() diff --git a/test/checkpoint/open_llama/__init__.py b/test/checkpoint/open_llama/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/checkpoint/open_llama/test_open_llama_dp_reshard.py b/test/checkpoint/open_llama/test_open_llama_dp_reshard.py new file mode 100644 index 0000000..9ac7fb3 --- /dev/null +++ b/test/checkpoint/open_llama/test_open_llama_dp_reshard.py @@ -0,0 +1,129 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import torch +import torch.distributed as dist +from torch.testing._internal.common_utils import run_tests + +from common_dtensor import DTensorTestBase, with_comms +from vescale.dtensor.device_mesh import mesh_resources +import vescale + +from ..common_func import merge_optimizer_states, get_open_llama_model_optimizer + +TMP_CKPT_DIR = "./open_llama_dp_reshard_checkpoint_dir" +NUM_OF_LAYERS = 4 # Limit number of transformer layers to avoid OOM in unit tests + + +class OpenLLaMa2Test1(DTensorTestBase): + @property + def world_size(self): + return 4 + + @with_comms + def test_open_llama2_with_ddp(self): + bsz = 6 + s = 18 + ddp_decoder, ve_optimizer, config = get_open_llama_model_optimizer( + dp_size=2, tp_size=2, layer_number=NUM_OF_LAYERS + ) + input = torch.randint(low=0, high=config.vocab_size, size=(bsz, s)).cuda() + # d_input = distribute_tensor(input.detach(), tp_sub_mesh, [Shard(1)]) + + ve_optimizer.zero_grad() + vescale_output = ddp_decoder(input.detach()).last_hidden_state + # vescale_output = vescale_output.redistribute(placements = [Replicate()]* tp_sub_mesh.ndim) + vescale_loss = vescale_output.mean() + vescale_loss.backward() + ddp_decoder.finish_grad_sync() + ve_optimizer.step() + + ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer} + vescale.checkpoint.save(TMP_CKPT_DIR, ckpt_state) + device_mesh = mesh_resources.get_current_mesh() + dp_device_mesh = device_mesh["DP"] + dp_process_group = dp_device_mesh.get_dim_groups(0) + tp_device_mesh = device_mesh["TP"] + tp_process_group = tp_device_mesh.get_dim_groups(0) + # For processes with dp_rank = 0, dump model state_dict + if dist.get_rank(dp_process_group) == 0: + dumped_model_sd = {} + for k, v in ddp_decoder.state_dict().items(): + dumped_model_sd[k] = v._local_tensor + torch.save(dumped_model_sd, f"open_llama_dp_reshard_model_tp_{dist.get_rank(tp_process_group)}.pt") + + # Save merged optimizer state dict + optimizer_state = ve_optimizer.state_dict() + states = [{} for _ in range(dist.get_world_size())] + dist.all_gather_object(states, optimizer_state) + + # Merge optimizer state dictionary + if dist.get_rank() == 0: + merged_kvs = merge_optimizer_states(states) + + torch.save(merged_kvs, "open_llama_dp_reshard_merged_optim_state_dict.pt") + dist.barrier() + + +class OpenLLaMa2Test2(DTensorTestBase): + @property + def world_size(self): + return 8 + + @with_comms + def test_open_llama2_with_ddp(self): + ddp_decoder, ve_optimizer, _ = get_open_llama_model_optimizer(dp_size=4, tp_size=2, layer_number=NUM_OF_LAYERS) + + ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer} + vescale.checkpoint.load(TMP_CKPT_DIR, ckpt_state) + device_mesh = mesh_resources.get_current_mesh() + tp_device_mesh = device_mesh["TP"] + tp_process_group = tp_device_mesh.get_dim_groups(0) + # Load model state dict and verify it + dumped_model_sd = torch.load( + f"open_llama_dp_reshard_model_tp_{dist.get_rank(tp_process_group)}.pt", map_location="cpu" + ) + + current_model_sd = ddp_decoder.state_dict() + for k, v in current_model_sd.items(): + if not torch.allclose(dumped_model_sd[k], v._local_tensor.cpu()): + print(f"k={k} truth={dumped_model_sd[k]} tensor in current model={v}") + raise AssertionError() + + # Merge optimizer state dict and verify it + current_optim_sd = ve_optimizer.state_dict() + states = [{} for _ in range(dist.get_world_size())] + dist.all_gather_object(states, current_optim_sd) + + if dist.get_rank() == 0: + dumped_optim_sd = torch.load("open_llama_dp_reshard_merged_optim_state_dict.pt") + dumped_optim_keys = dumped_optim_sd.keys() + current_merged_kvs = merge_optimizer_states(states) + for k, v in current_merged_kvs.items(): + if k not in dumped_optim_keys: + print(f"key={k} in current_merged_kvs is not in dumped_optim_keys") + raise AssertionError() + + if not torch.allclose(dumped_optim_sd[k].cpu(), v.cpu()): + print(f"k={k} truth={dumped_model_sd[k]} tensor in current optim={v}") + raise AssertionError() + + dist.barrier() + + +if __name__ == "__main__": + run_tests() diff --git a/test/checkpoint/open_llama/test_open_llama_load_save.py b/test/checkpoint/open_llama/test_open_llama_load_save.py new file mode 100644 index 0000000..72bb870 --- /dev/null +++ b/test/checkpoint/open_llama/test_open_llama_load_save.py @@ -0,0 +1,107 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import torch +import torch.distributed as dist +from torch.testing._internal.common_utils import run_tests + +from common_dtensor import DTensorTestBase, with_comms + +import vescale + +from ..common_func import flatten_dict, get_open_llama_model_optimizer + +TMP_CKPT_DIR = "./open_llama_load_save_checkpoint_dir" +NUM_OF_LAYERS = 4 # Limit number of transformer layers to avoid OOM in unit tests + + +class OpenLLaMa2Test1(DTensorTestBase): + @property + def world_size(self): + return 8 + + @with_comms + def test_open_llama2_with_ddp(self): + bsz = 6 + s = 18 + ddp_decoder, ve_optimizer, config = get_open_llama_model_optimizer( + dp_size=2, tp_size=4, layer_number=NUM_OF_LAYERS + ) + input = torch.randint(low=0, high=config.vocab_size, size=(bsz, s)).cuda() + # d_input = distribute_tensor(input.detach(), tp_sub_mesh, [Shard(1)]) + + ve_optimizer.zero_grad() + vescale_output = ddp_decoder(input.detach()).last_hidden_state + # vescale_output = vescale_output.redistribute(placements = [Replicate()]* tp_sub_mesh.ndim) + vescale_loss = vescale_output.mean() + vescale_loss.backward() + ddp_decoder.finish_grad_sync() + ve_optimizer.step() + + ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer} + vescale.checkpoint.save(TMP_CKPT_DIR, ckpt_state) + + # Dump model state_dict + dumped_model_sd = {} + for k, v in ddp_decoder.state_dict().items(): + dumped_model_sd[k] = v._local_tensor + torch.save(dumped_model_sd, f"open_llama_load_save_model_{dist.get_rank()}.pt") + # Dump optimizer state_dict + optimizer_state = ve_optimizer.state_dict() + + dumped_optimizer_sd = {} + for k, v in flatten_dict(optimizer_state[torch.float32]).items(): + if "step" not in k: + dumped_optimizer_sd[k] = v.local_tensor + + torch.save(dumped_optimizer_sd, f"open_llama_load_save_optimizer_{dist.get_rank()}.pt") + + +class OpenLLaMa2Test2(DTensorTestBase): + @property + def world_size(self): + return 8 + + @with_comms + def test_open_llama2_with_ddp(self): + ddp_decoder, ve_optimizer, _ = get_open_llama_model_optimizer(dp_size=2, tp_size=4, layer_number=NUM_OF_LAYERS) + + ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer} + vescale.checkpoint.load(TMP_CKPT_DIR, ckpt_state) + + # Load model state dict and verify it + dumped_model_sd = torch.load(f"open_llama_load_save_model_{dist.get_rank()}.pt") + + current_model_sd = ddp_decoder.state_dict() + for k, v in current_model_sd.items(): + if not torch.allclose(dumped_model_sd[k], v._local_tensor): + print(f"k={k} truth={dumped_model_sd[k]} tensor in current model={v}") + raise AssertionError() + + # Load optimizer state dict and verify + dumped_optimizer_sd = torch.load(f"open_llama_load_save_optimizer_{dist.get_rank()}.pt") + + current_optimizer_sd = ve_optimizer.state_dict() + for k, v in flatten_dict(current_optimizer_sd[torch.float32]).items(): + if "step" not in k: + if not torch.allclose(dumped_optimizer_sd[k], v.local_tensor): + print(f"k={k} truth={dumped_optimizer_sd[k]} tensor in optim={v.local_tensor}") + raise AssertionError() + + +if __name__ == "__main__": + run_tests() diff --git a/test/checkpoint/open_llama/test_open_llama_tp_reshard.py b/test/checkpoint/open_llama/test_open_llama_tp_reshard.py new file mode 100644 index 0000000..f617ce9 --- /dev/null +++ b/test/checkpoint/open_llama/test_open_llama_tp_reshard.py @@ -0,0 +1,129 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import torch +import torch.distributed as dist +from torch.testing._internal.common_utils import run_tests + +from common_dtensor import DTensorTestBase, with_comms + + +import vescale + +from ..common_func import merge_optimizer_states, get_open_llama_model_optimizer + +TMP_CKPT_DIR = "./open_llama_tp_reshard_checkpoint_dir" +NUM_OF_LAYERS = 4 # Limit number of transformer layers to avoid OOM in unit tests + + +class OpenLLaMa2Test1(DTensorTestBase): + @property + def world_size(self): + return 4 + + @with_comms + def test_open_llama2_with_ddp(self): + bsz = 6 + s = 18 + ddp_decoder, ve_optimizer, config = get_open_llama_model_optimizer( + dp_size=2, tp_size=2, layer_number=NUM_OF_LAYERS + ) + input = torch.randint(low=0, high=config.vocab_size, size=(bsz, s)).cuda() + # d_input = distribute_tensor(input.detach(), tp_sub_mesh, [Shard(1)]) + + ve_optimizer.zero_grad() + vescale_output = ddp_decoder(input.detach()).last_hidden_state + # vescale_output = vescale_output.redistribute(placements = [Replicate()]* tp_sub_mesh.ndim) + vescale_loss = vescale_output.mean() + vescale_loss.backward() + ddp_decoder.finish_grad_sync() + ve_optimizer.step() + + ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer} + vescale.checkpoint.save(TMP_CKPT_DIR, ckpt_state) + + # Merge model state dictionary and save it + # full_tensor contains gather operations + # so it must be called on all ranks + + dumped_model_sd = {} + model_state_dict = ddp_decoder.state_dict() + for k, v in model_state_dict.items(): + dumped_model_sd[k] = v.full_tensor() + if dist.get_rank() == 0: + torch.save(dumped_model_sd, "open_llama_tp_reshard_merged_model.pt") + dist.barrier() + + # Dump optimizer state_dict + optimizer_state = ve_optimizer.state_dict() + states = [{} for _ in range(dist.get_world_size())] + dist.all_gather_object(states, optimizer_state) + + # Merge optimizer state dictionary + if dist.get_rank() == 0: + merged_kvs = merge_optimizer_states(states) + + torch.save(merged_kvs, "open_llama_tp_reshard_merged_optim_state_dict.pt") + dist.barrier() + + +class OpenLLaMa2Test2(DTensorTestBase): + @property + def world_size(self): + return 8 + + @with_comms + def test_open_llama2_with_ddp(self): + ddp_decoder, ve_optimizer, _ = get_open_llama_model_optimizer(dp_size=2, tp_size=4, layer_number=NUM_OF_LAYERS) + + ckpt_state = {"model": ddp_decoder, "optimizer": ve_optimizer} + vescale.checkpoint.load(TMP_CKPT_DIR, ckpt_state) + # Load model state dict and verify it + dumped_model_sd = torch.load("open_llama_tp_reshard_merged_model.pt", map_location="cpu") + + current_model_sd = {} + model_state_dict = ddp_decoder.state_dict() + for k, v in model_state_dict.items(): + current_model_sd[k] = v.full_tensor() + for k, v in current_model_sd.items(): + if not torch.allclose(dumped_model_sd[k], v.cpu()): + print(f"k={k} truth={dumped_model_sd[k]} tensor in current model={v}") + raise AssertionError() + + # Merge optimizer state dict and verify it + current_optim_sd = ve_optimizer.state_dict() + states = [{} for _ in range(dist.get_world_size())] + dist.all_gather_object(states, current_optim_sd) + + if dist.get_rank() == 0: + dumped_optim_sd = torch.load("open_llama_tp_reshard_merged_optim_state_dict.pt") + dumped_optim_keys = dumped_optim_sd.keys() + current_merged_kvs = merge_optimizer_states(states) + for k, v in current_merged_kvs.items(): + if k not in dumped_optim_keys: + print(f"key={k} in current_merged_kvs is not in dumped_optim_keys") + raise AssertionError() + + if not torch.allclose(dumped_optim_sd[k].cpu(), v.cpu()): + print(f"k={k} truth={dumped_model_sd[k]} tensor in current optim={v}") + raise AssertionError() + + dist.barrier() + + +if __name__ == "__main__": + run_tests()