Skip to content

Commit

Permalink
Merge branch 'microsoft:master' into features/fp6_compile_err
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesTheZ authored Mar 29, 2024
2 parents 833dc58 + ffb53c2 commit fcc5653
Show file tree
Hide file tree
Showing 18 changed files with 386 additions and 107 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/hpu-gaudi2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ on:
pull_request:
paths:
- ".github/workflows/hpu-gaudi2.yml"
- "accelerator/hpu_accelerator.py"


concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
1 change: 0 additions & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from .comm.comm import init_distributed

from .runtime import zero
from .runtime import DeepSpeedOptimizer, ZeROOptimizer
from .runtime.compiler import is_compile_supported

from .pipe import PipelineModule
Expand Down
1 change: 1 addition & 0 deletions deepspeed/checkpoint/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
BASE_OPTIMIZER_STATE = 'base_optimizer_state'
BASE_OPTIMIZER_STATE_STEP = 'base_optimizer_state_step'
SINGLE_PARTITION_OF_FP32_GROUPS = "single_partition_of_fp32_groups"
PARAM_GROUPS = 'param_groups'
GROUP_PADDINGS = 'group_paddings'
PARTITION_COUNT = 'partition_count'
ZERO_STAGE = 'zero_stage'
Expand Down
54 changes: 41 additions & 13 deletions deepspeed/checkpoint/ds_to_universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
OPTIMIZER_STATE_DICT,
BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS,
PARAM_GROUPS,
PARAM_SLICE_MAPPINGS,
PARAM_SHAPES,
PARAM,
Expand Down Expand Up @@ -110,6 +111,9 @@ def extract_zero_shards(dir, ds_checkpoint, indices_3D):
fp32=fp32_groups[param_group_id],
)

if "step" in state_groups[param_group_id]:
flat_state["step"] = state_groups[param_group_id]["step"]

for name, fragment_mapping in param_slice_mappings[param_group_id].items():
if pp_index > 0 and any(re.match(pattern, name) for pattern in pipeline_replicated_params):
# Skip tied weights that are replicated in first and last pp stages
Expand Down Expand Up @@ -138,17 +142,28 @@ def dump_param_fragment(dir, tp_index, dp_index, state_name, state_flat_tensor,

#print(f"{param_name}: {offset}: {numel} => {path}")

t = state_flat_tensor.narrow(0, offset, numel).clone()
_save_checkpoint(path, t)
# State might be a python int or a tensor
if state_name != "step" and torch.is_tensor(state_flat_tensor):
state_flat_tensor = state_flat_tensor.narrow(0, offset, numel).clone()
_save_checkpoint(path, state_flat_tensor)


def _merge_zero_shards(param_base_path, state, tp_degree, slice_shape):
slices = []
for tp_index in range(tp_degree):
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
paths = sorted(list(glob.glob(f"{prefix_path}.*")))
if len(paths) == 0:
continue

shards = [torch.load(p) for p in paths]
slice = torch.cat(shards, dim=0).reshape(slice_shape)

if state == "step":
assert all(v == shards[0] for v in shards), "All shards must have the same step value"
slice = shards[0]
else:
slice = torch.cat(shards, dim=0).reshape(slice_shape)

slices.append(slice)
return slices

Expand Down Expand Up @@ -177,6 +192,10 @@ def get_matched_pattern(patterns_, name_):
return pattern_
return None

step_merged = _merge_zero_shards(slice_base_path, "step", tp_degree, shape)
if step_merged:
_save_checkpoint(os.path.join(param_base_path, f"step.pt"), step_merged[0])

for state in ("fp32", "exp_avg", "exp_avg_sq"):
slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape)
final_path = os.path.join(param_base_path, f"{state}.pt")
Expand Down Expand Up @@ -227,13 +246,21 @@ def _get_chunks(l, n):


def _do_parallel_work(do_work, work_chunks, num_workers):
pool = multiprocessing.Pool(num_workers)
results = []
for batch in tqdm.tqdm(work_chunks):
res = pool.map(do_work, batch)
results.extend(res)
pool.close()
pool.join()
if num_workers > 1:
pool = multiprocessing.Pool(num_workers)
results = []
for batch in tqdm.tqdm(work_chunks):
res = pool.map(do_work, batch)
results.extend(res)
pool.close()
pool.join()
else:
# No parallel pass for unit testing
# We can't create child processes in tests
results = []
for batch in tqdm.tqdm(work_chunks):
res = [do_work(x) for x in batch]
results.extend(res)
return results


