Skip to content

Commit

Permalink
Fix torchrec rowwise adagrad implementation
Browse files Browse the repository at this point in the history
Summary: Apparently we flipped the order of taking power 2 and taking mean.

Differential Revision: D57307271
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed May 17, 2024
1 parent 17d6895 commit f38c6f3
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 10 deletions.
91 changes: 87 additions & 4 deletions torchrec/distributed/test_utils/test_model_parallel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
)
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.modules.fused_embedding_modules import FusedEmbeddingBagCollection
from torchrec.optim.rowwise_adagrad import RowWiseAdagrad
from torchrec.test_utils import get_free_port, seed_and_log


Expand Down Expand Up @@ -334,10 +335,10 @@ class ModelParallelStateDictBase(ModelParallelSingleRankBase):
def setUp(self, backend: str = "nccl") -> None:
super().setUp(backend=backend)

num_features = 4
num_weighted_features = 2
self.batch_size = 20
self.num_float_features = 10
num_features = 1
num_weighted_features = 0
self.batch_size = 1
self.num_float_features = 1

self.tables = [
EmbeddingBagConfig(
Expand Down Expand Up @@ -953,3 +954,85 @@ def test_numerical_equivalence_between_kernel_types(
self._compare_models(
fused_model, model, is_deterministic=not stochastic_rounding
)

@unittest.skipIf(
not torch.cuda.is_available(),
"Not enough GPUs, this test requires at least one GPU",
)
# pyre-ignore[56]
@given(
sharder_type=st.sampled_from(
[
SharderType.EMBEDDING_BAG_COLLECTION.value,
]
),
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.FUSED.value,
]
),
)
@settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None)
def test_rowwise_adagrad_numerical_equivalence(
self,
sharder_type: str,
sharding_type: str,
kernel_type: str,
) -> None:
learning_rate = 0.1
fused_params = {
"optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
"learning_rate": learning_rate,
}

fused_sharders = [
cast(
ModuleSharder[nn.Module],
create_test_sharder(
sharder_type,
sharding_type,
EmbeddingComputeKernel.FUSED.value,
fused_params=fused_params,
),
),
]
dense_sharders = [
cast(
ModuleSharder[nn.Module],
create_test_sharder(
sharder_type,
sharding_type,
EmbeddingComputeKernel.DENSE.value,
fused_params=fused_params,
),
),
]
(fused_model, _), _ = self._generate_dmps_and_batch(fused_sharders)
(dense_model, _), batch = self._generate_dmps_and_batch(dense_sharders)

dense_opt = RowWiseAdagrad(
dense_model.module.sparse.parameters(),
lr=learning_rate,
eps=1e-8, # TBE has default eps 1e-8
)

# load the baseline model's state_dict onto the new model
dense_model.load_state_dict(
cast("OrderedDict[str, torch.Tensor]", fused_model.state_dict())
)

for _ in range(4):
dense_opt.zero_grad()
loss1, pred1 = fused_model(batch)
loss2, pred2 = dense_model(batch)
loss1.backward()
loss2.backward()
dense_opt.step()

self._eval_models(fused_model, dense_model, batch, is_deterministic=False)
self._compare_models(fused_model, dense_model, is_deterministic=False)
21 changes: 15 additions & 6 deletions torchrec/optim/rowwise_adagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@

#!/usr/bin/env python3

import logging
from typing import Any, Dict, Iterable, List

import torch
from torch import Tensor

from torch.optim.optimizer import Optimizer

logger: logging.Logger = logging.getLogger(__name__)


class RowWiseAdagrad(Optimizer):
r"""Implements Row wise Adagrad algorithm. This is an extension of the Adagrad algorithm
Expand Down Expand Up @@ -66,6 +69,14 @@ def __init__(
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))

if weight_decay > 0:
logger.warning(
"Note that the weight decay mode of this optimizer may produce "
"different results compared to the one by FBGEMM TBE. This is "
"due to FBGEMM TBE rowwise adagrad is sparse, and will only "
"update the optimizer states if that row has nonzero gradients."
)

defaults = dict(
lr=lr,
lr_decay=lr_decay,
Expand Down Expand Up @@ -211,14 +222,12 @@ def _single_tensor_adagrad(
step = step_t.item()
grad = grad if not maximize else -grad

row_wise_grad = grad.mean(axis=1).view(-1, 1)
if weight_decay != 0:

grad = grad.add(param, alpha=weight_decay)
row_wise_grad = grad.add(param, alpha=weight_decay)

state_sum += grad.pow(2).mean(axis=1).view(-1, 1)
std = state_sum.sqrt().add_(eps)

clr = lr / (1 + (step - 1) * lr_decay)

state_sum.addcmul_(row_wise_grad, row_wise_grad, value=1)
std = state_sum.sqrt().add_(eps)
param.addcdiv_(row_wise_grad, std, value=-clr)
param.addcdiv_(grad, std, value=-clr)

0 comments on commit f38c6f3

Please sign in to comment.