Skip to content

Commit

Permalink
consistent failure
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-WAIFU committed Oct 17, 2024
1 parent 1745ffb commit afe2c40
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 45 deletions.
54 changes: 54 additions & 0 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import sys
import os

import torch

Expand All @@ -26,6 +27,7 @@
import math


'''
class Tee:
"""Duplicate output to both stdout/err and file"""
Expand Down Expand Up @@ -61,6 +63,58 @@ def flush(self) -> None:
self.file.flush()
except OSError:
pass
'''

class Tee:
"""Duplicate output to both stdout/err and file"""

def __init__(self, file, err: bool = False) -> None:
self.err = err
self.std = sys.stderr if err else sys.stdout

if isinstance(file, str):
try:
# Ensure the directory exists if file is a path
os.makedirs(os.path.dirname(file), exist_ok=True)
self.file = open(file, "w")
except IOError as e:
print(f"Warning: Could not open file {file} for writing. {str(e)}", file=self.std)
self.file = None
elif hasattr(file, 'write') and hasattr(file, 'flush'):
# If it's a file-like object, use it directly
self.file = file
else:
raise ValueError("'file' must be either a file path or a file-like object")

if not err:
sys.stdout = self
else:
sys.stderr = self

def __del__(self) -> None:
if not self.err:
sys.stdout = self.std
else:
sys.stderr = self.std

if self.file and hasattr(self.file, 'close'):
self.file.close()

def write(self, data) -> None:
self.std.write(data)
if self.file:
try:
self.file.write(data)
except IOError as e:
print(f"Warning: Could not write to file. {str(e)}", file=self.std)

def flush(self) -> None:
self.std.flush()
if self.file:
try:
self.file.flush()
except IOError as e:
print(f"Warning: Could not flush file. {str(e)}", file=self.std)


def human_readable_flops(num) -> str:
Expand Down
9 changes: 3 additions & 6 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ def set_accelerator_visible():
def count_gpus():
global _num_gpus
if _num_gpus is None:
import subprocess

nvidia_smi = subprocess.check_output(["nvidia-smi", "--list-gpus"])
_num_gpus = len(nvidia_smi.decode("utf-8").strip().split("\n"))
return _num_gpus
Expand All @@ -146,8 +144,6 @@ def set_cuda_visibile():
xdist_worker_id = 0
if cuda_visible is None:
# CUDA_VISIBLE_DEVICES is not set, discover it from nvidia-smi instead
import subprocess

nvidia_smi = subprocess.check_output(["nvidia-smi", "--list-gpus"])
num_gpus = len(nvidia_smi.decode("utf-8").strip().split("\n"))
cuda_visible = ",".join(map(str, range(num_gpus)))
Expand Down Expand Up @@ -516,10 +512,11 @@ def model_setup(yaml_list=None, param_dict=None, clear_data=True):
args_loaded.build_tokenizer()

initialize_megatron(neox_args=args_loaded)
model, optimizer, lr_scheduler = setup_model_and_optimizer(
print("YAP")
model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer(
neox_args=args_loaded, use_cache=True
)
return model, optimizer, lr_scheduler, args_loaded
return model, optimizer, lr_scheduler, reference_model, args_loaded


def simulate_deepy_env(monkeypatch, input_args):
Expand Down
5 changes: 0 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ def check_environment(pytestconfig):
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_call(item):
# We want to use our own launching function for distributed tests
print("-------------------------------------------------------------------------")
print(type(item))
func_name = item.function.__name__ if hasattr(item, 'function') else None
print(f"Function name: {func_name}")
print("-------------------------------------------------------------------------")
if getattr(item.cls, "is_dist_test", False):
dist_test_class = item.cls()
dist_test_class(item._request)
Expand Down
49 changes: 17 additions & 32 deletions tests/model/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
import torch

PARAMS_TO_TEST = {
"pipe_parallel_size,model_parallel_size": [[0, 1], [1, 2], [0, 2], [2, 1]],
"include":["localhost:0,1"],
"pipe_parallel_size,model_parallel_size": [[1, 2], [0, 2], [2, 1]],
"checkpoint_validation_with_forward_pass": [True],
"fp16,fp32_allreduce": [
[
Expand Down Expand Up @@ -61,30 +62,22 @@
}

parameters, names = parametrize(
PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None
PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=42
)

class TestModelCheckpoint(DistributedTest):
world_size = 2

@pytest.mark.parametrize("param_dict", parameters, ids=names)
def test_train(param_dict):
import tempfile

d = tempfile.mkdtemp()
param_dict["save"] = d

t1 = test_run_checkpoint_test_class()
t1.run_checkpoint_test(param_dict=param_dict)


class test_run_checkpoint_test_class(DistributedTest):
def run_checkpoint_test(yaml_list=None, param_dict=None):

@pytest.mark.parametrize("param_dict", parameters, ids=names)
def test_checkpoint(self, param_dict, tmpdir):
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
print("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB")

model, optimizer, lr_scheduler, args_loaded = model_setup(
yaml_list, param_dict, clear_data=True
model, optimizer, lr_scheduler, reference_model, args_loaded = model_setup(
yaml_list=None, param_dict=param_dict, clear_data=True
)
print("CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC")

# save model checkpoint
save_checkpoint(
Expand All @@ -101,7 +94,7 @@ def run_checkpoint_test(yaml_list=None, param_dict=None):
reloaded_optimizer,
reloaded_lr_scheduler,
args_reloaded,
) = model_setup(yaml_list, param_dict, clear_data=False)
) = model_setup(yaml_list=None, param_dict=param_dict, clear_data=False)
iteration = load_checkpoint(
neox_args=args_reloaded,
model=reloaded_model,
Expand All @@ -110,9 +103,7 @@ def run_checkpoint_test(yaml_list=None, param_dict=None):
)

# ensure same checkpoint is loaded
assert (
iteration == 42
), "run_checkpoint_test() iteration loaded from checkpoint correct"
assert iteration == 42, "Iteration loaded from checkpoint is incorrect"

# check all weight groups are the same
for idx, ((n1, p1), (n2, p2)) in enumerate(
Expand All @@ -122,14 +113,8 @@ def run_checkpoint_test(yaml_list=None, param_dict=None):
)
):
assert n1 == n2
params_equal = (p1 == p2).all().item()
assert params_equal, "run_checkpoint_test() params equal: " + str(n1)

params_equal = torch.all(p1 == p2).item()
assert params_equal, f"Parameters not equal: {n1}"

if __name__ == "__main__":
params = list(
parametrize(
PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None
)
)
test_train(params[0])
# Clean up
del model, reloaded_model
2 changes: 1 addition & 1 deletion tests/model/test_model_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_generate(self, param_dict, tmpdir):
}

param_dict.update(fixed_params)
model, _, _, args_loaded = model_setup(None, param_dict, clear_data=True)
model, _, _, _, args_loaded = model_setup(None, param_dict, clear_data=True)
model.eval()

prompt = param_dict.pop("prompt")
Expand Down
2 changes: 1 addition & 1 deletion tests/model/test_model_instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class test_instantiate_optimizers_class(DistributedTest):
def run_test_model_instantiation(yaml_list=None, param_dict=None):
from deepspeed.runtime.pipe.engine import PipelineEngine, DeepSpeedEngine

model, optimizer, lr_scheduler, args_loaded = model_setup(yaml_list, param_dict)
model, optimizer, lr_scheduler, reference_model, args_loaded = model_setup(yaml_list, param_dict)
if args_loaded.pipe_parallel_size < 2:
assert isinstance(
model, DeepSpeedEngine
Expand Down

0 comments on commit afe2c40

Please sign in to comment.