Skip to content

Commit

Permalink
Fix failling tests (#1301)
Browse files Browse the repository at this point in the history
* fix typo

* fix neoxargs usage test

* skip conversion test due to multiprocessing issue

* precommit

---------

Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
AI-WAIFU and Quentin-Anthony authored Oct 8, 2024
1 parent c8f7b56 commit 3272032
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 4 deletions.
23 changes: 23 additions & 0 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,29 @@ Model Arguments
- **dim_att**: int
Default = None
Total dimension of the attention mechanism for RWKV. If not set, defaults to hidden_size.
- **head_size**: int
Default = None
Size of each attention head for RWKV. Calculated as dim_att // num_attention_heads.
- **ffn_dim**: int
Default = None
Dimension of the feed-forward network for RWKV. If not set, calculated based on hidden_size and expansion_factor.
## NeoXArgsOptimizer
Optimizer Arguments
Expand Down
19 changes: 17 additions & 2 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from template import NeoXArgsTemplate

try:
from typing import List, Literal, Union, Optional
from typing import List, Literal, Union, Optional, Any
except ImportError:
from typing_extensions import List, Literal, Union, Optional

Expand Down Expand Up @@ -502,6 +502,21 @@ class NeoXArgsModel(NeoXArgsTemplate):
Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column)
"""

dim_att: int = None
"""
Total dimension of the attention mechanism for RWKV. If not set, defaults to hidden_size.
"""

head_size: int = None
"""
Size of each attention head for RWKV. Calculated as dim_att // num_attention_heads.
"""

ffn_dim: int = None
"""
Dimension of the feed-forward network for RWKV. If not set, calculated based on hidden_size and expansion_factor.
"""


@dataclass
class NeoXArgsOptimizer(NeoXArgsTemplate):
Expand Down Expand Up @@ -673,7 +688,7 @@ class NeoXArgsLogging(NeoXArgsTemplate):
Custom metadata to attach to the created Comet Experiment.
"""

comet_experiment = None
comet_experiment: Any = None
"""
Initialized comet experiment object used to log data
"""
Expand Down
2 changes: 1 addition & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def forward_step(
return model.eval_batch(data_iterator, return_logits=return_logits)

# Get the batch.
if neox_args.memory_profiling and neox_args.it:
if neox_args.memory_profiling and neox_args.iteration:
torch.cuda.nvtx.range_push(f"Get batch")
if timers is not None:
timers("batch generator").start()
Expand Down
4 changes: 3 additions & 1 deletion tests/neox_args/test_neoxargs_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def test_neoxargs_usage():

# find args matches
matches = list(
re.findall(r"(?<=args\.).{2,}?(?=[\s\n(){}+-/*;:,=,[,\]])", file_contents)
re.findall(
r"(?<=neox_args\.).{2,}?(?=[\s\n(){}+-/*;:,=,[,\]])", file_contents
)
)
if len(matches) == 0:
continue
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/test_format_conversion_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
from megatron.neox_arguments.neox_args import NeoXArgsTokenizer


@pytest.mark.skip(
reason="Conversion test is skipped until we fix the CUDA + torch multiprocessing issue."
)
def test_gpt_neox_to_huggingface(monkeypatch, tmpdir, tmp_path):
# Generate random GPT-NEOX model, check we can convert to hf format

model_dir = str(tmpdir)
input_args = ["train.py", "tests/config/test_setup.yml"]
deepspeed_main_args = simulate_deepy_env(monkeypatch, input_args)
Expand Down

0 comments on commit 3272032

Please sign in to comment.