Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pytorch] Add support for aten::scaled_dot_product_attention #16143

Merged
merged 5 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 113 additions & 3 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,15 +1555,21 @@ def unflatten(self, inputs, input_types):
dim = dim if dim >= 0 else len(dshape) + dim
assert len(dshape) > dim >= 0

assert unflattened_size.count(-1) <= 1
new_unflattened_size = []
for s in unflattened_size:
if isinstance(s, _expr.Constant):
s = s.data.numpy().item()
new_unflattened_size.append(s)

assert new_unflattened_size.count(-1) <= 1

mult = np.multiply.reduce(unflattened_size)
mult = np.multiply.reduce(new_unflattened_size)
if mult < 0:
assert dshape[dim] % mult == 0
else:
assert dshape[dim] == mult

new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :]
new_shape = dshape[:dim] + tuple(new_unflattened_size) + dshape[dim + 1 :]
out = _op.reshape(data, new_shape)
return out

Expand Down Expand Up @@ -3891,6 +3897,109 @@ def linalg_vector_norm(self, inputs, input_types):
reci_ord,
)

def scaled_dot_product_attention(self, inputs, input_types):
query = inputs[0]
key = inputs[1]
value = inputs[2]
attn_mask = inputs[3]
dropout_p = inputs[4]
is_causal = inputs[5]

# Explicit scale can be used from torch>=2.1.0
if len(inputs) == 7:
scale = inputs[6]
else:
scale = None

assert (
input_types[0] == input_types[1] == input_types[2]
), "Expected query, key, and value to have the same dtype"

dtype = input_types[0]
assert dtype == "float32" or dtype == "float64", "Data type can be float32 or float64"

query_shape = self.infer_shape_with_prelude(query)
key_shape = self.infer_shape_with_prelude(key)
value_shape = self.infer_shape_with_prelude(value)
assert 3 <= len(query_shape) <= 4, "Only 3D or 4D query supported"
assert 3 <= len(key_shape) <= 4, "Only 3D or 4D key supported"
assert 3 <= len(value_shape) <= 4, "Only 3D or 4D value supported"

assert dropout_p == 0.0, "Only dropout_p==0.0 supported"

L, S = query_shape[-2], key_shape[-2]

if scale is None:
scale_factor = _expr.const(1 / math.sqrt(query_shape[-1]), dtype=dtype)
else:
scale_factor = _expr.const(scale, dtype=dtype)

attn_bias = _op.full(_expr.const(0.0, dtype=dtype), (L, S))

if is_causal:
assert attn_mask is None, "Explicit attn_mask shouldn't be set when is_causal=True"
temp_mask = _op.full(_expr.const(True), [L, S], dtype="bool")
temp_mask = _op.trilu(temp_mask, 0, upper=False)
temp_mask = _op.cast(temp_mask, dtype="bool")
temp_mask = _op.logical_not(temp_mask)
fill_value = _op.cast(_expr.const(float("-inf")), dtype=dtype)
attn_bias = _op.where(temp_mask, fill_value, attn_bias)
attn_bias = _op.cast(attn_bias, dtype)

if attn_mask is not None:
if input_types[3] == "bool":
attn_mask = _op.logical_not(attn_mask)
fill_value = _op.cast(_expr.const(float("-inf")), dtype=dtype)
attn_bias = _op.where(attn_mask, fill_value, attn_bias)
else:
attn_bias = _op.add(attn_bias, attn_mask)

if len(query_shape) < len(key_shape):
batch_size = key_shape[0]
else:
batch_size = query_shape[0]
if len(query_shape) == 4 and len(key_shape) == 4:
query = _op.reshape(query, newshape=[-3, -2])
key = _op.reshape(key, newshape=[-3, -2])
if len(query_shape) == 3 and len(key_shape) == 4:
query = _op.broadcast_to(query, shape=(batch_size,) + query_shape)
query = _op.reshape(query, newshape=[-3, -2])
key = _op.reshape(key, newshape=[-3, -2])
if len(query_shape) == 4 and len(key_shape) == 3:
query = _op.reshape(query, newshape=[-3, -2])
key = _op.broadcast_to(key, shape=(batch_size,) + key_shape)
key = _op.reshape(key, newshape=[-3, -2])
attn_weight = _op.nn.batch_matmul(query, key)
if len(query_shape) == 4 or len(key_shape) == 4:
attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2])
attn_weight = _op.squeeze(attn_weight, axis=[])

attn_weight = _op.multiply(attn_weight, scale_factor)
attn_weight = _op.add(attn_weight, attn_bias)
attn_weight = _op.nn.softmax(attn_weight)
attn_weight = _op.nn.dropout(attn_weight, rate=dropout_p)

aw_shape = self.infer_shape_with_prelude(attn_weight)
if len(aw_shape) < len(value_shape):
batch_size = value_shape[0]
else:
batch_size = aw_shape[0]
if len(aw_shape) == 4 and len(value_shape) == 4:
attn_weight = _op.reshape(attn_weight, newshape=[-3, -2])
value = _op.reshape(value, newshape=[-3, -2])
if len(aw_shape) == 3 and len(value_shape) == 4:
attn_weight = _op.broadcast_to(attn_weight, shape=(batch_size,) + aw_shape)
attn_weight = _op.reshape(attn_weight, newshape=[-3, -2])
value = _op.reshape(value, newshape=[-3, -2])
if len(aw_shape) == 4 and len(value_shape) == 3:
attn_weight = _op.reshape(attn_weight, newshape=[-3, -2])
value = _op.broadcast_to(value, shape=(batch_size,) + value_shape)
value = _op.reshape(value, newshape=[-3, -2])
attn_weight = _op.nn.batch_matmul(attn_weight, value, transpose_b=False)
if len(aw_shape) == 4 or len(value_shape) == 4:
attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2])
return attn_weight

