Skip to content

Commit

Permalink
[TK] Add support for ops required for Flash Attention 2 (#385)
Browse files Browse the repository at this point in the history
Add new ops:

- tkl.exp2 (math)
- tkl.max (reduce max)
- tkl.sum (reduce sum)
- tkl.broadcast (broadcast leading dims)
- tkl.broadcast_in_dim (broadcast specific dimensions)
- tkl.transpose (transpose)
  • Loading branch information
Groverkss authored Feb 1, 2024
1 parent da57fe3 commit 6f67a97
Show file tree
Hide file tree
Showing 11 changed files with 311 additions and 88 deletions.
76 changes: 75 additions & 1 deletion python/shark_turbine/kernel/_support/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,13 @@ def wrapper(f):
### ========================================================================
### Math Operations
### ========================================================================
def handle_exp2(self, op, val):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(val,),
kwargs={},
)

def handle_vector_constant(
self, op, shape: Tuple[int, ...], dtype, value: int | float
Expand All @@ -278,15 +285,82 @@ def handle_vector_constant(
### ========================================================================
### Reduction Operations
### ========================================================================
def handle_vector_max(self, op, vector, axis=None, acc=None):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(vector, axis, acc),
kwargs={},
)

def handle_vector_sum(self, op, vector, axis=None, acc=None):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(vector, axis, acc),
kwargs={},
)

def handle_vector_dot(self, op, lhs, rhs, acc):
def handle_vector_dot(self, op, lhs, rhs, acc=None):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(lhs, rhs, acc),
kwargs={},
)

### ========================================================================
### Shape Manipulation Operations
### ========================================================================
def handle_vector_broadcast(self, op, vector, leading_sizes):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(vector, leading_sizes),
kwargs={},
)

def handle_vector_broadcast_in_dim(self, op, vector, shape, broadcast_dimensions):
# Currently, we do not have a corressponding op in MLIR, so
# we trace this to broadcast + transpose.
# TODO: Add a vector dialect op for this in MLIR.

# Remove broadcast_dimensions from shape.
shape_with_leading = tuple(
dim for i, dim in enumerate(shape) if i not in broadcast_dimensions
)

# Broadcast
broadcasted_vector = self.region_graph.create_proxy(
"call_function",
target=ops.vector_broadcast,
args=(vector, shape_with_leading),
kwargs={},
)

# Get the permutation for the transpose.
permutation = tuple(
i for i in range(len(shape)) if i not in broadcast_dimensions
)
permutation = permutation + tuple(broadcast_dimensions)
print(permutation)

# Transpose
return self.region_graph.create_proxy(
"call_function",
target=ops.vector_transpose,
args=(broadcasted_vector, permutation),
kwargs={},
)

def handle_vector_transpose(self, op, vector, permutation):
return self.region_graph.create_proxy(
"call_function",
target=op,
args=(vector, permutation),
kwargs={},
)


###############################################################################
# Launch context
Expand Down
31 changes: 29 additions & 2 deletions python/shark_turbine/kernel/compiler/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Value,
VectorType,
arith_d,
math_d,
builtin_d,
)

Expand Down Expand Up @@ -139,7 +140,7 @@ def binary_arithmetic(

def binary_vector_arithmetic(
self, op: str, lhs: IRProxyValue, rhs: IRProxyValue
) -> Value:
) -> IRProxyValue:
lhs_ir = lhs.ir_value
rhs_ir = rhs.ir_value
lhs_element_type = VectorType(lhs_ir.type).element_type
Expand All @@ -149,10 +150,33 @@ def binary_vector_arithmetic(
handler = getattr(self, attr_name)
except AttributeError:
raise CodegenError(
f"Cannot perform binary arithmetic operation '{op}' between {lhs.type} and {rhs.type} (tried '{attr_name}')"
f"Cannot perform binary arithmetic operation '{op}' between {lhs_ir.type} and {rhs_ir.type} (tried '{attr_name}')"
)
return handler(lhs, rhs)

def unary_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue:
val_ir_type = val.ir_value.type
attr_name = f"unary_{op}_{val_ir_type}"
try:
handler = getattr(self, attr_name)
except AttributeError:
raise CodegenError(
f"Cannot perform unary arithmetic operation '{op}' on {val_ir_type} (tried '{attr_name}')"
)
return handler(val)

def unary_vector_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue:
val_ir = val.ir_value
val_element_type = VectorType(val_ir.type).element_type
attr_name = f"unary_{op}_{val_element_type}"
try:
handler = getattr(self, attr_name)
except AttributeError:
raise CodegenError(
f"Cannot perform unary arithmetic operation '{op}' on {val_ir.type} (tried '{attr_name}')"
)
return handler(val)

def promote_index_to_f32(self, value: Value, to_type: IrType) -> Value:
i32_type = IntegerType.get_signless(32)
i32 = arith_d.index_cast(i32_type, value)
Expand Down Expand Up @@ -215,5 +239,8 @@ def binary_truediv_f32_f32(
) -> IRProxyValue:
return IRProxyValue(arith_d.divf(lhs.ir_value, rhs.ir_value))

def unary_exp2_f32(self, val: IRProxyValue) -> IRProxyValue:
return IRProxyValue(math_d.exp2(val.ir_value))


ScalarBuilder = _ScalarBuilder()
Loading

0 comments on commit 6f67a97

Please sign in to comment.