Skip to content

Commit

Permalink
[Relay][Pytorch] Add support for aten::scaled_dot_product_attention (
Browse files Browse the repository at this point in the history
…#16143)

* add support for `aten::scaled_dot_product_attention`

* convert `_expr.Constant` into `int` in the unflattened_size

* enable `test_transformer`

* explicit scale not support with `torch==2.0.0`

* fix error on gpu
  • Loading branch information
mshr-h authored Nov 29, 2023
1 parent 3136ff4 commit 97ddd66
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 4 deletions.
116 changes: 113 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,15 +1555,21 @@ def unflatten(self, inputs, input_types):
dim = dim if dim >= 0 else len(dshape) + dim
assert len(dshape) > dim >= 0

assert unflattened_size.count(-1) <= 1
new_unflattened_size = []
for s in unflattened_size:
if isinstance(s, _expr.Constant):
s = s.data.numpy().item()
new_unflattened_size.append(s)

assert new_unflattened_size.count(-1) <= 1

mult = np.multiply.reduce(unflattened_size)
mult = np.multiply.reduce(new_unflattened_size)
if mult < 0:
assert dshape[dim] % mult == 0
else:
assert dshape[dim] == mult

new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :]
new_shape = dshape[:dim] + tuple(new_unflattened_size) + dshape[dim + 1 :]
out = _op.reshape(data, new_shape)
return out

Expand Down Expand Up @@ -3891,6 +3897,109 @@ def linalg_vector_norm(self, inputs, input_types):
reci_ord,
)

def scaled_dot_product_attention(self, inputs, input_types):
query = inputs[0]
key = inputs[1]
value = inputs[2]
attn_mask = inputs[3]
dropout_p = inputs[4]
is_causal = inputs[5]

# Explicit scale can be used from torch>=2.1.0
if len(inputs) == 7:
scale = inputs[6]
else:
scale = None

assert (
input_types[0] == input_types[1] == input_types[2]
), "Expected query, key, and value to have the same dtype"

dtype = input_types[0]
assert dtype == "float32" or dtype == "float64", "Data type can be float32 or float64"

query_shape = self.infer_shape_with_prelude(query)
key_shape = self.infer_shape_with_prelude(key)
value_shape = self.infer_shape_with_prelude(value)
assert 3 <= len(query_shape) <= 4, "Only 3D or 4D query supported"
assert 3 <= len(key_shape) <= 4, "Only 3D or 4D key supported"
assert 3 <= len(value_shape) <= 4, "Only 3D or 4D value supported"

assert dropout_p == 0.0, "Only dropout_p==0.0 supported"

L, S = query_shape[-2], key_shape[-2]

if scale is None:
scale_factor = _expr.const(1 / math.sqrt(query_shape[-1]), dtype=dtype)
else:
scale_factor = _expr.const(scale, dtype=dtype)

attn_bias = _op.full(_expr.const(0.0, dtype=dtype), (L, S))

if is_causal:
assert attn_mask is None, "Explicit attn_mask shouldn't be set when is_causal=True"
temp_mask = _op.full(_expr.const(True), [L, S], dtype="bool")
temp_mask = _op.trilu(temp_mask, 0, upper=False)
temp_mask = _op.cast(temp_mask, dtype="bool")
temp_mask = _op.logical_not(temp_mask)
fill_value = _op.cast(_expr.const(float("-inf")), dtype=dtype)
attn_bias = _op.where(temp_mask, fill_value, attn_bias)
attn_bias = _op.cast(attn_bias, dtype)

if attn_mask is not None:
if input_types[3] == "bool":
attn_mask = _op.logical_not(attn_mask)
fill_value = _op.cast(_expr.const(float("-inf")), dtype=dtype)
attn_bias = _op.where(attn_mask, fill_value, attn_bias)
else:
attn_bias = _op.add(attn_bias, attn_mask)