# Operator mappings
def create_convert_map(self):
self.convert_map = {
Expand Down Expand Up @@ -4167,6 +4276,7 @@ def create_convert_map(self):
"aten::copy_": self.inplace_copy,
"aten::swapaxes": self.transpose,
"aten::linalg_vector_norm": self.linalg_vector_norm,
"aten::scaled_dot_product_attention": self.scaled_dot_product_attention,
}

def update_convert_map(self, custom_map):
Expand Down
84 changes: 83 additions & 1 deletion tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4776,7 +4776,6 @@ def test_fn(x, mask):
verify_model(test_fn, [inp.to(torch.float64), inp > 0.5])


@pytest.mark.skip(reason="unsupported op: 'aten::scaled_dot_product_attention'")
def test_transformer():
"""test_transformer"""
model = torch.nn.Transformer(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6)
Expand Down Expand Up @@ -5489,6 +5488,89 @@ def test_fn(order):
verify_model(test_fn(order=0), input_data=input_data)


@tvm.testing.uses_gpu
def test_scaled_dot_product_attention():
"""test_scaled_dot_product_attention"""
torch.set_grad_enabled(False)

def test_fn(attn_mask=None, is_causal=False):
return lambda query, key, value: torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, is_causal=is_causal
)

L, S, E, Ev = 5, 7, 11, 13
query_4d = torch.randn(2, 3, L, E)
query_3d = torch.randn(3, L, E)
key_4d = torch.randn(2, 3, S, E)
key_3d = torch.randn(3, S, E)
value_4d = torch.randn(2, 3, S, Ev)
value_3d = torch.randn(3, S, Ev)

verify_model(test_fn(), [query_4d, key_4d, value_4d])
verify_model(test_fn(), [query_4d, key_4d, value_3d])
verify_model(test_fn(), [query_4d, key_3d, value_4d])
verify_model(test_fn(), [query_4d, key_3d, value_3d])
verify_model(test_fn(), [query_3d, key_4d, value_4d])
verify_model(test_fn(), [query_3d, key_4d, value_3d])
verify_model(test_fn(), [query_3d, key_3d, value_4d])
verify_model(test_fn(), [query_3d, key_3d, value_3d])

verify_model(test_fn(is_causal=True), [query_4d, key_4d, value_4d])
verify_model(test_fn(is_causal=True), [query_4d, key_4d, value_3d])
verify_model(test_fn(is_causal=True), [query_4d, key_3d, value_4d])
verify_model(test_fn(is_causal=True), [query_4d, key_3d, value_3d])
verify_model(test_fn(is_causal=True), [query_3d, key_4d, value_4d])
verify_model(test_fn(is_causal=True), [query_3d, key_4d, value_3d])
verify_model(test_fn(is_causal=True), [query_3d, key_3d, value_4d])
verify_model(test_fn(is_causal=True), [query_3d, key_3d, value_3d])

# Test with explicit attn_mask
attn_mask = torch.ones((L, S), dtype=torch.bool).tril(diagonal=0)
if torch.cuda.is_available():
attn_mask = attn_mask.cuda()
verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_4d, value_4d])
verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_4d, value_3d])
verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_3d, value_4d])
verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_3d, value_3d])
verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_4d, value_4d])
verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_4d, value_3d])
verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_3d, value_4d])
verify_model(test_fn(attn_mask=attn_mask), [query_3d, key_3d, value_3d])

# Test with float64
query_4d = torch.randn(2, 3, L, E, dtype=torch.float64)
query_3d = torch.randn(3, L, E, dtype=torch.float64)
key_4d = torch.randn(2, 3, S, E, dtype=torch.float64)
key_3d = torch.randn(3, S, E, dtype=torch.float64)
value_4d = torch.randn(2, 3, S, Ev, dtype=torch.float64)
value_3d = torch.randn(3, S, Ev, dtype=torch.float64)
verify_model(test_fn(), [query_4d, key_4d, value_4d])
verify_model(test_fn(), [query_4d, key_4d, value_3d])
verify_model(test_fn(), [query_4d, key_3d, value_4d])
verify_model(test_fn(), [query_4d, key_3d, value_3d])
verify_model(test_fn(), [query_3d, key_4d, value_4d])
verify_model(test_fn(), [query_3d, key_4d, value_3d])
verify_model(test_fn(), [query_3d, key_3d, value_4d])
verify_model(test_fn(), [query_3d, key_3d, value_3d])

# Test with larger tensors
L, S, E, Ev = 128, 128, 64, 64
query_4d = torch.randn(32, 8, L, E)
query_3d = torch.randn(8, L, E)
key_4d = torch.randn(32, 8, S, E)
key_3d = torch.randn(8, S, E)
value_4d = torch.randn(32, 8, S, Ev)
value_3d = torch.randn(8, S, Ev)
verify_model(test_fn(), [query_4d, key_4d, value_4d])
verify_model(test_fn(), [query_4d, key_4d, value_3d])
verify_model(test_fn(), [query_4d, key_3d, value_4d])
verify_model(test_fn(), [query_4d, key_3d, value_3d])
verify_model(test_fn(), [query_3d, key_4d, value_4d])
verify_model(test_fn(), [query_3d, key_4d, value_3d])
verify_model(test_fn(), [query_3d, key_3d, value_4d])
verify_model(test_fn(), [query_3d, key_3d, value_3d])


class TestSetSpan:
"""test structural equal between translated / hand-crafted relay IR with span tagged."""

Expand Down
Loading