From 0af7485aa11a0d448fa59ceb3a242d607076e402 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Tue, 28 Nov 2023 20:23:30 +0900 Subject: [PATCH 1/5] add support for `aten::scaled_dot_product_attention` --- python/tvm/relay/frontend/pytorch.py | 99 +++++++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 91 +++++++++++++++++ 2 files changed, 190 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 30711de0a760..c82cf1cffbd7 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3891,6 +3891,104 @@ def linalg_vector_norm(self, inputs, input_types): reci_ord, ) + def scaled_dot_product_attention(self, inputs, input_types): + query = inputs[0] + key = inputs[1] + value = inputs[2] + attn_mask = inputs[3] + dropout_p = inputs[4] + is_causal = inputs[5] + scale = inputs[6] + + assert ( + input_types[0] == input_types[1] == input_types[2] + ), "Expected query, key, and value to have the same dtype" + + dtype = input_types[0] + assert dtype == "float32" or dtype == "float64", "Data type can be float32 or float64" + + query_shape = self.infer_shape_with_prelude(query) + key_shape = self.infer_shape_with_prelude(key) + value_shape = self.infer_shape_with_prelude(value) + assert 3 <= len(query_shape) <= 4, "Only 3D or 4D query supported" + assert 3 <= len(key_shape) <= 4, "Only 3D or 4D key supported" + assert 3 <= len(value_shape) <= 4, "Only 3D or 4D value supported" + + assert dropout_p == 0.0, "Only dropout_p==0.0 supported" + + L, S = query_shape[-2], key_shape[-2] + + if scale is None: + scale_factor = _expr.const(1 / math.sqrt(query_shape[-1]), dtype=dtype) + else: + scale_factor = _expr.const(scale, dtype=dtype) + + attn_bias = _op.full(_expr.const(0.0, dtype=dtype), (L, S)) + + if is_causal: + assert attn_mask is None, "Explicit attn_mask shouldn't be set when is_causal=True" + temp_mask = _op.full(_expr.const(True), [L, S], dtype="bool") + temp_mask = _op.trilu(temp_mask, 0, upper=False) + temp_mask = _op.cast(temp_mask, dtype="bool") + temp_mask = _op.logical_not(temp_mask) + fill_value = _op.cast(_expr.const(float("-inf")), dtype=dtype) + attn_bias = _op.where(temp_mask, fill_value, attn_bias) + attn_bias = _op.cast(attn_bias, dtype) + + if attn_mask is not None: + if input_types[3] == "bool": + attn_mask = _op.logical_not(attn_mask) + fill_value = _op.cast(_expr.const(float("-inf")), dtype=dtype) + attn_bias = _op.where(attn_mask, fill_value, attn_bias) + else: + attn_bias = _op.add(attn_bias, attn_mask) + + if len(query_shape) < len(key_shape): + batch_size = key_shape[0] + else: + batch_size = query_shape[0] + if len(query_shape) == 4 and len(key_shape) == 4: + query = _op.reshape(query, newshape=[-3, -2]) + key = _op.reshape(key, newshape=[-3, -2]) + if len(query_shape) == 3 and len(key_shape) == 4: + query = _op.broadcast_to(query, shape=(batch_size,) + query_shape) + query = _op.reshape(query, newshape=[-3, -2]) + key = _op.reshape(key, newshape=[-3, -2]) + if len(query_shape) == 4 and len(key_shape) == 3: + query = _op.reshape(query, newshape=[-3, -2]) + key = _op.broadcast_to(key, shape=(batch_size,) + key_shape) + key = _op.reshape(key, newshape=[-3, -2]) + attn_weight = _op.nn.batch_matmul(query, key) + if len(query_shape) == 4 or len(key_shape) == 4: + attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2]) + attn_weight = _op.squeeze(attn_weight, axis=[]) + + attn_weight = _op.multiply(attn_weight, scale_factor) + attn_weight = _op.add(attn_weight, attn_bias) + attn_weight = _op.nn.softmax(attn_weight) + attn_weight = _op.nn.dropout(attn_weight, rate=dropout_p) + + aw_shape = self.infer_shape_with_prelude(attn_weight) + if len(aw_shape) < len(value_shape): + batch_size = value_shape[0] + else: + batch_size = aw_shape[0] + if len(aw_shape) == 4 and len(value_shape) == 4: + attn_weight = _op.reshape(attn_weight, newshape=[-3, -2]) + value = _op.reshape(value, newshape=[-3, -2]) + if len(aw_shape) == 3 and len(value_shape) == 4: + attn_weight = _op.broadcast_to(attn_weight, shape=(batch_size,) + aw_shape) + attn_weight = _op.reshape(attn_weight, newshape=[-3, -2]) + value = _op.reshape(value, newshape=[-3, -2]) + if len(aw_shape) == 4 and len(value_shape) == 3: + attn_weight = _op.reshape(attn_weight, newshape=[-3, -2]) + value = _op.broadcast_to(value, shape=(batch_size,) + value_shape) + value = _op.reshape(value, newshape=[-3, -2]) + attn_weight = _op.nn.batch_matmul(attn_weight, value, transpose_b=False) + if len(aw_shape) == 4 or len(value_shape) == 4: + attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2]) + return attn_weight + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -4167,6 +4265,7 @@ def create_convert_map(self): "aten::copy_": self.inplace_copy, "aten::swapaxes": self.transpose, "aten::linalg_vector_norm": self.linalg_vector_norm, + "aten::scaled_dot_product_attention": self.scaled_dot_product_attention, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d9ecbce26587..3860cdd324df 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5489,6 +5489,97 @@ def test_fn(order): verify_model(test_fn(order=0), input_data=input_data) +@tvm.testing.uses_gpu +def test_scaled_dot_product_attention(): + """test_scaled_dot_product_attention""" + torch.set_grad_enabled(False) + + def test_fn(attn_mask=None, is_causal=False, scale=None): + return lambda query, key, value: torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, is_causal=is_causal, scale=scale + ) + + L, S, E, Ev = 5, 7, 11, 13 + query_4d = torch.randn(2, 3, L, E) + query_3d = torch.randn(3, L, E) + key_4d = torch.randn(2, 3, S, E) + key_3d = torch.randn(3, S, E) + value_4d = torch.randn(2, 3, S, Ev) + value_3d = torch.randn(3, S, Ev) + + verify_model(test_fn(), [query_4d, key_4d, value_4d]) + verify_model(test_fn(), [query_4d, key_4d, value_3d]) + verify_model(test_fn(), [query_4d, key_3d, value_4d]) + verify_model(test_fn(), [query_4d, key_3d, value_3d]) + verify_model(test_fn(), [query_3d, key_4d, value_4d]) + verify_model(test_fn(), [query_3d, key_4d, value_3d]) + verify_model(test_fn(), [query_3d, key_3d, value_4d]) + verify_model(test_fn(), [query_3d, key_3d, value_3d]) + + verify_model(test_fn(is_causal=True), [query_4d, key_4d, value_4d]) + verify_model(test_fn(is_causal=True), [query_4d, key_4d, value_3d]) + verify_model(test_fn(is_causal=True), [query_4d, key_3d, value_4d]) + verify_model(test_fn(is_causal=True), [query_4d, key_3d, value_3d]) + verify_model(test_fn(is_causal=True), [query_3d, key_4d, value_4d]) + verify_model(test_fn(is_causal=True), [query_3d, key_4d, value_3d]) + verify_model(test_fn(is_causal=True), [query_3d, key_3d, value_4d]) + verify_model(test_fn(is_causal=True), [query_3d, key_3d, value_3d]) + + # Test with explicit attn_mask + attn_mask = torch.ones((L, S), dtype=torch.bool).tril(diagonal=0) + verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_4d, value_4d]) + verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_4d, value_3d]) + verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_3d, value_4d]) + verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_3d, value_3d]) + verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_4d, value_4d]) + verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_4d, value_3d]) + verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_3d, value_4d]) + verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_3d, value_3d]) + + scale = 0.5 + verify_model(test_fn(scale=scale), [query_4d, key_4d, value_4d]) + verify_model(test_fn(scale=scale), [query_4d, key_4d, value_3d]) + verify_model(test_fn(scale=scale), [query_4d, key_3d, value_4d]) + verify_model(test_fn(scale=scale), [query_4d, key_3d, value_3d]) + verify_model(test_fn(scale=scale), [query_3d, key_4d, value_4d]) + verify_model(test_fn(scale=scale), [query_3d, key_4d, value_3d]) + verify_model(test_fn(scale=scale), [query_3d, key_3d, value_4d]) + verify_model(test_fn(scale=scale), [query_3d, key_3d, value_3d]) + + # Test with float64 + query_4d = torch.randn(2, 3, L, E, dtype=torch.float64) + query_3d = torch.randn(3, L, E, dtype=torch.float64) + key_4d = torch.randn(2, 3, S, E, dtype=torch.float64) + key_3d = torch.randn(3, S, E, dtype=torch.float64) + value_4d = torch.randn(2, 3, S, Ev, dtype=torch.float64) + value_3d = torch.randn(3, S, Ev, dtype=torch.float64) + verify_model(test_fn(), [query_4d, key_4d, value_4d]) + verify_model(test_fn(), [query_4d, key_4d, value_3d]) + verify_model(test_fn(), [query_4d, key_3d, value_4d]) + verify_model(test_fn(), [query_4d, key_3d, value_3d]) + verify_model(test_fn(), [query_3d, key_4d, value_4d]) + verify_model(test_fn(), [query_3d, key_4d, value_3d]) + verify_model(test_fn(), [query_3d, key_3d, value_4d]) + verify_model(test_fn(), [query_3d, key_3d, value_3d]) + + # Test with larger tensors + L, S, E, Ev = 128, 128, 64, 64 + query_4d = torch.randn(32, 8, L, E) + query_3d = torch.randn(8, L, E) + key_4d = torch.randn(32, 8, S, E) + key_3d = torch.randn(8, S, E) + value_4d = torch.randn(32, 8, S, Ev) + value_3d = torch.randn(8, S, Ev) + verify_model(test_fn(), [query_4d, key_4d, value_4d]) + verify_model(test_fn(), [query_4d, key_4d, value_3d]) + verify_model(test_fn(), [query_4d, key_3d, value_4d]) + verify_model(test_fn(), [query_4d, key_3d, value_3d]) + verify_model(test_fn(), [query_3d, key_4d, value_4d]) + verify_model(test_fn(), [query_3d, key_4d, value_3d]) + verify_model(test_fn(), [query_3d, key_3d, value_4d]) + verify_model(test_fn(), [query_3d, key_3d, value_3d]) + + class TestSetSpan: """test structural equal between translated / hand-crafted relay IR with span tagged.""" From 5f4d23d9da0c8b056e2df490fe23864c762ba88f Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Tue, 28 Nov 2023 20:23:49 +0900 Subject: [PATCH 2/5] convert `_expr.Constant` into `int` in the unflattened_size --- python/tvm/relay/frontend/pytorch.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index c82cf1cffbd7..a0b4f6d60650 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1555,15 +1555,21 @@ def unflatten(self, inputs, input_types): dim = dim if dim >= 0 else len(dshape) + dim assert len(dshape) > dim >= 0 - assert unflattened_size.count(-1) <= 1 + new_unflattened_size = [] + for s in unflattened_size: + if isinstance(s, _expr.Constant): + s = s.data.numpy().item() + new_unflattened_size.append(s) + + assert new_unflattened_size.count(-1) <= 1 - mult = np.multiply.reduce(unflattened_size) + mult = np.multiply.reduce(new_unflattened_size) if mult < 0: assert dshape[dim] % mult == 0 else: assert dshape[dim] == mult - new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :] + new_shape = dshape[:dim] + tuple(new_unflattened_size) + dshape[dim + 1 :] out = _op.reshape(data, new_shape) return out From 215cc721a0919b3a61da1fc7329d17dedf2ed9f4 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Tue, 28 Nov 2023 20:23:59 +0900 Subject: [PATCH 3/5] enable `test_transformer` --- tests/python/frontend/pytorch/test_forward.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 3860cdd324df..2754b26a9776 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4776,7 +4776,6 @@ def test_fn(x, mask): verify_model(test_fn, [inp.to(torch.float64), inp > 0.5]) -@pytest.mark.skip(reason="unsupported op: 'aten::scaled_dot_product_attention'") def test_transformer(): """test_transformer""" model = torch.nn.Transformer(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6) From d46a6fcf1a134e5f74e0e4a251068c39dddaf00e Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 29 Nov 2023 10:29:16 +0900 Subject: [PATCH 4/5] explicit scale not support with `torch==2.0.0` --- python/tvm/relay/frontend/pytorch.py | 7 ++++++- tests/python/frontend/pytorch/test_forward.py | 14 ++------------ 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index a0b4f6d60650..18e628d5659a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3904,7 +3904,12 @@ def scaled_dot_product_attention(self, inputs, input_types): attn_mask = inputs[3] dropout_p = inputs[4] is_causal = inputs[5] - scale = inputs[6] + + # Explicit scale can be used from torch>=2.1.0 + if len(inputs) == 7: + scale = inputs[6] + else: + scale = None assert ( input_types[0] == input_types[1] == input_types[2] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 2754b26a9776..d14359472ddd 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5493,9 +5493,9 @@ def test_scaled_dot_product_attention(): """test_scaled_dot_product_attention""" torch.set_grad_enabled(False) - def test_fn(attn_mask=None, is_causal=False, scale=None): + def test_fn(attn_mask=None, is_causal=False): return lambda query, key, value: torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=attn_mask, is_causal=is_causal, scale=scale + query, key, value, attn_mask=attn_mask, is_causal=is_causal ) L, S, E, Ev = 5, 7, 11, 13 @@ -5535,16 +5535,6 @@ def test_fn(attn_mask=None, is_causal=False, scale=None): verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_3d, value_4d]) verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_3d, value_3d]) - scale = 0.5 - verify_model(test_fn(scale=scale), [query_4d, key_4d, value_4d]) - verify_model(test_fn(scale=scale), [query_4d, key_4d, value_3d]) - verify_model(test_fn(scale=scale), [query_4d, key_3d, value_4d]) - verify_model(test_fn(scale=scale), [query_4d, key_3d, value_3d]) - verify_model(test_fn(scale=scale), [query_3d, key_4d, value_4d]) - verify_model(test_fn(scale=scale), [query_3d, key_4d, value_3d]) - verify_model(test_fn(scale=scale), [query_3d, key_3d, value_4d]) - verify_model(test_fn(scale=scale), [query_3d, key_3d, value_3d]) - # Test with float64 query_4d = torch.randn(2, 3, L, E, dtype=torch.float64) query_3d = torch.randn(3, L, E, dtype=torch.float64) From 9e18215d2d7692a4c8b8247679b2e620f397cffd Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 29 Nov 2023 10:33:11 +0900 Subject: [PATCH 5/5] fix error on gpu --- tests/python/frontend/pytorch/test_forward.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d14359472ddd..6d4ed4539bc0 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5526,6 +5526,8 @@ def test_fn(attn_mask=None, is_causal=False): # Test with explicit attn_mask attn_mask = torch.ones((L, S), dtype=torch.bool).tril(diagonal=0) + if torch.cuda.is_available(): + attn_mask = attn_mask.cuda() verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_4d, value_4d]) verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_4d, value_3d]) verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_3d, value_4d])