if len(query_shape) < len(key_shape):
batch_size = key_shape[0]
else:
batch_size = query_shape[0]
if len(query_shape) == 4 and len(key_shape) == 4:
query = _op.reshape(query, newshape=[-3, -2])
key = _op.reshape(key, newshape=[-3, -2])
if len(query_shape) == 3 and len(key_shape) == 4:
query = _op.broadcast_to(query, shape=(batch_size,) + query_shape)
query = _op.reshape(query, newshape=[-3, -2])
key = _op.reshape(key, newshape=[-3, -2])
if len(query_shape) == 4 and len(key_shape) == 3:
query = _op.reshape(query, newshape=[-3, -2])
key = _op.broadcast_to(key, shape=(batch_size,) + key_shape)
key = _op.reshape(key, newshape=[-3, -2])
attn_weight = _op.nn.batch_matmul(query, key)
if len(query_shape) == 4 or len(key_shape) == 4:
attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2])
attn_weight = _op.squeeze(attn_weight, axis=[])

attn_weight = _op.multiply(attn_weight, scale_factor)
attn_weight = _op.add(attn_weight, attn_bias)
attn_weight = _op.nn.softmax(attn_weight)
attn_weight = _op.nn.dropout(attn_weight, rate=dropout_p)

aw_shape = self.infer_shape_with_prelude(attn_weight)
if len(aw_shape) < len(value_shape):
batch_size = value_shape[0]
else:
batch_size = aw_shape[0]
if len(aw_shape) == 4 and len(value_shape) == 4:
attn_weight = _op.reshape(attn_weight, newshape=[-3, -2])
value = _op.reshape(value, newshape=[-3, -2])
if len(aw_shape) == 3 and len(value_shape) == 4:
attn_weight = _op.broadcast_to(attn_weight, shape=(batch_size,) + aw_shape)
attn_weight = _op.reshape(attn_weight, newshape=[-3, -2])
value = _op.reshape(value, newshape=[-3, -2])
if len(aw_shape) == 4 and len(value_shape) == 3:
attn_weight = _op.reshape(attn_weight, newshape=[-3, -2])
value = _op.broadcast_to(value, shape=(batch_size,) + value_shape)
value = _op.reshape(value, newshape=[-3, -2])
attn_weight = _op.nn.batch_matmul(attn_weight, value, transpose_b=False)
if len(aw_shape) == 4 or len(value_shape) == 4:
attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2])
return attn_weight

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -4167,6 +4276,7 @@ def create_convert_map(self):
"aten::copy_": self.inplace_copy,
"aten::swapaxes": self.transpose,
"aten::linalg_vector_norm": self.linalg_vector_norm,
"aten::scaled_dot_product_attention": self.scaled_dot_product_attention,
}

def update_convert_map(self, custom_map):
Expand Down
84 changes: 83 additions & 1 deletion tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4776,7 +4776,6 @@ def test_fn(x, mask):
verify_model(test_fn, [inp.to(torch.float64), inp > 0.5])


@pytest.mark.skip(reason="unsupported op: 'aten::scaled_dot_product_attention'")
def test_transformer():
"""test_transformer"""
model = torch.nn.Transformer(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6)
Expand Down Expand Up @@ -5489,6 +5488,89 @@ def test_fn(order):
verify_model(test_fn(order=0), input_data=input_data)


@tvm.testing.uses_gpu
def test_scaled_dot_product_attention():
"""test_scaled_dot_product_attention"""
torch.set_grad_enabled(False)

def test_fn(attn_mask=None, is_causal=False):
return lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, is_causal=is_causal
)

L, S, E, Ev = 5, 7, 11, 13
query_4d = torch.randn(2, 3, L, E)
query_3d = torch.randn(3, L, E)
key_4d = torch.randn(2, 3, S, E)
key_3d = torch.randn(3, S, E)
value_4d = torch.randn(2, 3, S, Ev)
value_3d = torch.randn(3, S, Ev)

