diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 30711de0a760..18e628d5659a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1555,15 +1555,21 @@ def unflatten(self, inputs, input_types): dim = dim if dim >= 0 else len(dshape) + dim assert len(dshape) > dim >= 0 - assert unflattened_size.count(-1) <= 1 + new_unflattened_size = [] + for s in unflattened_size: + if isinstance(s, _expr.Constant): + s = s.data.numpy().item() + new_unflattened_size.append(s) + + assert new_unflattened_size.count(-1) <= 1 - mult = np.multiply.reduce(unflattened_size) + mult = np.multiply.reduce(new_unflattened_size) if mult < 0: assert dshape[dim] % mult == 0 else: assert dshape[dim] == mult - new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :] + new_shape = dshape[:dim] + tuple(new_unflattened_size) + dshape[dim + 1 :] out = _op.reshape(data, new_shape) return out @@ -3891,6 +3897,109 @@ def linalg_vector_norm(self, inputs, input_types): reci_ord, ) + def scaled_dot_product_attention(self, inputs, input_types): + query = inputs[0] + key = inputs[1] + value = inputs[2] + attn_mask = inputs[3] + dropout_p = inputs[4] + is_causal = inputs[5] + + # Explicit scale can be used from torch>=2.1.0 + if len(inputs) == 7: + scale = inputs[6] + else: + scale = None + + assert ( + input_types[0] == input_types[1] == input_types[2] + ), "Expected query, key, and value to have the same dtype" + + dtype = input_types[0] + assert dtype == "float32" or dtype == "float64", "Data type can be float32 or float64" + + query_shape = self.infer_shape_with_prelude(query) + key_shape = self.infer_shape_with_prelude(key) + value_shape = self.infer_shape_with_prelude(value) + assert 3 <= len(query_shape) <= 4, "Only 3D or 4D query supported" + assert 3 <= len(key_shape) <= 4, "Only 3D or 4D key supported" + assert 3 <= len(value_shape) <= 4, "Only 3D or 4D value supported" + + assert dropout_p == 0.0, "Only dropout_p==0.0 supported" + + L, S = query_shape[-2], key_shape[-2] + + if scale is None: + scale_factor = _expr.const(1 / math.sqrt(query_shape[-1]), dtype=dtype) + else: + scale_factor = _expr.const(scale, dtype=dtype) + + attn_bias = _op.full(_expr.const(0.0, dtype=dtype), (L, S)) + + if is_causal: + assert attn_mask is None, "Explicit attn_mask shouldn't be set when is_causal=True" + temp_mask = _op.full(_expr.const(True), [L, S], dtype="bool") + temp_mask = _op.trilu(temp_mask, 0, upper=False) + temp_mask = _op.cast(temp_mask, dtype="bool") + temp_mask = _op.logical_not(temp_mask) + fill_value = _op.cast(_expr.const(float("-inf")), dtype=dtype) + attn_bias = _op.where(temp_mask, fill_value, attn_bias) + attn_bias = _op.cast(attn_bias, dtype) + + if attn_mask is not None: + if input_types[3] == "bool": + attn_mask = _op.logical_not(attn_mask) + fill_value = _op.cast(_expr.const(float("-inf")), dtype=dtype) + attn_bias = _op.where(attn_mask, fill_value, attn_bias) + else: + attn_bias = _op.add(attn_bias, attn_mask) + + if len(query_shape) < len(key_shape): + batch_size = key_shape[0] + else: + batch_size = query_shape[0] + if len(query_shape) == 4 and len(key_shape) == 4: + query = _op.reshape(query, newshape=[-3, -2]) + key = _op.reshape(key, newshape=[-3, -2]) + if len(query_shape) == 3 and len(key_shape) == 4: + query = _op.broadcast_to(query, shape=(batch_size,) + query_shape) + query = _op.reshape(query, newshape=[-3, -2]) + key = _op.reshape(key, newshape=[-3, -2]) + if len(query_shape) == 4 and len(key_shape) == 3: + query = _op.reshape(query, newshape=[-3, -2]) + key = _op.broadcast_to(key, shape=(batch_size,) + key_shape) + key = _op.reshape(key, newshape=[-3, -2]) + attn_weight = _op.nn.batch_matmul(query, key) + if len(query_shape) == 4 or len(key_shape) == 4: + attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2]) + attn_weight = _op.squeeze(attn_weight, axis=[]) + + attn_weight = _op.multiply(attn_weight, scale_factor) + attn_weight = _op.add(attn_weight, attn_bias) + attn_weight = _op.nn.softmax(attn_weight) + attn_weight = _op.nn.dropout(attn_weight, rate=dropout_p) + + aw_shape = self.infer_shape_with_prelude(attn_weight) + if len(aw_shape) < len(value_shape): + batch_size = value_shape[0] + else: + batch_size = aw_shape[0] + if len(aw_shape) == 4 and len(value_shape) == 4: + attn_weight = _op.reshape(attn_weight, newshape=[-3, -2]) + value = _op.reshape(value, newshape=[-3, -2]) + if len(aw_shape) == 3 and len(value_shape) == 4: + attn_weight = _op.broadcast_to(attn_weight, shape=(batch_size,) + aw_shape) + attn_weight = _op.reshape(attn_weight, newshape=[-3, -2]) + value = _op.reshape(value, newshape=[-3, -2]) + if len(aw_shape) == 4 and len(value_shape) == 3: + attn_weight = _op.reshape(attn_weight, newshape=[-3, -2]) + value = _op.broadcast_to(value, shape=(batch_size,) + value_shape) + value = _op.reshape(value, newshape=[-3, -2]) + attn_weight = _op.nn.batch_matmul(attn_weight, value, transpose_b=False) + if len(aw_shape) == 4 or len(value_shape) == 4: + attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2]) + return attn_weight + # Operator mappings def create_convert_map(self): self.convert_map = { @@ -4167,6 +4276,7 @@ def create_convert_map(self): "aten::copy_": self.inplace_copy, "aten::swapaxes": self.transpose, "aten::linalg_vector_norm": self.linalg_vector_norm, + "aten::scaled_dot_product_attention": self.scaled_dot_product_attention, } def update_convert_map(self, custom_map): diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d9ecbce26587..6d4ed4539bc0 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -4776,7 +4776,6 @@ def test_fn(x, mask): verify_model(test_fn, [inp.to(torch.float64), inp > 0.5]) -@pytest.mark.skip(reason="unsupported op: 'aten::scaled_dot_product_attention'") def test_transformer(): """test_transformer""" model = torch.nn.Transformer(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6) @@ -5489,6 +5488,89 @@ def test_fn(order): verify_model(test_fn(order=0), input_data=input_data) +@tvm.testing.uses_gpu +def test_scaled_dot_product_attention(): + """test_scaled_dot_product_attention""" + torch.set_grad_enabled(False) + + def test_fn(attn_mask=None, is_causal=False): + return lambda query, key, value: torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=attn_mask, is_causal=is_causal + ) + + L, S, E, Ev = 5, 7, 11, 13 + query_4d = torch.randn(2, 3, L, E) + query_3d = torch.randn(3, L, E) + key_4d = torch.randn(2, 3, S, E) + key_3d = torch.randn(3, S, E) + value_4d = torch.randn(2, 3, S, Ev) + value_3d = torch.randn(3, S, Ev) + + verify_model(test_fn(), [query_4d, key_4d, value_4d]) + verify_model(test_fn(), [query_4d, key_4d, value_3d]) + verify_model(test_fn(), [query_4d, key_3d, value_4d]) + verify_model(test_fn(), [query_4d, key_3d, value_3d]) + verify_model(test_fn(), [query_3d, key_4d, value_4d]) + verify_model(test_fn(), [query_3d, key_4d, value_3d]) + verify_model(test_fn(), [query_3d, key_3d, value_4d]) + verify_model(test_fn(), [query_3d, key_3d, value_3d]) + + verify_model(test_fn(is_causal=True), [query_4d, key_4d, value_4d]) + verify_model(test_fn(is_causal=True), [query_4d, key_4d, value_3d]) + verify_model(test_fn(is_causal=True), [query_4d, key_3d, value_4d]) + verify_model(test_fn(is_causal=True), [query_4d, key_3d, value_3d]) + verify_model(test_fn(is_causal=True), [query_3d, key_4d, value_4d]) + verify_model(test_fn(is_causal=True), [query_3d, key_4d, value_3d]) + verify_model(test_fn(is_causal=True), [query_3d, key_3d, value_4d]) + verify_model(test_fn(is_causal=True), [query_3d, key_3d, value_3d]) + + # Test with explicit attn_mask + attn_mask = torch.ones((L, S), dtype=torch.bool).tril(diagonal=0) + if torch.cuda.is_available(): + attn_mask = attn_mask.cuda() + verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_4d, value_4d]) + verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_4d, value_3d]) + verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_3d, value_4d]) + verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_3d, value_3d]) + verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_4d, value_4d]) + verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_4d, value_3d]) + verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_3d, value_4d]) + verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_3d, value_3d]) + + # Test with float64 + query_4d = torch.randn(2, 3, L, E, dtype=torch.float64) + query_3d = torch.randn(3, L, E, dtype=torch.float64) + key_4d = torch.randn(2, 3, S, E, dtype=torch.float64) + key_3d = torch.randn(3, S, E, dtype=torch.float64) + value_4d = torch.randn(2, 3, S, Ev, dtype=torch.float64) + value_3d = torch.randn(3, S, Ev, dtype=torch.float64) + verify_model(test_fn(), [query_4d, key_4d, value_4d]) + verify_model(test_fn(), [query_4d, key_4d, value_3d]) + verify_model(test_fn(), [query_4d, key_3d, value_4d]) + verify_model(test_fn(), [query_4d, key_3d, value_3d]) + verify_model(test_fn(), [query_3d, key_4d, value_4d]) + verify_model(test_fn(), [query_3d, key_4d, value_3d]) + verify_model(test_fn(), [query_3d, key_3d, value_4d]) + verify_model(test_fn(), [query_3d, key_3d, value_3d]) + + # Test with larger tensors + L, S, E, Ev = 128, 128, 64, 64 + query_4d = torch.randn(32, 8, L, E) + query_3d = torch.randn(8, L, E) + key_4d = torch.randn(32, 8, S, E) + key_3d = torch.randn(8, S, E) + value_4d = torch.randn(32, 8, S, Ev) + value_3d = torch.randn(8, S, Ev) + verify_model(test_fn(), [query_4d, key_4d, value_4d]) + verify_model(test_fn(), [query_4d, key_4d, value_3d]) + verify_model(test_fn(), [query_4d, key_3d, value_4d]) + verify_model(test_fn(), [query_4d, key_3d, value_3d]) + verify_model(test_fn(), [query_3d, key_4d, value_4d]) + verify_model(test_fn(), [query_3d, key_4d, value_3d]) + verify_model(test_fn(), [query_3d, key_3d, value_4d]) + verify_model(test_fn(), [query_3d, key_3d, value_3d]) + + class TestSetSpan: """test structural equal between translated / hand-crafted relay IR with span tagged."""