diff --git a/torchrec/datasets/utils.py b/torchrec/datasets/utils.py index 333fc49ed..faedc498b 100644 --- a/torchrec/datasets/utils.py +++ b/torchrec/datasets/utils.py @@ -40,10 +40,8 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "Batch": ) def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.dense_features.record_stream(stream) self.sparse_features.record_stream(stream) - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.labels.record_stream(stream) def pin_memory(self) -> "Batch": diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 62be9a2c9..8d17c36ce 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -289,7 +289,6 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None: for f in self.input_features: f.record_stream(stream) for r in self.reverse_indices: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. r.record_stream(stream) diff --git a/torchrec/distributed/sharding/sequence_sharding.py b/torchrec/distributed/sharding/sequence_sharding.py index cee1b2179..fe14e3b12 100644 --- a/torchrec/distributed/sharding/sequence_sharding.py +++ b/torchrec/distributed/sharding/sequence_sharding.py @@ -43,13 +43,10 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None: if self.features_before_input_dist is not None: self.features_before_input_dist.record_stream(stream) if self.sparse_features_recat is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.sparse_features_recat.record_stream(stream) if self.unbucketize_permute_tensor is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.unbucketize_permute_tensor.record_stream(stream) if self.lengths_after_input_dist is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.lengths_after_input_dist.record_stream(stream) @@ -74,5 +71,4 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None: if self.features_before_input_dist is not None: self.features_before_input_dist.record_stream(stream) if self.unbucketize_permute_tensor is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.unbucketize_permute_tensor.record_stream(stream) diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index 879813e5f..9beeabd9e 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -240,12 +240,10 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput": ) def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.float_features.record_stream(stream) self.idlist_features.record_stream(stream) if self.idscore_features is not None: self.idscore_features.record_stream(stream) - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.label.record_stream(stream) diff --git a/torchrec/distributed/tests/test_train_pipeline.py b/torchrec/distributed/tests/test_train_pipeline.py index 17a696d47..6a4e357c0 100644 --- a/torchrec/distributed/tests/test_train_pipeline.py +++ b/torchrec/distributed/tests/test_train_pipeline.py @@ -107,9 +107,7 @@ def to(self, device: torch.device, non_blocking: bool) -> "ModelInputSimple": ) def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.float_features.record_stream(stream) - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self.label.record_stream(stream) diff --git a/torchrec/distributed/train_pipeline.py b/torchrec/distributed/train_pipeline.py index c44934b2d..fc6f5b796 100644 --- a/torchrec/distributed/train_pipeline.py +++ b/torchrec/distributed/train_pipeline.py @@ -468,7 +468,6 @@ def __call__(self, *input, **kwargs) -> Awaitable: assert isinstance( data, (torch.Tensor, Multistreamable) ), f"{type(data)} must implement Multistreamable interface" - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. data.record_stream(cur_stream) ctx = self._context.module_contexts[self._name] diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 625df2205..1af1f9046 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -524,19 +524,15 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor" @torch.jit.unused def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self._values.record_stream(stream) weights = self._weights lengths = self._lengths offsets = self._offsets if weights is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. weights.record_stream(stream) if lengths is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. lengths.record_stream(stream) if offsets is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. offsets.record_stream(stream) def __str__(self) -> str: @@ -1666,19 +1662,15 @@ def to_dict(self) -> Dict[str, JaggedTensor]: @torch.jit.unused def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self._values.record_stream(stream) weights = self._weights lengths = self._lengths offsets = self._offsets if weights is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. weights.record_stream(stream) if lengths is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. lengths.record_stream(stream) if offsets is not None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. offsets.record_stream(stream) def to( @@ -2096,7 +2088,6 @@ def regroup_as_dict( @torch.jit.unused def record_stream(self, stream: torch.cuda.streams.Stream) -> None: - # pyre-fixme[6]: For 1st param expected `Stream` but got `Stream`. self._values.record_stream(stream) def to(self, device: torch.device, non_blocking: bool = False) -> "KeyedTensor":