From 1e47e3aedc2e02fff602f1e725506f1655dfeee2 Mon Sep 17 00:00:00 2001 From: Weihua Hu Date: Fri, 1 Mar 2024 21:08:22 -0800 Subject: [PATCH] Fix handling of empty `MultiNestedTensor` (#369) --- CHANGELOG.md | 2 ++ test/data/test_multi_nested_tensor.py | 6 +++++- torch_frame/data/multi_embedding_tensor.py | 11 +++-------- torch_frame/data/multi_nested_tensor.py | 18 ++++-------------- torch_frame/datasets/amphibians.py | 2 -- 5 files changed, 14 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 63eb4c45..06d76e64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed bug in empty `MultiNestedTensor` handling ([#369](https://github.com/pyg-team/pytorch-frame/pull/369)) + - Fixed the split of `DataFrameTextBenchmark` ([#358](https://github.com/pyg-team/pytorch-frame/pull/358)) - Fixed empty `MultiNestedTensor` col indexing ([#355](https://github.com/pyg-team/pytorch-frame/pull/355)) diff --git a/test/data/test_multi_nested_tensor.py b/test/data/test_multi_nested_tensor.py index a53fcfac..62594ac7 100644 --- a/test/data/test_multi_nested_tensor.py +++ b/test/data/test_multi_nested_tensor.py @@ -180,6 +180,7 @@ def test_multi_nested_tensor_basics(device): # Test multi_nested_tensor[List[int]] indexing for index in [[4], [2, 2], [-4, 1, 7], [3, -7, 1, 0], []]: multi_nested_tensor_indexed = multi_nested_tensor[index] + assert multi_nested_tensor_indexed.dtype == torch.long assert multi_nested_tensor_indexed.shape[0] == len(index) assert multi_nested_tensor_indexed.shape[1] == num_cols for i, idx in enumerate(index): @@ -208,8 +209,10 @@ def test_multi_nested_tensor_basics(device): # Test column List[int] indexing for index in [[4], [2, 2], [-4, 1, 7], [3, -7, 1, 0], []]: + multi_nested_tensor_indexed = multi_nested_tensor[:, index] assert_equal(column_select(tensor_mat, index), - multi_nested_tensor[:, index]) + multi_nested_tensor_indexed) + assert multi_nested_tensor_indexed.dtype == torch.long # Test column-wise Boolean masking for index in [[4], [2, 3], [0, 1, 7], []]: @@ -245,6 +248,7 @@ def test_multi_nested_tensor_basics(device): empty_multi_nested_tensor = multi_nested_tensor[:, 5:3] assert empty_multi_nested_tensor.shape[0] == num_rows assert empty_multi_nested_tensor.shape[1] == 0 + assert empty_multi_nested_tensor.dtype == torch.long # Test column narrow assert_equal(column_select(tensor_mat, slice(3, 3 + 2)), diff --git a/torch_frame/data/multi_embedding_tensor.py b/torch_frame/data/multi_embedding_tensor.py index 7cc82c4e..b2031b18 100644 --- a/torch_frame/data/multi_embedding_tensor.py +++ b/torch_frame/data/multi_embedding_tensor.py @@ -150,12 +150,7 @@ def _col_index_select(self, index: Tensor) -> MultiEmbeddingTensor: :meth:`MultiEmbeddingTensor.index_select`. """ if index.numel() == 0: - return MultiEmbeddingTensor( - num_rows=self.num_rows, - num_cols=0, - values=torch.tensor([], device=self.device), - offset=torch.tensor([0], device=self.device), - ) + return self._empty(dim=1) offset = torch.zeros( index.size(0) + 1, dtype=torch.long, @@ -228,8 +223,8 @@ def _empty(self, dim: int) -> MultiEmbeddingTensor: return MultiEmbeddingTensor( num_rows=0 if dim == 0 else self.num_rows, num_cols=0 if dim == 1 else self.num_cols, - values=torch.tensor([], device=self.device), - offset=torch.tensor([0], device=self.device) + values=torch.tensor([], device=self.device, dtype=self.dtype), + offset=torch.tensor([0], device=self.device, dtype=torch.long) if dim == 1 else self.offset, ) diff --git a/torch_frame/data/multi_nested_tensor.py b/torch_frame/data/multi_nested_tensor.py index 48d2d305..aa2789e6 100644 --- a/torch_frame/data/multi_nested_tensor.py +++ b/torch_frame/data/multi_nested_tensor.py @@ -180,12 +180,7 @@ def _row_index_select(self, index: Tensor) -> MultiNestedTensor: r"""Helper function called by :obj:`index_select`.""" # Calculate values if index.numel() == 0: - return MultiNestedTensor( - num_rows=0, - num_cols=self.num_cols, - values=torch.tensor([], device=self.device), - offset=torch.tensor([0], device=self.device), - ) + return self._empty(dim=0) index_right = (index + 1) * self.num_cols index_left = index * self.num_cols diff = self.offset[index_right] - self.offset[index_left] @@ -218,12 +213,7 @@ def _row_index_select(self, index: Tensor) -> MultiNestedTensor: def _col_index_select(self, index: Tensor) -> MultiNestedTensor: r"""Helper function called by :obj:`index_select`.""" if index.numel() == 0: - return MultiNestedTensor( - num_rows=self.num_rows, - num_cols=0, - values=torch.tensor([], device=self.device), - offset=torch.tensor([0], device=self.device), - ) + return self._empty(dim=1) start_idx = (index + torch.arange( 0, self.num_rows * self.num_cols, @@ -320,13 +310,13 @@ def to_dense(self, fill_value: int | float) -> Tensor: return dense def _empty(self, dim: int) -> MultiNestedTensor: - r"""Creates an empty :class:`MultiEmbeddingTensor`. + r"""Creates an empty :class:`MultiNestedTensor`. Args: dim (int): The dimension to empty. Returns: - MultiEmbeddingTensor: An empty :class:`MultiEmbeddingTensor`. + MultiNestedTensor: An empty :class:`MultiNestedTensor`. """ values = torch.tensor([], device=self.device, dtype=self.dtype) offset = torch.zeros(1, device=self.device, dtype=torch.long) diff --git a/torch_frame/datasets/amphibians.py b/torch_frame/datasets/amphibians.py index d18af777..48b2a46b 100644 --- a/torch_frame/datasets/amphibians.py +++ b/torch_frame/datasets/amphibians.py @@ -54,8 +54,6 @@ def __init__(self, root: str): lambda row: [col for col in target_cols if row[col] == '1'], axis=1) df = df.drop(target_cols, axis=1) - import pdb - pdb.set_trace() # Infer the pandas dataframe automatically path = osp.join(root, 'amphibians_posprocess.csv')