Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss disagreement between TP=1 and TP=2 #631

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

sichu2023
Copy link
Collaborator

@sichu2023 sichu2023 commented Jan 22, 2025

Description

Loss values from unreduced_token_loss_fn disagree even though the logits are the same when TP=1 and TP=2. Interestingly, this introduces a nearly-constant offset in training curve and does not impact grad_norm. The offset slowly shrinks over 100 training steps.

Suspected root cause:
vocab_parallel_cross_entropy takes into account of logits at the padded vocab dimensions while it should have ignored them. As training progresses, more probabilities are then assigned to "real vocabs" so the padding vacabs slowly becomes less relevant, leading to the closing gap between TP=1 and TP=2 loss curves.

logits dimensions also expanded from 128 to 256 when replacing ColumnParallelLinear with torch.nn.Linear although the reason is still unclear. If vocab_parallel_cross_entropy takes into account of the padding vocab dimensions, this can potentially explain the constant offset observed.

On a related issue, unreduced_token_loss_fn introduces inplace operation on the token logits.
https://github.com/NVIDIA/bionemo-framework/blob/sichu/loss-curve-tp-shift/sub-packages/bionemo-llm/tests/bionemo/llm/model/test_loss.py#L158

See Slack channel here.

Type of changes

  • Bug fix (non-breaking change which fixes an issue)

Usage

DATADIR=$(download_bionemo_data esm2/testdata_esm2_pretrain:2.0 --source pbss)/2024_03_sanity

NUM_DEVICES=${1}
TP=${NUM_DEVICES}

test_data_flags="""
    --train-cluster-path=${DATADIR}/train_clusters_sanity.parquet \
    --train-database-path=${DATADIR}/train_sanity.db \
    --valid-cluster-path=${DATADIR}/valid_clusters.parquet \
    --valid-database-path=${DATADIR}/validation.db
"""
hparam_flags_8m="""
    --num-layers=6 \
    --hidden-size=320 \
    --num-attention-heads=20 \
    --ffn-hidden-size=1280 \
    --micro-batch-size=4
"""
test_run_flags="""
    --num-nodes=1 \
    --num-gpus=${num_devices} \
    --num-steps=100 \
    --limit-val-batches=2
"""

launcher="torchrun --nproc-per-node=${NUM_DEVICES}"
script="sub-packages/bionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py"
cmd="""
${launcher} $script \
    $test_data_flags \
    $hparam_flags_8m \
    $test_run_flags \
    --log-every-n-steps=5 \
    --val-check-interval=5 \
    --tensor-model-parallel-size=${TP}
"""
eval $cmd

Pre-submit Checklist

  • I have tested these changes locally
  • I have updated the documentation accordingly
  • I have added/updated tests as needed
  • All existing tests pass successfully

@sichu2023 sichu2023 added the bug Something isn't working label Jan 22, 2025
@sichu2023 sichu2023 self-assigned this Jan 22, 2025
@sichu2023 sichu2023 marked this pull request as draft January 22, 2025 02:44
@sichu2023
Copy link
Collaborator Author

Loss curve when initializing torch.nn.Linear with torch.init.zeros_ under output_layer_init_method.
W B Chart 1_21_2025, 6_40_00 PM
W B Chart 1_21_2025, 6_40_05 PM

@codecov-commenter
Copy link

❌ 31 Tests Failed:

Tests completed Failed Passed Skipped
885 31 854 13
View the top 3 failed tests by shortest run time
../../../../usr/local/lib/python3.12/dist-packages/bionemo/testing/harnesses/stop_and_go.py::test_stop_and_go_consistency[ConsumedSamplesCallback]
Stack Traces | 0.001s run time
cls = <class 'geneformer.test_stop_and_go.TestGeneformerStopAndGo'>

    @classmethod
    def run_stop_and_go(cls):
        """Executes training both continuously and with a checkpoint interruption."""
        # Interrupted model training
>       cls.stop()

