From c018d4372bce17e86e7547c7a1f8f95ef863b12d Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Mon, 27 Nov 2023 09:52:06 +0900 Subject: [PATCH] add support for `aten::scaled_dot_product_attention` --- python/tvm/relay/frontend/pytorch.py | 72 +++++++++++++++++++ tests/python/frontend/pytorch/test_forward.py | 65 +++++++++++++++++ 2 files changed, 137 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 9374a24912805..03afcfb9260c1 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3891,6 +3891,77 @@ def linalg_vector_norm(self, inputs, input_types): reci_ord, ) + def scaled_dot_product_attention(self, inputs, input_types): + query = inputs[0] + key = inputs[1] + value = inputs[2] + attn_mask = inputs[3] + dropout_p = inputs[4] + is_causal = inputs[5] + scale = inputs[6] + + assert dropout_p == 0.0, "Only dropout_p==0.0 supported" + assert ( + input_types[0] == input_types[1] == input_types[2] + ), "Expected query, key, and value to have the same dtype" + + query_shape = self.infer_shape_with_prelude(query) + key_shape = self.infer_shape_with_prelude(key) + value_shape = self.infer_shape_with_prelude(value) + assert ( + len(query_shape) == len(key_shape) == len(value_shape) + ), "query, key and value should be the same number of dim at the moment" + + L = query_shape[-2] + S = key_shape[-2] + attn_bias = _op.full(_expr.const(0.0, dtype=input_types[0]), (L, S)) + + if scale is None: + scale_factor = _expr.const(1 / math.sqrt(query_shape[-1]), dtype=input_types[0]) + else: + scale_factor = _expr.const(scale, dtype=input_types[0]) + + if is_causal: + assert attn_mask is None + temp_mask = _op.full(_expr.const(True), [L, S], dtype="bool") + temp_mask = _op.trilu(temp_mask, 0, upper=False) + temp_mask = _op.cast(temp_mask, dtype="bool") + temp_mask = _op.logical_not(temp_mask) + fill_value = _op.cast(_expr.const(float("-inf")), dtype=input_types[0]) + attn_bias = _op.where(temp_mask, fill_value, attn_bias) + attn_bias = _op.cast(attn_bias, input_types[0]) + + if attn_mask is not None: + if input_types[3] == "bool": + attn_mask = _op.logical_not(attn_mask) + fill_value = _op.cast(_expr.const(float("-inf")), dtype=input_types[0]) + attn_bias = _op.where(attn_mask, fill_value, attn_bias) + else: + attn_bias = _op.add(attn_bias, attn_mask) + + if len(query_shape) == 4: + query = _op.reshape(query, (-3, -2)) + key = _op.reshape(key, (-3, -2)) + else: + query = _op.reshape(query, (-1, -2)) + key = _op.reshape(key, (-1, -2)) + + attn_weight = _op.nn.batch_matmul(query, key) + if len(query_shape) == 4: + attn_weight = _op.reshape(attn_weight, (-4, query_shape[0], -1, -2)) + attn_weight = _op.multiply(attn_weight, scale_factor) + attn_weight = _op.add(attn_weight, attn_bias) + attn_weight = _op.nn.softmax(attn_weight) + + if len(query_shape) == 4: + attn_weight = _op.reshape(attn_weight, (-3, -2)) + value = _op.reshape(value, (-3, -2)) + + attn_weight = _op.nn.batch_matmul(attn_weight, value, transpose_b=False) + if len(query_shape) == 4: + attn_weight = _op.reshape(attn_weight, (-4, key_shape[0], -1, -2)) + return attn_weight + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -4167,6 +4238,7 @@ def create_convert_map(self): "aten::copy_": self.inplace_copy, "aten::swapaxes": self.transpose, "aten::linalg_vector_norm": self.linalg_vector_norm, + "aten::scaled_dot_product_attention": self.scaled_dot_product_attention, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d9ecbce265875..632710d8b7206 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5489,6 +5489,71 @@ def test_fn(order): verify_model(test_fn(order=0), input_data=input_data) +@tvm.testing.uses_gpu +def test_scaled_dot_product_attention(): + """test_scaled_dot_product_attention""" + torch.set_grad_enabled(False) + + class SDPA(torch.nn.Module): + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention(query, key, value) + + class SDPAScale(torch.nn.Module): + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention(query, key, value, scale=0.5) + + class SDPACausalAttention(torch.nn.Module): + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, is_causal=True + ) + + class SDPABoolAttnMask(torch.nn.Module): + def forward(self, query, key, value): + L, S = query.shape[-2], key.shape[-2] + attn_mask = torch.ones((L, S), dtype=torch.bool).tril(diagonal=0) + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask + ) + + class SDPAFloatAttnMask(torch.nn.Module): + def forward(self, query, key, value): + L, S = query.shape[-2], key.shape[-2] + attn_mask = torch.ones((L, S), dtype=torch.float32) + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask + ) + + # Test with 4D input + L, S = 128, 128 + query = torch.rand(32, 8, L, 64, dtype=torch.float32) + key = torch.rand(32, 8, S, 64, dtype=torch.float32) + value = torch.rand(32, 8, 128, 64, dtype=torch.float32) + verify_model(SDPA().float().eval(), [query, key, value]) + verify_model(SDPAScale().float().eval(), [query, key, value]) + verify_model(SDPACausalAttention().float().eval(), [query, key, value]) + verify_model(SDPABoolAttnMask().float().eval(), [query, key, value]) + verify_model(SDPAFloatAttnMask().float().eval(), [query, key, value]) + + # Test with 3D input + L, S = 3, 3 + query = torch.randn(2, L, 8, dtype=torch.float32) + key = torch.randn(2, S, 8, dtype=torch.float32) + value = torch.randn(2, 3, 8, dtype=torch.float32) + verify_model(SDPA().float().eval(), [query, key, value]) + verify_model(SDPAScale().float().eval(), [query, key, value]) + verify_model(SDPACausalAttention().float().eval(), [query, key, value]) + verify_model(SDPABoolAttnMask().float().eval(), [query, key, value]) + verify_model(SDPAFloatAttnMask().float().eval(), [query, key, value]) + + # Test with double type + L, S = 3, 3 + query = torch.randn(2, L, 8, dtype=torch.float64) + key = torch.randn(2, S, 8, dtype=torch.float64) + value = torch.randn(2, 3, 8, dtype=torch.float64) + verify_model(SDPA().double().eval(), [query, key, value]) + + class TestSetSpan: """test structural equal between translated / hand-crafted relay IR with span tagged."""