Skip to content

Commit

Permalink
Implement neox_args processing when OMPI_COMM_WORLD_SIZE>1
Browse files Browse the repository at this point in the history
  • Loading branch information
kyuheejang committed Nov 7, 2023
1 parent 90aa131 commit 7445c3a
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions tools/ckpts/convert_hf_to_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import argparse
import torch
import base64
import json

import numpy as np

Expand Down Expand Up @@ -395,9 +397,17 @@ def consume_neox_args2(args_parsed, overwrite_values=None):
We then instantiate a new NeoXArgs from the dictionary (`.from_dict`). This should ensure args are never inconsistent across machines.
"""
def decode_base64_and_load_json(base64_string):
try:
decoded_string = base64.b64decode(base64_string).decode('utf-8')

with open(args_parsed.megatron_config) as jsonfile:
megatron_config = json.load(jsonfile)
json_data = json.loads(decoded_string)
return json_data
except Exception as e:
print(f"An error occurred: {e}")
return None

megatron_config = decode_base64_and_load_json(args_parsed.megatron_config)
if args_parsed.deepspeed_config is not None:
overwrite_values = NeoXArgs.set_up_autotuning(
args_parsed.deepspeed_config, overwrite_values
Expand Down Expand Up @@ -496,7 +506,7 @@ def get_non_existing_dir(tmp_dir):
# time.sleep(5)

if int(os.environ.get("OMPI_COMM_WORLD_SIZE", 1)) > 1:
neox_args = consume_neox_args2(args2)
neox_args = consume_neox_args2(args)
else:
neox_args = NeoXArgs.from_ymls(args.config)
neox_args.configure_distributed_args()
Expand Down

0 comments on commit 7445c3a

Please sign in to comment.