.../local/lib/python3.12.../testing/harnesses/stop_and_go.py:314: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.../local/lib/python3.12.../testing/harnesses/stop_and_go.py:268: in stop
    llm.train(
.../local/lib/python3.12.../collections/llm/api.py:106: in train
    trainer.fit(model, data)
.../local/lib/python3.12.../pytorch/trainer/trainer.py:538: in fit
    call._call_and_handle_interrupt(
.../local/lib/python3.12.../pytorch/trainer/call.py:46: in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
.../local/lib/python3.12.../strategies/launchers/subprocess_script.py:105: in launch
    return function(*args, **kwargs)
.../local/lib/python3.12.../pytorch/trainer/trainer.py:574: in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
.../local/lib/python3.12.../pytorch/trainer/trainer.py:981: in _run
    results = self._run_stage()
.../local/lib/python3.12.../pytorch/trainer/trainer.py:1023: in _run_stage
    self._run_sanity_check()
.../local/lib/python3.12.../pytorch/trainer/trainer.py:1052: in _run_sanity_check
    val_loop.run()
.../local/lib/python3.12.../pytorch/loops/utilities.py:178: in _decorator
    return loop_run(self, *args, **kwargs)
.../local/lib/python3.12.../pytorch/loops/evaluation_loop.py:135: in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
.../local/lib/python3.12.../pytorch/loops/evaluation_loop.py:396: in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
.../local/lib/python3.12.../pytorch/trainer/call.py:319: in _call_strategy_hook
    output = fn(*args, **kwargs)
.../local/lib/python3.12.../pytorch/strategies/megatron_strategy.py:621: in validation_step
    out = self.model.validation_step(dataloader_iter, *args, **kwargs)
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:342: in validation_step
    return self._step(
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:429: in _step
    return self.forward(
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:279: in forward
    microbatch_outputs = step()
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:1147: in __call__
    return self.forward_backward_func(
.../local/lib/python3.12.../core/pipeline_parallel/schedules.py:471: in forward_backward_no_pipelining
    output_tensor, num_tokens = forward_step(
.../local/lib/python3.12.../core/pipeline_parallel/schedules.py:275: in forward_step
    output_tensor, loss_func = forward_step_func(data_iterator, model)
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:492: in wrapped_forward_step_func
    output_tensor = _forward_step(model, batch)
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:758: in wrapped
    return attr(*args)
.../local/lib/python3.12.../bionemo/llm/lightning.py:330: in validation_step
    outputs = self.forward_step(batch)
.../local/lib/python3.12.../bionemo/llm/lightning.py:313: in forward_step
    return self._forward_step(self.module, batch)
.../local/lib/python3.12.../model/biobert/lightning.py:149: in bert_forward_step
    forward_results = model.forward(input_ids=batch["text"], attention_mask=batch["attention_mask"])
.../local/lib/python3.12.../core/distributed/data_parallel_base.py:22: in forward
    return self.module(*inputs, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
.../local/lib/python3.12.../core/transformer/module.py:178: in forward
    outputs = self.module(*inputs, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
.../local/lib/python3.12.../model/biobert/model.py:454: in forward
    logits = self.output_layer(hidden_states_after_lm_head)
.../local/lib/python3.12.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = ColumnParallelLinear()
input_ = tensor([[[-0.7891, -0.9961,  1.3750,  ...,  0.7500, -0.2598, -0.3262],
         [-0.6641, -0.9062,  1.4453,  ...,  0.8...
         [-0.0232, -0.3281, -0.8711,  ..., -0.9531, -0.5625, -0.4883]]],
       device='cuda:0', dtype=torch.bfloat16)
weight = None, runtime_gather_output = None

    def forward(
        self,
        input_: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        runtime_gather_output: Optional[bool] = None,
    ):
        """Forward of ColumnParallelLinear
    
        Args:
            input_:
                3D tensor whose order of dimension is [sequence, batch, hidden]
            weight (optional):
                weight tensor to use, compulsory when skip_weight_param_allocation is True.
            runtime_gather_output (bool): Gather output at runtime. Default None means
                `gather_output` arg in the constructor will be used.
    
        Returns:
            - output
            - bias
    
        """
        if weight is None:
            if self.weight is None:
>               raise RuntimeError(
                    "weight was not supplied to ColumnParallelLinear forward pass "
                    "and skip_weight_param_allocation is True."
                )
E               RuntimeError: weight was not supplied to ColumnParallelLinear forward pass and skip_weight_param_allocation is True.

.../local/lib/python3.12.../core/tensor_parallel/layers.py:896: RuntimeError
../../../../usr/local/lib/python3.12/dist-packages/bionemo/testing/harnesses/stop_and_go.py::test_stop_and_go_consistency[TrainOutputCallback]
Stack Traces | 0.001s run time
cls = <class 'geneformer.test_stop_and_go.TestGeneformerStopAndGo'>

    @classmethod
    def run_stop_and_go(cls):
        """Executes training both continuously and with a checkpoint interruption."""
        # Interrupted model training
>       cls.stop()

.../local/lib/python3.12.../testing/harnesses/stop_and_go.py:314: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.../local/lib/python3.12.../testing/harnesses/stop_and_go.py:268: in stop
    llm.train(
.../local/lib/python3.12.../collections/llm/api.py:106: in train
    trainer.fit(model, data)
.../local/lib/python3.12.../pytorch/trainer/trainer.py:538: in fit
    call._call_and_handle_interrupt(
.../local/lib/python3.12.../pytorch/trainer/call.py:46: in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
.../local/lib/python3.12.../strategies/launchers/subprocess_script.py:105: in launch
    return function(*args, **kwargs)
.../local/lib/python3.12.../pytorch/trainer/trainer.py:574: in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
.../local/lib/python3.12.../pytorch/trainer/trainer.py:981: in _run
    results = self._run_stage()
.../local/lib/python3.12.../pytorch/trainer/trainer.py:1023: in _run_stage
    self._run_sanity_check()
.../local/lib/python3.12.../pytorch/trainer/trainer.py:1052: in _run_sanity_check
    val_loop.run()
.../local/lib/python3.12.../pytorch/loops/utilities.py:178: in _decorator
    return loop_run(self, *args, **kwargs)
.../local/lib/python3.12.../pytorch/loops/evaluation_loop.py:135: in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
.../local/lib/python3.12.../pytorch/loops/evaluation_loop.py:396: in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
.../local/lib/python3.12.../pytorch/trainer/call.py:319: in _call_strategy_hook
    output = fn(*args, **kwargs)
.../local/lib/python3.12.../pytorch/strategies/megatron_strategy.py:621: in validation_step
    out = self.model.validation_step(dataloader_iter, *args, **kwargs)
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:342: in validation_step
    return self._step(
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:429: in _step
    return self.forward(
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:279: in forward
    microbatch_outputs = step()
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:1147: in __call__
    return self.forward_backward_func(
.../local/lib/python3.12.../core/pipeline_parallel/schedules.py:471: in forward_backward_no_pipelining
    output_tensor, num_tokens = forward_step(
.../local/lib/python3.12.../core/pipeline_parallel/schedules.py:275: in forward_step
    output_tensor, loss_func = forward_step_func(data_iterator, model)
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:492: in wrapped_forward_step_func
    output_tensor = _forward_step(model, batch)
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:758: in wrapped
    return attr(*args)
.../local/lib/python3.12.../bionemo/llm/lightning.py:330: in validation_step
    outputs = self.forward_step(batch)
.../local/lib/python3.12.../bionemo/llm/lightning.py:313: in forward_step
    return self._forward_step(self.module, batch)
.../local/lib/python3.12.../model/biobert/lightning.py:149: in bert_forward_step
    forward_results = model.forward(input_ids=batch["text"], attention_mask=batch["attention_mask"])
.../local/lib/python3.12.../core/distributed/data_parallel_base.py:22: in forward
    return self.module(*inputs, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
.../local/lib/python3.12.../core/transformer/module.py:178: in forward
    outputs = self.module(*inputs, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
.../local/lib/python3.12.../model/biobert/model.py:454: in forward
    logits = self.output_layer(hidden_states_after_lm_head)
.../local/lib/python3.12.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = ColumnParallelLinear()
input_ = tensor([[[-0.7891, -0.9961,  1.3750,  ...,  0.7500, -0.2598, -0.3262],
         [-0.6641, -0.9062,  1.4453,  ...,  0.8...
         [-0.0232, -0.3281, -0.8711,  ..., -0.9531, -0.5625, -0.4883]]],
       device='cuda:0', dtype=torch.bfloat16)
weight = None, runtime_gather_output = None

    def forward(
        self,
        input_: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        runtime_gather_output: Optional[bool] = None,
    ):
        """Forward of ColumnParallelLinear
    
        Args:
            input_:
                3D tensor whose order of dimension is [sequence, batch, hidden]
            weight (optional):
                weight tensor to use, compulsory when skip_weight_param_allocation is True.
            runtime_gather_output (bool): Gather output at runtime. Default None means
                `gather_output` arg in the constructor will be used.
    
        Returns:
            - output
            - bias
    
        """
        if weight is None:
            if self.weight is None:
>               raise RuntimeError(
                    "weight was not supplied to ColumnParallelLinear forward pass "
                    "and skip_weight_param_allocation is True."
                )
E               RuntimeError: weight was not supplied to ColumnParallelLinear forward pass and skip_weight_param_allocation is True.

.../local/lib/python3.12.../core/tensor_parallel/layers.py:896: RuntimeError
../../../../usr/local/lib/python3.12/dist-packages/bionemo/testing/harnesses/stop_and_go.py::test_stop_and_go_consistency[ValidLossCallback]
Stack Traces | 0.001s run time
cls = <class 'geneformer.test_stop_and_go.TestGeneformerStopAndGo'>

    @classmethod
    def run_stop_and_go(cls):
        """Executes training both continuously and with a checkpoint interruption."""
        # Interrupted model training
>       cls.stop()

.../local/lib/python3.12.../testing/harnesses/stop_and_go.py:314: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.../local/lib/python3.12.../testing/harnesses/stop_and_go.py:268: in stop
    llm.train(
.../local/lib/python3.12.../collections/llm/api.py:106: in train
    trainer.fit(model, data)
.../local/lib/python3.12.../pytorch/trainer/trainer.py:538: in fit
    call._call_and_handle_interrupt(
.../local/lib/python3.12.../pytorch/trainer/call.py:46: in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
.../local/lib/python3.12.../strategies/launchers/subprocess_script.py:105: in launch
    return function(*args, **kwargs)
.../local/lib/python3.12.../pytorch/trainer/trainer.py:574: in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
.../local/lib/python3.12.../pytorch/trainer/trainer.py:981: in _run
    results = self._run_stage()
.../local/lib/python3.12.../pytorch/trainer/trainer.py:1023: in _run_stage
    self._run_sanity_check()
.../local/lib/python3.12.../pytorch/trainer/trainer.py:1052: in _run_sanity_check
    val_loop.run()
.../local/lib/python3.12.../pytorch/loops/utilities.py:178: in _decorator
    return loop_run(self, *args, **kwargs)
.../local/lib/python3.12.../pytorch/loops/evaluation_loop.py:135: in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
.../local/lib/python3.12.../pytorch/loops/evaluation_loop.py:396: in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
.../local/lib/python3.12.../pytorch/trainer/call.py:319: in _call_strategy_hook
    output = fn(*args, **kwargs)
.../local/lib/python3.12.../pytorch/strategies/megatron_strategy.py:621: in validation_step
    out = self.model.validation_step(dataloader_iter, *args, **kwargs)
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:342: in validation_step
    return self._step(
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:429: in _step
    return self.forward(
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:279: in forward
    microbatch_outputs = step()
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:1147: in __call__
    return self.forward_backward_func(
.../local/lib/python3.12.../core/pipeline_parallel/schedules.py:471: in forward_backward_no_pipelining
    output_tensor, num_tokens = forward_step(
.../local/lib/python3.12.../core/pipeline_parallel/schedules.py:275: in forward_step
    output_tensor, loss_func = forward_step_func(data_iterator, model)
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:492: in wrapped_forward_step_func
    output_tensor = _forward_step(model, batch)
.../local/lib/python3.12.../nemo/lightning/megatron_parallel.py:758: in wrapped
    return attr(*args)
.../local/lib/python3.12.../bionemo/llm/lightning.py:330: in validation_step
    outputs = self.forward_step(batch)
.../local/lib/python3.12.../bionemo/llm/lightning.py:313: in forward_step
    return self._forward_step(self.module, batch)
.../local/lib/python3.12.../model/biobert/lightning.py:149: in bert_forward_step
    forward_results = model.forward(input_ids=batch["text"], attention_mask=batch["attention_mask"])
.../local/lib/python3.12.../core/distributed/data_parallel_base.py:22: in forward
    return self.module(*inputs, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
.../local/lib/python3.12.../core/transformer/module.py:178: in forward
    outputs = self.module(*inputs, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
.../local/lib/python3.12.../model/biobert/model.py:454: in forward
    logits = self.output_layer(hidden_states_after_lm_head)
.../local/lib/python3.12.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../local/lib/python3.12.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = ColumnParallelLinear()
input_ = tensor([[[-0.7891, -0.9961,  1.3750,  ...,  0.7500, -0.2598, -0.3262],
         [-0.6641, -0.9062,  1.4453,  ...,  0.8...
         [-0.0232, -0.3281, -0.8711,  ..., -0.9531, -0.5625, -0.4883]]],
       device='cuda:0', dtype=torch.bfloat16)
weight = None, runtime_gather_output = None

    def forward(
        self,
        input_: torch.Tensor,
        weight: Optional[torch.Tensor] = None,
        runtime_gather_output: Optional[bool] = None,
    ):
        """Forward of ColumnParallelLinear
    
        Args:
            input_:
                3D tensor whose order of dimension is [sequence, batch, hidden]
            weight (optional):
                weight tensor to use, compulsory when skip_weight_param_allocation is True.
            runtime_gather_output (bool): Gather output at runtime. Default None means
                `gather_output` arg in the constructor will be used.
    
        Returns:
            - output
            - bias
    
        """
        if weight is None:
            if self.weight is None:
>               raise RuntimeError(
                    "weight was not supplied to ColumnParallelLinear forward pass "
                    "and skip_weight_param_allocation is True."
                )
E               RuntimeError: weight was not supplied to ColumnParallelLinear forward pass and skip_weight_param_allocation is True.

.../local/lib/python3.12.../core/tensor_parallel/layers.py:896: RuntimeError

To view more test analytics, go to the Test Analytics Dashboard
📢 Thoughts on this report? Let us know!

# embedding_activation_buffer=self.embedding_activation_buffer,
# grad_output_buffer=self.grad_output_buffer,
# )
self.output_layer = torch.nn.Linear(config.hidden_size, self.vocab_size, bias=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably has bugs with bias getting out of sync between TP ranks during training FYI. See https://jirasw.nvidia.com/browse/BIONEMO-668 and https://jirasw.nvidia.com/browse/BIONEMO-666 and https://nvidia.slack.com/archives/C074Z808N05/p1737508003987919 and https://nvidia.slack.com/archives/C0434FDLPQV/p1733963545314469

Also if your concern is when you do TP=2 that the logit dim is 1/2 that may be because columnparallellinear splits along the logit vocab dimension, and ideally vocab parallel cross entropy knows how to reduce across this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup. Only temporarily placed torch.nn.Linear to debug. Will revert back to ColumnParallelLinear after so.

Copy link
Collaborator Author

@sichu2023 sichu2023 Jan 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am less concerned about "1/2 logits dim" (128 dim) but more concerned about torch.nn.Linear giving 256 dim on TP=2. 128 dim should be the correct dim (33 + padding).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants