Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-WAIFU committed Oct 14, 2024
1 parent f7eee21 commit f54e0e4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
5 changes: 4 additions & 1 deletion megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
OKAY = f"{GREEN}[OKAY]{END}"
WARNING = f"{YELLOW}[WARNING]{END}"
FAIL = f"{RED}[FAIL]{END}"
ERROR = f"{RED}[ERROR]{END}"
INFO = "[INFO]"

# ZERO defaults by deespeed
Expand Down Expand Up @@ -875,16 +876,17 @@ def calculate_derived(self):
"""
Derives additional configuration values necessary for training from the current config
"""

# number of gpus
# Get number of GPUs param or hostfile to determine train_batch_size
global_num_gpus = getattr(self, "global_num_gpus", None)
if global_num_gpus is None:
if self.hostfile is not None or os.path.exists(DLTS_HOSTFILE):
hostfile_path = self.hostfile or DLTS_HOSTFILE
print(hostfile_path, self.include, self.exclude)
resources = obtain_resource_pool(
hostfile_path, self.include or "", self.exclude or ""
)
print(resources)
if self.num_nodes is not None and self.num_nodes > 0:
resources = {
k: resources[k]
Expand All @@ -896,6 +898,7 @@ def calculate_derived(self):
else:
global_num_gpus = torch.cuda.device_count()
self.update_value("global_num_gpus", global_num_gpus)


logging.info(
self.__class__.__name__
Expand Down
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def get_test_path(filename):
def model_setup(yaml_list=None, param_dict=None, clear_data=True):
from megatron.neox_arguments import NeoXArgs
from megatron.mpu import destroy_model_parallel
from megatron import initialize_megatron
from megatron.initialize import initialize_megatron
from megatron.training import setup_model_and_optimizer

destroy_model_parallel() # mpu model parallel contains remaining global vars
Expand Down
3 changes: 2 additions & 1 deletion tests/model/test_model_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tests.common import DistributedTest, model_setup, parametrize

PARAMS_TO_TEST = {
"include":["localhost:0,1"],
"pipe_parallel_size,model_parallel_size,world_size": [
[0, 1, 1],
[0, 1, 2],
Expand Down Expand Up @@ -73,7 +74,7 @@ def test_train(param_dict):
class run_generate_test_class(DistributedTest):
world_size = 2

def run_generate_test(param_dict, prompt):
def run_generate_test(self, param_dict, prompt):
from megatron.text_generation_utils import generate_samples_from_prompt
from megatron.utils import is_mp_rank_0

Expand Down

0 comments on commit f54e0e4

Please sign in to comment.