Skip to content

Commit

Permalink
add support for aten::scaled_dot_product_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
mshr-h committed Nov 27, 2023
1 parent 3fd3a63 commit c018d43
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 0 deletions.
72 changes: 72 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3891,6 +3891,77 @@ def linalg_vector_norm(self, inputs, input_types):
reci_ord,
)

def scaled_dot_product_attention(self, inputs, input_types):
query = inputs[0]
key = inputs[1]
value = inputs[2]
attn_mask = inputs[3]
dropout_p = inputs[4]
is_causal = inputs[5]
scale = inputs[6]

assert dropout_p == 0.0, "Only dropout_p==0.0 supported"
assert (
input_types[0] == input_types[1] == input_types[2]
), "Expected query, key, and value to have the same dtype"

query_shape = self.infer_shape_with_prelude(query)
key_shape = self.infer_shape_with_prelude(key)
value_shape = self.infer_shape_with_prelude(value)
assert (
len(query_shape) == len(key_shape) == len(value_shape)
), "query, key and value should be the same number of dim at the moment"

L = query_shape[-2]
S = key_shape[-2]
attn_bias = _op.full(_expr.const(0.0, dtype=input_types[0]), (L, S))

if scale is None:
scale_factor = _expr.const(1 / math.sqrt(query_shape[-1]), dtype=input_types[0])
else:
scale_factor = _expr.const(scale, dtype=input_types[0])

if is_causal:
assert attn_mask is None
temp_mask = _op.full(_expr.const(True), [L, S], dtype="bool")
temp_mask = _op.trilu(temp_mask, 0, upper=False)
temp_mask = _op.cast(temp_mask, dtype="bool")
temp_mask = _op.logical_not(temp_mask)
fill_value = _op.cast(_expr.const(float("-inf")), dtype=input_types[0])
attn_bias = _op.where(temp_mask, fill_value, attn_bias)
attn_bias = _op.cast(attn_bias, input_types[0])

if attn_mask is not None:
if input_types[3] == "bool":
attn_mask = _op.logical_not(attn_mask)
fill_value = _op.cast(_expr.const(float("-inf")), dtype=input_types[0])
attn_bias = _op.where(attn_mask, fill_value, attn_bias)
else:
attn_bias = _op.add(attn_bias, attn_mask)

if len(query_shape) == 4:
query = _op.reshape(query, (-3, -2))
key = _op.reshape(key, (-3, -2))
else:
query = _op.reshape(query, (-1, -2))
key = _op.reshape(key, (-1, -2))

attn_weight = _op.nn.batch_matmul(query, key)
if len(query_shape) == 4:
attn_weight = _op.reshape(attn_weight, (-4, query_shape[0], -1, -2))
attn_weight = _op.multiply(attn_weight, scale_factor)
attn_weight = _op.add(attn_weight, attn_bias)
attn_weight = _op.nn.softmax(attn_weight)

if len(query_shape) == 4:
attn_weight = _op.reshape(attn_weight, (-3, -2))
value = _op.reshape(value, (-3, -2))

attn_weight = _op.nn.batch_matmul(attn_weight, value, transpose_b=False)
if len(query_shape) == 4:
attn_weight = _op.reshape(attn_weight, (-4, key_shape[0], -1, -2))
return attn_weight

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -4167,6 +4238,7 @@ def create_convert_map(self):
"aten::copy_": self.inplace_copy,
"aten::swapaxes": self.transpose,
"aten::linalg_vector_norm": self.linalg_vector_norm,
"aten::scaled_dot_product_attention": self.scaled_dot_product_attention,
}

def update_convert_map(self, custom_map):
Expand Down
65 changes: 65 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5489,6 +5489,71 @@ def test_fn(order):
verify_model(test_fn(order=0), input_data=input_data)


@tvm.testing.uses_gpu
def test_scaled_dot_product_attention():
"""test_scaled_dot_product_attention"""
torch.set_grad_enabled(False)

class SDPA(torch.nn.Module):
def forward(self, query, key, value):
return torch.nn.functional.scaled_dot_product_attention(query, key, value)

class SDPAScale(torch.nn.Module):
def forward(self, query, key, value):
return torch.nn.functional.scaled_dot_product_attention(query, key, value, scale=0.5)

class SDPACausalAttention(torch.nn.Module):
def forward(self, query, key, value):
return torch.nn.functional.scaled_dot_product_attention(
query, key, value, is_causal=True
)

class SDPABoolAttnMask(torch.nn.Module):
def forward(self, query, key, value):
L, S = query.shape[-2], key.shape[-2]
attn_mask = torch.ones((L, S), dtype=torch.bool).tril(diagonal=0)
return torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask
)

class SDPAFloatAttnMask(torch.nn.Module):
def forward(self, query, key, value):
L, S = query.shape[-2], key.shape[-2]
attn_mask = torch.ones((L, S), dtype=torch.float32)
return torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask
)

# Test with 4D input
L, S = 128, 128
query = torch.rand(32, 8, L, 64, dtype=torch.float32)
key = torch.rand(32, 8, S, 64, dtype=torch.float32)
value = torch.rand(32, 8, 128, 64, dtype=torch.float32)
verify_model(SDPA().float().eval(), [query, key, value])
verify_model(SDPAScale().float().eval(), [query, key, value])
verify_model(SDPACausalAttention().float().eval(), [query, key, value])
verify_model(SDPABoolAttnMask().float().eval(), [query, key, value])
verify_model(SDPAFloatAttnMask().float().eval(), [query, key, value])

# Test with 3D input
L, S = 3, 3
query = torch.randn(2, L, 8, dtype=torch.float32)
key = torch.randn(2, S, 8, dtype=torch.float32)
value = torch.randn(2, 3, 8, dtype=torch.float32)
verify_model(SDPA().float().eval(), [query, key, value])
verify_model(SDPAScale().float().eval(), [query, key, value])
verify_model(SDPACausalAttention().float().eval(), [query, key, value])
verify_model(SDPABoolAttnMask().float().eval(), [query, key, value])
verify_model(SDPAFloatAttnMask().float().eval(), [query, key, value])

# Test with double type
L, S = 3, 3
query = torch.randn(2, L, 8, dtype=torch.float64)
key = torch.randn(2, S, 8, dtype=torch.float64)
value = torch.randn(2, 3, 8, dtype=torch.float64)
verify_model(SDPA().double().eval(), [query, key, value])


class TestSetSpan:
"""test structural equal between translated / hand-crafted relay IR with span tagged."""

Expand Down

0 comments on commit c018d43

Please sign in to comment.