Skip to content

Commit

Permalink
Verify no in-place modification in StypeEncoder (#299)
Browse files Browse the repository at this point in the history
Next step (after `v0.2.0` release): Avoid unnecessary `clone()` in stype
encoder and still pass this test. Ref:
#254
  • Loading branch information
weihua916 authored Dec 15, 2023
1 parent 7fca13d commit 2b59649
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 22 deletions.
2 changes: 1 addition & 1 deletion docs/source/handling_advanced_stypes/handle_text.rst
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ Next we need to specify the text model embedding with `LoRA <https://arxiv.org/a
As mentioned above, we store text model inputs in the format of dictionary of
:obj:`~torch_frame.data.MultiNestedTensor`s.
:obj:`~torch_frame.data.MultiNestedTensor`.
During the :meth:`forward`, we first transform each
:obj:`~torch_frame.data.MultiNestedTensor` back to padded PyTorch Tensor by using
:meth:`~torch_frame.data.MultiNestedTensor.to_dense` with the padding value
Expand Down
59 changes: 39 additions & 20 deletions test/nn/encoder/test_stypewise_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import copy

import pytest

from torch_frame import stype
from torch_frame import NAStrategy, stype
from torch_frame.config import ModelConfig
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.config.text_tokenizer import TextTokenizerConfig
Expand All @@ -24,34 +26,46 @@
)


@pytest.mark.parametrize('encoder_cat_cls_kwargs', [(EmbeddingEncoder, {})])
@pytest.mark.parametrize('encoder_num_cls_kwargs', [
(LinearEncoder, {}),
(LinearBucketEncoder, {}),
@pytest.mark.parametrize("encoder_cat_cls_kwargs",
[(EmbeddingEncoder, {
"na_strategy": NAStrategy.MOST_FREQUENT,
})])
@pytest.mark.parametrize("encoder_num_cls_kwargs", [
(LinearEncoder, {
"na_strategy": NAStrategy.MEAN,
}),
(LinearBucketEncoder, {
"na_strategy": NAStrategy.MEAN,
}),
(LinearPeriodicEncoder, {
'n_bins': 4
"n_bins": 4,
"na_strategy": NAStrategy.MEAN,
}),
])
@pytest.mark.parametrize('encoder_multicategorical_cls_kwargs', [
(MultiCategoricalEmbeddingEncoder, {}),
@pytest.mark.parametrize("encoder_multicategorical_cls_kwargs", [
(MultiCategoricalEmbeddingEncoder, {
"na_strategy": NAStrategy.ZEROS
}),
])
@pytest.mark.parametrize('encoder_timestamp_cls_kwargs', [
(TimestampEncoder, {}),
@pytest.mark.parametrize("encoder_timestamp_cls_kwargs", [
(TimestampEncoder, {
"na_strategy": NAStrategy.MEDIAN_TIMESTAMP
}),
])
@pytest.mark.parametrize('encoder_text_embedded_cls_kwargs', [
@pytest.mark.parametrize("encoder_text_embedded_cls_kwargs", [
(LinearEmbeddingEncoder, {}),
])
@pytest.mark.parametrize('encoder_text_tokenized_cls_kwargs', [
@pytest.mark.parametrize("encoder_text_tokenized_cls_kwargs", [
(LinearModelEncoder, {
'col_to_model_cfg': {
'text_tokenized_1':
"col_to_model_cfg": {
"text_tokenized_1":
ModelConfig(model=RandomTextModel(12), out_channels=12),
'text_tokenized_2':
"text_tokenized_2":
ModelConfig(model=RandomTextModel(6), out_channels=6)
},
}),
])
@pytest.mark.parametrize('encoder_embedding_cls_kwargs', [
@pytest.mark.parametrize("encoder_embedding_cls_kwargs", [
(LinearEmbeddingEncoder, {}),
])
def test_stypewise_feature_encoder(
Expand All @@ -66,7 +80,7 @@ def test_stypewise_feature_encoder(
num_rows = 10
dataset: Dataset = FakeDataset(
num_rows=num_rows,
with_nan=False,
with_nan=True,
stypes=[
stype.categorical,
stype.numerical,
Expand Down Expand Up @@ -124,7 +138,12 @@ def test_stypewise_feature_encoder(
col_names_dict=tensor_frame.col_names_dict,
stype_encoder_dict=stype_encoder_dict,
)
tensor_frame_original = copy.deepcopy(tensor_frame)
x, col_names = encoder(tensor_frame)

# Test no in-place operation in encoder
assert tensor_frame_original == tensor_frame

assert x.shape == (num_rows, tensor_frame.num_cols, out_channels)
assert col_names == [
"num_1",
Expand All @@ -138,9 +157,9 @@ def test_stypewise_feature_encoder(
"multicat_2",
"multicat_3",
"multicat_4",
'timestamp_0',
'timestamp_1',
'timestamp_2',
"timestamp_0",
"timestamp_1",
"timestamp_2",
"emb_1",
"emb_2",
"text_embedded_1",
Expand Down
2 changes: 1 addition & 1 deletion torch_frame/data/tensor_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def __eq__(self, other: Any) -> bool:
return False
if self_feat.shape != other_feat.shape:
return False
if not torch.allclose(self_feat, other_feat):
if not torch.allclose(self_feat, other_feat, equal_nan=True):
return False
elif isinstance(self_feat, MultiNestedTensor):
if not isinstance(other_feat, MultiNestedTensor):
Expand Down

0 comments on commit 2b59649

Please sign in to comment.