Skip to content

Commit

Permalink
Add fused_params to mch sharders (pytorch#1649)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1649

Adding fused_params to zch. The main reason is so cache load factor passed through zch sharder.fused_params can be represented in planner stats.

Reviewed By: dstaay-fb

Differential Revision: D52921362

fbshipit-source-id: 94001c20e8eb308eaa2d4a68923b132eddb1ff26
  • Loading branch information
henrylhtsang authored and facebook-github-bot committed Feb 6, 2024
1 parent 4f234a2 commit 0c2462f
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torchrec/distributed/mc_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from typing import Dict, Iterator, List, Optional, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, TypeVar, Union

import torch
from torch.autograd.profiler import record_function
Expand Down Expand Up @@ -276,3 +276,8 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
set(self._mc_sharder.sharding_types(compute_device_type)),
)
)

@property
def fused_params(self) -> Optional[Dict[str, Any]]:
# TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints
return self._e_sharder.fused_params

0 comments on commit 0c2462f

Please sign in to comment.