Expand Down Expand Up @@ -273,6 +300,7 @@ def _save_optimizer_state(args, ds_checkpoint):

optim_sd = sd[OPTIMIZER_STATE_DICT]
output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states}
output_sd[PARAM_GROUPS] = optim_sd[BASE_OPTIMIZER_STATE][PARAM_GROUPS]
zero_output_folder = os.path.join(args.output_folder, "zero")
output_file_path = os.path.join(zero_output_folder, f"optimizer_state.pt")
_save_checkpoint(output_file_path, output_sd)
Expand All @@ -283,10 +311,9 @@ def _check_for_required_state(ds_checkpoint):
assert universal_checkpoint_info is not None, f'Required {UNIVERSAL_CHECKPOINT_INFO} state is missing in checkpoint. Verify that client creates this state.'


def main():
def main(args):
print(f'Convert DeepSpeed Checkpoint to Universal Checkpoint')

args = parse_arguments()
print(f'Converting DeepSpeed checkpoint in {args.input_folder} to Universal checkpoint in {args.output_folder}')

ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
Expand Down Expand Up @@ -332,4 +359,5 @@ def main():


if __name__ == "__main__":
main()
args = parse_arguments()
main(args)
8 changes: 8 additions & 0 deletions deepspeed/checkpoint/universal_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,15 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
if match:
hp_keys.append(match.group(1))

step = None
for key in hp_keys:
ckpt_file = os.path.join(folder, f"{key}.pt")
ckpt_dict = torch.load(ckpt_file)

if key == "step":
step = ckpt_dict
continue

full_hp_param = ckpt_dict[PARAM]

# need to deal with slices that were averaged.
Expand Down Expand Up @@ -103,6 +109,8 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):

hp_mapping.optim_fragment[key] = tp_hp_fragment.clone().detach()

return step


def enable_universal_checkpoint(param_list):
for param in param_list:
Expand Down
4 changes: 3 additions & 1 deletion deepspeed/checkpoint/zero_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,11 @@ def _strip_tensor_paddings(self, sd):
if group_paddings[key] == 0:
continue
for state_name, state_value in group_state.items():
if torch.is_tensor(state_value):
if state_name != "step" and torch.is_tensor(state_value):
raw_length = state_value.numel() - group_paddings[key]
group_state[state_name] = torch.narrow(state_value, 0, 0, raw_length).clone()
else:
group_state[state_name] = state_value

def _clear_group_paddings(self, sd):
group_paddings = self._get_optimizer_state(sd, GROUP_PADDINGS)
Expand Down
26 changes: 25 additions & 1 deletion deepspeed/ops/transformer/inference/triton/matmul_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,36 @@
import deepspeed
from pathlib import Path
import atexit
import subprocess


# -----------------------------------------------------------------------------
# util class/functions for triton
def is_nfs_path(path):
# Normalize the path to get the absolute path
path = os.path.abspath(path)

# Use the 'df' command to find the file system type for the given path
try:
output = subprocess.check_output(['df', '-T', path], encoding='utf-8')
except subprocess.CalledProcessError:
return False # Command failed

# Process the output of 'df -T' to check for 'nfs' in the filesystem type column
lines = output.strip().split('\n')
if len(lines) > 1: # The first line is headers
fs_type = lines[1].split()[1].lower() # File system type is the second column
return 'nfs' in fs_type
return False


def _default_cache_dir():
return os.path.join(Path.home(), ".triton", "autotune")
tmp_path = os.path.join(Path.home(), ".triton", "autotune")
if is_nfs_path(tmp_path):
print(
f"Warning: The default cache directory for DeepSpeed Triton autotune, {tmp_path}, appears to be on an NFS system. While this is generally acceptable, if you experience slowdowns or hanging when DeepSpeed exits, it is recommended to set the TRITON_CACHE_DIR environment variable to a non-NFS path."
)
return tmp_path


def bias_add_activation(C, bias=None, activation=""):
Expand Down
8 changes: 0 additions & 8 deletions deepspeed/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,3 @@
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team


