Skip to content

Commit

Permalink
Add fused_params to fp-ebc (pytorch#1429)
Browse files Browse the repository at this point in the history
Summary:

This is needed in the planner, since we rely on sharder to pass in global cache_load_factor.

mc-ebc also has this problem. But it is less worrying.

Reviewed By: YLGH

Differential Revision: D49105069
  • Loading branch information
hlhtsang authored and facebook-github-bot committed Oct 9, 2023
1 parent ee359ee commit 2e0e188
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torchrec/distributed/fp_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.

from functools import partial
from typing import Dict, Iterator, List, Optional, Type, Union
from typing import Any, Dict, Iterator, List, Optional, Type, Union

import torch
from torch import nn
Expand Down Expand Up @@ -190,6 +190,11 @@ def shard(
device=device,
)

@property
def fused_params(self) -> Optional[Dict[str, Any]]:
# TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints
return self._ebc_sharder.fused_params

def shardable_parameters(
self, module: FeatureProcessedEmbeddingBagCollection
) -> Dict[str, torch.nn.Parameter]:
Expand Down

0 comments on commit 2e0e188

Please sign in to comment.