Skip to content

Commit

Permalink
Add flushing and reset_cache_states to pre-hook and hook of state_dic…
Browse files Browse the repository at this point in the history
…t and load_state_dict (pytorch#1674)

Summary:

The reason for doing both at the same time is to also enable the unit test.

What this diff does: 
* call flushing before state_dict
* call reset_cache_states after load_state_dict

Problem previous is that when we call sharded_ebc.state_dict(), it won't recursively call lookup.state_dict(). So no flushing was called.

Differential Revision: D53199744
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Feb 2, 2024
1 parent cda48a6 commit 0164a5b
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 3 deletions.
12 changes: 12 additions & 0 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ def config(self) -> GroupedEmbeddingConfig:
def flush(self) -> None:
pass

def purge(self) -> None:
pass

def named_split_embedding_weights(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
Expand Down Expand Up @@ -649,6 +652,9 @@ def named_parameters(
def flush(self) -> None:
self._emb_module.flush()

def purge(self) -> None:
self._emb_module.reset_cache_states()


class BatchedDenseEmbedding(BaseBatchedEmbedding[torch.Tensor]):
def __init__(
Expand Down Expand Up @@ -810,6 +816,9 @@ def config(self) -> GroupedEmbeddingConfig:
def flush(self) -> None:
pass

def purge(self) -> None:
pass

def named_split_embedding_weights(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, torch.Tensor]]:
Expand Down Expand Up @@ -935,6 +944,9 @@ def named_parameters(
def flush(self) -> None:
self._emb_module.flush()

def purge(self) -> None:
self._emb_module.reset_cache_states()


class BatchedDenseEmbeddingBag(BaseBatchedEmbeddingBag[torch.Tensor]):
def __init__(
Expand Down
17 changes: 17 additions & 0 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,17 @@ def __init__(
if module.device != torch.device("meta"):
self.load_state_dict(module.state_dict())

@staticmethod
def _pre_state_dict_hook(
self: "ShardedEmbeddingCollection",
prefix: str = "",
keep_vars: bool = False,
) -> None:
for lookup in self._lookups:
while isinstance(lookup, DistributedDataParallel):
lookup = lookup.module
lookup.flush()

@staticmethod
def _pre_load_state_dict_hook(
self: "ShardedEmbeddingCollection",
Expand Down Expand Up @@ -475,6 +486,11 @@ def _pre_load_state_dict_hook(
else torch.cat(local_shards, dim=0)
)

for lookup in self._lookups:
while isinstance(lookup, DistributedDataParallel):
lookup = lookup.module
lookup.purge()

def _initialize_torch_state(self) -> None: # noqa
"""
This provides consistency between this class and the EmbeddingCollection's
Expand Down Expand Up @@ -562,6 +578,7 @@ def post_state_dict_hook(
destination_key = f"{prefix}embeddings.{table_name}.weight"
destination[destination_key] = sharded_t

self.register_state_dict_pre_hook(self._pre_state_dict_hook)
self._register_state_dict_hook(post_state_dict_hook)
self._register_load_state_dict_pre_hook(
self._pre_load_state_dict_hook, with_module=True
Expand Down
32 changes: 32 additions & 0 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,14 @@ def named_parameters_by_table(
) in embedding_kernel.named_parameters_by_table():
yield (table_name, tbe_slice)

def flush(self) -> None:
for emb_module in self._emb_modules:
emb_module.flush()

def purge(self) -> None:
for emb_module in self._emb_modules:
emb_module.purge()


class CommOpGradientScaling(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -503,6 +511,14 @@ def named_parameters_by_table(
) in embedding_kernel.named_parameters_by_table():
yield (table_name, tbe_slice)

def flush(self) -> None:
for emb_module in self._emb_modules:
emb_module.flush()

def purge(self) -> None:
for emb_module in self._emb_modules:
emb_module.purge()


class MetaInferGroupedEmbeddingsLookup(
BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn
Expand Down Expand Up @@ -627,6 +643,14 @@ def named_buffers(
for emb_module in self._emb_modules:
yield from emb_module.named_buffers(prefix, recurse)

def flush(self) -> None:
# not implemented
pass

def purge(self) -> None:
# not implemented
pass


class MetaInferGroupedPooledEmbeddingsLookup(
BaseEmbeddingLookup[KeyedJaggedTensor, torch.Tensor], TBEToRegisterMixIn
Expand Down Expand Up @@ -771,6 +795,14 @@ def named_buffers(
for emb_module in self._emb_modules:
yield from emb_module.named_buffers(prefix, recurse)

def flush(self) -> None:
# not implemented
pass

def purge(self) -> None:
# not implemented
pass


class InferGroupedLookupMixin(ABC):
def forward(
Expand Down
17 changes: 17 additions & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,17 @@ def __init__(
]:
self.load_state_dict(module.state_dict(), strict=False)

@staticmethod
def _pre_state_dict_hook(
self: "ShardedEmbeddingBagCollection",
prefix: str = "",
keep_vars: bool = False,
) -> None:
for lookup in self._lookups:
while isinstance(lookup, DistributedDataParallel):
lookup = lookup.module
lookup.flush()

@staticmethod
def _pre_load_state_dict_hook(
self: "ShardedEmbeddingBagCollection",
Expand Down Expand Up @@ -571,6 +582,11 @@ def _pre_load_state_dict_hook(
f"Unexpected state_dict key type {type(state_dict[key])} found for {key}"
)

for lookup in self._lookups:
while isinstance(lookup, DistributedDataParallel):
lookup = lookup.module
lookup.purge()

def _initialize_torch_state(self) -> None: # noqa
"""
This provides consistency between this class and the EmbeddingBagCollection's
Expand Down Expand Up @@ -657,6 +673,7 @@ def post_state_dict_hook(
destination_key = f"{prefix}embedding_bags.{table_name}.weight"
destination[destination_key] = sharded_t

self.register_state_dict_pre_hook(self._pre_state_dict_hook)
self._register_state_dict_hook(post_state_dict_hook)
self._register_load_state_dict_pre_hook(
self._pre_load_state_dict_hook, with_module=True
Expand Down
95 changes: 92 additions & 3 deletions torchrec/distributed/test_utils/test_model_parallel_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,14 +462,103 @@ def test_meta_device_dmp_state_dict(self) -> None:

# pyre-ignore[56]
@given(
sharders=st.sampled_from(
sharder_type=st.sampled_from(
[
[EmbeddingBagCollectionSharder()],
SharderType.EMBEDDING_BAG_COLLECTION.value,
]
),
sharding_type=st.sampled_from(
[
ShardingType.COLUMN_WISE.value,
ShardingType.DATA_PARALLEL.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.FUSED.value,
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
]
),
)
@settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None)
def test_load_state_dict(
self, sharder_type: str, sharding_type: str, kernel_type: str
) -> None:
if (
self.device == torch.device("cpu")
and kernel_type != EmbeddingComputeKernel.FUSED.value
):
self.skipTest("CPU does not support uvm.")

sharders = [
cast(
ModuleSharder[nn.Module],
create_test_sharder(
sharder_type,
sharding_type,
kernel_type,
),
),
]
models, batch = self._generate_dmps_and_batch(sharders)
m1, m2 = models

# load the second's (m2's) with the first (m1's) state_dict
m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict()))

# validate the models are equivalent
with torch.no_grad():
loss1, pred1 = m1(batch)
loss2, pred2 = m2(batch)
self.assertTrue(torch.equal(loss1, loss2))
self.assertTrue(torch.equal(pred1, pred2))
sd1 = m1.state_dict()
for key, value in m2.state_dict().items():
v2 = sd1[key]
if isinstance(value, ShardedTensor):
assert len(value.local_shards()) == 1
dst = value.local_shards()[0].tensor
else:
dst = value
if isinstance(v2, ShardedTensor):
assert len(v2.local_shards()) == 1
src = v2.local_shards()[0].tensor
else:
src = v2
self.assertTrue(torch.equal(src, dst))

# pyre-ignore[56]
@given(
sharder_type=st.sampled_from(
[
SharderType.EMBEDDING_BAG_COLLECTION.value,
]
),
sharding_type=st.sampled_from(
[
ShardingType.DATA_PARALLEL.value,
]
),
kernel_type=st.sampled_from(
[
EmbeddingComputeKernel.DENSE.value,
]
),
)
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
def test_load_state_dict(self, sharders: List[ModuleSharder[nn.Module]]) -> None:
def test_load_state_dict_dp(
self, sharder_type: str, sharding_type: str, kernel_type: str
) -> None:
sharders = [
cast(
ModuleSharder[nn.Module],
create_test_sharder(
sharder_type,
sharding_type,
kernel_type,
),
),
]
models, batch = self._generate_dmps_and_batch(sharders)
m1, m2 = models

Expand Down

0 comments on commit 0164a5b

Please sign in to comment.