diff --git a/fx2ait/fx2ait/converters/ait_converters.py b/fx2ait/fx2ait/converters/ait_converters.py index 3c21094ba..b04448c34 100644 --- a/fx2ait/fx2ait/converters/ait_converters.py +++ b/fx2ait/fx2ait/converters/ait_converters.py @@ -668,6 +668,13 @@ def acc_ops_getitem( isinstance(idx, Sequence) and any(isinstance(x, slice) for x in idx) ): return acc_ops_slice(target, args, kwargs, name) + + if isinstance(idx, AITTensor) and idx.dtype() == "bool": + # TODO: could do something similar to acc_ops_masked_select + raise NotImplementedError( + "AIT does not support tensor[boolean_tensor] masking yet" + ) + if isinstance(input_val, AITTensor): return acc_ops_slice(target, args, kwargs, name) diff --git a/fx2ait/fx2ait/test/converters/test_ait_binary_op.py b/fx2ait/fx2ait/test/converters/test_ait_binary_op.py index 1da7b11e1..fbedbc1e6 100644 --- a/fx2ait/fx2ait/test/converters/test_ait_binary_op.py +++ b/fx2ait/fx2ait/test/converters/test_ait_binary_op.py @@ -154,6 +154,25 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: expected_ops={acc_op}, ) + def test_getitem_boolean_index(self) -> None: + """Verify that NotImplementatedError is thrown encountering + tensor[boolean_mask_tensor] + """ + + class TestModule(torch.nn.Module): + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + return x[mask] + + mod = TestModule().cuda() + x = torch.rand(10, 4).half().cuda() + mask = (torch.rand((10,)) > 0.5).cuda() + mod(x, mask) + + self.assertRaises( + NotImplementedError, + lambda: self.run_test(mod, [x, mask], expected_ops={}), + ) + # This is a common binary op combo usage for ads models. def test_binary_op_combo(self) -> None: class TestModule(torch.nn.Module):