verify_model(test_fn(), [query_4d, key_4d, value_4d])
verify_model(test_fn(), [query_4d, key_4d, value_3d])
verify_model(test_fn(), [query_4d, key_3d, value_4d])
verify_model(test_fn(), [query_4d, key_3d, value_3d])
verify_model(test_fn(), [query_3d, key_4d, value_4d])
verify_model(test_fn(), [query_3d, key_4d, value_3d])
verify_model(test_fn(), [query_3d, key_3d, value_4d])
verify_model(test_fn(), [query_3d, key_3d, value_3d])

verify_model(test_fn(is_causal=True), [query_4d, key_4d, value_4d])
verify_model(test_fn(is_causal=True), [query_4d, key_4d, value_3d])
verify_model(test_fn(is_causal=True), [query_4d, key_3d, value_4d])
verify_model(test_fn(is_causal=True), [query_4d, key_3d, value_3d])
verify_model(test_fn(is_causal=True), [query_3d, key_4d, value_4d])
verify_model(test_fn(is_causal=True), [query_3d, key_4d, value_3d])
verify_model(test_fn(is_causal=True), [query_3d, key_3d, value_4d])
verify_model(test_fn(is_causal=True), [query_3d, key_3d, value_3d])

# Test with explicit attn_mask
attn_mask = torch.ones((L, S), dtype=torch.bool).tril(diagonal=0)
if torch.cuda.is_available():
attn_mask = attn_mask.cuda()
verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_4d, value_4d])
verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_4d, value_3d])
verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_3d, value_4d])
verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_3d, value_3d])
verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_4d, value_4d])
verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_4d, value_3d])
verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_3d, value_4d])
verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_3d, value_3d])

# Test with float64
query_4d = torch.randn(2, 3, L, E, dtype=torch.float64)
query_3d = torch.randn(3, L, E, dtype=torch.float64)
key_4d = torch.randn(2, 3, S, E, dtype=torch.float64)
key_3d = torch.randn(3, S, E, dtype=torch.float64)
value_4d = torch.randn(2, 3, S, Ev, dtype=torch.float64)
value_3d = torch.randn(3, S, Ev, dtype=torch.float64)
verify_model(test_fn(), [query_4d, key_4d, value_4d])
verify_model(test_fn(), [query_4d, key_4d, value_3d])
verify_model(test_fn(), [query_4d, key_3d, value_4d])
verify_model(test_fn(), [query_4d, key_3d, value_3d])
verify_model(test_fn(), [query_3d, key_4d, value_4d])
verify_model(test_fn(), [query_3d, key_4d, value_3d])
verify_model(test_fn(), [query_3d, key_3d, value_4d])
verify_model(test_fn(), [query_3d, key_3d, value_3d])

# Test with larger tensors
L, S, E, Ev = 128, 128, 64, 64
query_4d = torch.randn(32, 8, L, E)
query_3d = torch.randn(8, L, E)
key_4d = torch.randn(32, 8, S, E)
key_3d = torch.randn(8, S, E)
value_4d = torch.randn(32, 8, S, Ev)
value_3d = torch.randn(8, S, Ev)
verify_model(test_fn(), [query_4d, key_4d, value_4d])
verify_model(test_fn(), [query_4d, key_4d, value_3d])
verify_model(test_fn(), [query_4d, key_3d, value_4d])
verify_model(test_fn(), [query_4d, key_3d, value_3d])
verify_model(test_fn(), [query_3d, key_4d, value_4d])
verify_model(test_fn(), [query_3d, key_4d, value_3d])
verify_model(test_fn(), [query_3d, key_3d, value_4d])
verify_model(test_fn(), [query_3d, key_3d, value_3d])


class TestSetSpan:
"""test structural equal between translated / hand-crafted relay IR with span tagged."""

Expand Down

0 comments on commit 97ddd66

Please sign in to comment.