class DeepSpeedOptimizer(object):
pass


class ZeROOptimizer(DeepSpeedOptimizer):
pass
63 changes: 63 additions & 0 deletions deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import os
import torch

from deepspeed.utils import logger
from deepspeed.utils.tensor_fragment import map_to_flat_opt_states
from deepspeed.runtime.utils import bwc_tensor_model_parallel_rank


class DeepSpeedOptimizer(object):
pass


class ZeROOptimizer(DeepSpeedOptimizer):

def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, checkpoint_dir: str) -> None:
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
optim_state_path = os.path.join(checkpoint_dir, "optimizer_state.pt")
assert os.path.isfile(
optim_state_path), f'{optim_state_path} containing optimizer global state is missing! Cannot proceed.'
optim_sd = torch.load(optim_state_path)

self._load_global_state(optim_sd)

tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
if self.mpu is None:
logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.")
tp_world_size = 1
else:
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
else self.mpu.get_tensor_model_parallel_world_size()

for i, (param_group,
loaded_param_group) in enumerate(zip(self.optimizer.param_groups, optim_sd['param_groups'])):
# We have an assumption that all params in the same param_group have the same keys
opt_keys = set()
steps = []

lp_groups = getattr(self, lp_groups_name)
for lp in lp_groups[i]:
if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
step = lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)
for key in lp._hp_mapping.get_optim_state_keys():
opt_keys.add(key)
steps.append(step)

hp_param = param_group['params'][0]
assert all(step == steps[0] for step in steps), f"Steps {steps} are not equal"
if steps[0] is not None:
self.optimizer.state[hp_param]['step'] = steps[0]

map_to_flat_opt_states(hp_param, lp_groups[i], self.optimizer.state, opt_keys)

for key, value in loaded_param_group.items():
if key == 'params':
continue
param_group[key] = value
29 changes: 7 additions & 22 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@
from collections import OrderedDict
import torch
import sys
import os
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from deepspeed import comm as dist
from deepspeed.runtime.constants import PIPE_REPLICATED
from deepspeed.runtime import ZeROOptimizer
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from packaging import version as pkg_version
from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank,
is_model_parallel_parameter, see_memory_usage, graph_process,
get_norm_with_moe_layers)
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups
from deepspeed.moe.utils import is_moe_param, is_moe_param_group
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups, map_to_flat_opt_states
from deepspeed.checkpoint import enable_universal_checkpoint
from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE,
SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS,
Expand Down Expand Up @@ -493,6 +492,7 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l
self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)

if load_optimizer_states:
print(f"_load_legacy_checkpoint current_rank_sd[BASE_OPTIMIZER_STATE]")
self.optimizer.load_state_dict(current_rank_sd[BASE_OPTIMIZER_STATE])

if load_from_fp32_weights:
Expand All @@ -505,31 +505,16 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l
self._link_all_hp_params()

def _load_universal_checkpoint(self, checkpoint_folder, load_optimizer_states, load_from_fp32_weights):
self._load_hp_checkpoint_state(checkpoint_folder)
self.load_hp_checkpoint_state_from_checkpoint_dir("bf16_groups", checkpoint_folder)

def _load_global_state(self, sd):
pass

@property
def param_groups(self):
"""Forward the wrapped optimizer's parameters."""
return self.optimizer.param_groups

def _load_hp_checkpoint_state(self, checkpoint_dir):
checkpoint_dir = os.path.join(checkpoint_dir, "zero")
tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
tp_world_size = self.mpu.get_slice_parallel_world_size()

for i, param_group in enumerate(self.optimizer.param_groups):
# We have an assumption that all params in the same param_group have the same keys
opt_keys = set()

for lp in self.bf16_groups[i]:
if lp._hp_mapping is not None:
#print(f"Loading {self.param_names[lp]} {tp_rank=} {tp_world_size=}")
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)
for key in lp._hp_mapping.get_optim_state_keys():
opt_keys.add(key)
map_to_flat_opt_states(param_group['params'][0], self.bf16_groups[i], self.optimizer.state, opt_keys)

def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx):
assert self.immediate_grad_update
self._update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=False)
Expand Down
Loading

0 comments on commit fcc5653

Please sign in to comment.