Skip to content

Commit

Permalink
progress?
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-WAIFU committed Oct 17, 2024
1 parent 9e60eec commit 1745ffb
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 14 deletions.
15 changes: 11 additions & 4 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import time
import shutil
import itertools
import inspect
import subprocess
from pathlib import Path
from abc import ABC, abstractmethod
from deepspeed.accelerator import get_accelerator
Expand Down Expand Up @@ -48,6 +50,14 @@
DEEPSPEED_UNIT_WORKER_TIMEOUT = 120
DEEPSPEED_TEST_TIMEOUT = 600

def is_rocm_pytorch():
"""
Check if the current PyTorch installation is using ROCm.
Returns:
bool: True if PyTorch is using ROCm, False otherwise.
"""
return hasattr(torch.version, 'hip') and torch.version.hip is not None

def get_xdist_worker_id():
xdist_worker = os.environ.get("PYTEST_XDIST_WORKER", None)
Expand All @@ -67,7 +77,6 @@ def get_master_port():

_num_gpus = None


def set_accelerator_visible():
cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
xdist_worker_id = get_xdist_worker_id()
Expand Down Expand Up @@ -428,9 +437,7 @@ def test_2(self, val1, val2, val3, val4):
assert int(os.environ["WORLD_SIZE"]) == 1
assert all(val1, val2, val3, val4)
"""

def __init__(self):
self.is_dist_test = True
is_dist_test = True

# Temporary directory that is shared among test methods in a class
@pytest.fixture(autouse=True, scope="class")
Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def check_environment(pytestconfig):
@pytest.hookimpl(tryfirst=True)
def pytest_runtest_call(item):
# We want to use our own launching function for distributed tests
print("-------------------------------------------------------------------------")
print(type(item))
func_name = item.function.__name__ if hasattr(item, 'function') else None
print(f"Function name: {func_name}")
print("-------------------------------------------------------------------------")
if getattr(item.cls, "is_dist_test", False):
dist_test_class = item.cls()
dist_test_class(item._request)
Expand Down
17 changes: 7 additions & 10 deletions tests/model/test_model_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,11 @@
PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None
)


@pytest.mark.parametrize("param_dict", parameters, ids=names)
def test_train(param_dict):
t1 = run_generate_test_class()
t1.run_generate_test(param_dict, param_dict.pop("prompt"))


class run_generate_test_class(DistributedTest):
class TestModelGeneration(DistributedTest):
world_size = 2

def run_generate_test(self, param_dict, prompt):
@pytest.mark.parametrize("param_dict", parameters, ids=names)
def test_generate(self, param_dict, tmpdir):
from megatron.text_generation_utils import generate_samples_from_prompt
from megatron.utils import is_mp_rank_0

Expand All @@ -89,10 +83,10 @@ def run_generate_test(self, param_dict, prompt):
}

param_dict.update(fixed_params)
# TODO: we don't need to reinstantiate the model every time if we're only changing sampling settings - should be a workaround for this
model, _, _, args_loaded = model_setup(None, param_dict, clear_data=True)
model.eval()

prompt = param_dict.pop("prompt")
prompts = [prompt for _ in range(args_loaded.num_samples)]
output = generate_samples_from_prompt(
neox_args=args_loaded,
Expand All @@ -111,3 +105,6 @@ def run_generate_test(self, param_dict, prompt):
for prompt, out in zip(prompts, output):
assert prompt == out["context"]
assert len(out["text"]) > 0

# Clean up
del model

0 comments on commit 1745ffb

Please sign in to comment.