Skip to content

Commit

Permalink
suppress errors in torchrec
Browse files Browse the repository at this point in the history
Differential Revision: D50469064
  • Loading branch information
Pyre Bot Jr authored and facebook-github-bot committed Oct 19, 2023
1 parent 2d595ea commit f3b29d4
Show file tree
Hide file tree
Showing 7 changed files with 0 additions and 21 deletions.
2 changes: 0 additions & 2 deletions torchrec/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,8 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "Batch":
)

def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
self.dense_features.record_stream(stream)
self.sparse_features.record_stream(stream)
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
self.labels.record_stream(stream)

def pin_memory(self) -> "Batch":
Expand Down
1 change: 0 additions & 1 deletion torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
for f in self.input_features:
f.record_stream(stream)
for r in self.reverse_indices:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
r.record_stream(stream)


Expand Down
4 changes: 0 additions & 4 deletions torchrec/distributed/sharding/sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,10 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
if self.features_before_input_dist is not None:
self.features_before_input_dist.record_stream(stream)
if self.sparse_features_recat is not None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
self.sparse_features_recat.record_stream(stream)
if self.unbucketize_permute_tensor is not None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
self.unbucketize_permute_tensor.record_stream(stream)
if self.lengths_after_input_dist is not None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
self.lengths_after_input_dist.record_stream(stream)


Expand All @@ -74,5 +71,4 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
if self.features_before_input_dist is not None:
self.features_before_input_dist.record_stream(stream)
if self.unbucketize_permute_tensor is not None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
self.unbucketize_permute_tensor.record_stream(stream)
2 changes: 0 additions & 2 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,12 +240,10 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput":
)

def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
self.float_features.record_stream(stream)
self.idlist_features.record_stream(stream)
if self.idscore_features is not None:
self.idscore_features.record_stream(stream)
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
self.label.record_stream(stream)


Expand Down
2 changes: 0 additions & 2 deletions torchrec/distributed/tests/test_train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ def to(self, device: torch.device, non_blocking: bool) -> "ModelInputSimple":
)

def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
self.float_features.record_stream(stream)
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
self.label.record_stream(stream)


Expand Down
1 change: 0 additions & 1 deletion torchrec/distributed/train_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,6 @@ def __call__(self, *input, **kwargs) -> Awaitable:
assert isinstance(
data, (torch.Tensor, Multistreamable)
), f"{type(data)} must implement Multistreamable interface"
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
data.record_stream(cur_stream)

ctx = self._context.module_contexts[self._name]
Expand Down
9 changes: 0 additions & 9 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,19 +524,15 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor"

@torch.jit.unused
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
self._values.record_stream(stream)
weights = self._weights
lengths = self._lengths
offsets = self._offsets
if weights is not None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
weights.record_stream(stream)
if lengths is not None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
lengths.record_stream(stream)
if offsets is not None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
offsets.record_stream(stream)

def __str__(self) -> str:
Expand Down Expand Up @@ -1666,19 +1662,15 @@ def to_dict(self) -> Dict[str, JaggedTensor]:

@torch.jit.unused
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
self._values.record_stream(stream)
weights = self._weights
lengths = self._lengths
offsets = self._offsets
if weights is not None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
weights.record_stream(stream)
if lengths is not None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
lengths.record_stream(stream)
if offsets is not None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
offsets.record_stream(stream)

def to(
Expand Down Expand Up @@ -2096,7 +2088,6 @@ def regroup_as_dict(

@torch.jit.unused
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
# pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`.
self._values.record_stream(stream)

def to(self, device: torch.device, non_blocking: bool = False) -> "KeyedTensor":
Expand Down

0 comments on commit f3b29d4

Please sign in to comment.