Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump 3rdparty/NeMo from 06e6703 to 06a1491 #538

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/NeMo
Submodule NeMo updated 372 files
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from megatron.core.transformer.transformer_config import TransformerConfig
from nemo.collections.llm import fn
from nemo.collections.llm.fn.mixin import FNMixin
from nemo.collections.llm.peft.lora import AdapterParallelAdd, LoRA
from nemo.collections.llm.peft.lora import LoRA, LoRALinear
from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.lightning.megatron_parallel import (
Expand Down Expand Up @@ -271,9 +271,7 @@ def selective_freeze(self, m: nn.Module, name: str | None = None, prefix: str |
FNMixin.freeze(m)
return m

def transform(
self, m: nn.Module, name: str | None = None, prefix: str | None = None
) -> nn.Module | AdapterParallelAdd:
def transform(self, m: nn.Module, name: str | None = None, prefix: str | None = None) -> nn.Module | LoRALinear:
"""Transforms the input model if the name is in the target modules."""
tp_size = parallel_state.get_tensor_model_parallel_world_size()
if name in self.target_modules:
Expand Down Expand Up @@ -317,5 +315,5 @@ def transform(
model_parallel_config=getattr(m, "config", None),
alpha=self.alpha,
)
return AdapterParallelAdd(m, adapter)
return LoRALinear(m, adapter)
return m
Loading