Skip to content

Commit

Permalink
Merge branch 'main' of github.com:nod-ai/SHARK-Turbine into wip-subgr…
Browse files Browse the repository at this point in the history
…aph-tracing
  • Loading branch information
stellaraccident committed Jan 20, 2024
2 parents 5686f84 + 6b21267 commit 6e85184
Show file tree
Hide file tree
Showing 7 changed files with 11 additions and 12 deletions.
4 changes: 2 additions & 2 deletions python/shark_turbine/kernel/compiler/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ def py_constant_int(self, py_value) -> Value:
# If coming from a stock 'int' Python type with no idea how to convert it,
# there isn't much smart we can do. We conservatively treat 'index' as
# reasonable.
attr = IntegerAttr.get(IndexType.get(), py_value)
return arith_d.constant(attr)
result_type = IndexType.get()
return arith_d.constant(result_type, IntegerAttr.get(result_type, py_value))

# Binary index arithmetic.
def binary_add_index_index(self, lhs: Value, rhs: Value) -> Value:
Expand Down
6 changes: 4 additions & 2 deletions python/shark_turbine/kernel/compiler/dispatch_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,9 @@ def abi_type(binding: BindingDesc):
workgroup_values = list(workgroup_builder.workload)
while len(workgroup_values) < 3:
with InsertionPoint(workgroup_builder.entry_block):
result_type = IndexType.get()
workgroup_values.append(
arith_d.constant(IntegerAttr.get(IndexType.get(), 1))
arith_d.constant(result_type, IntegerAttr.get(result_type, 1))
)
workgroup_builder.terminate(workgroup_values)

Expand Down Expand Up @@ -226,7 +227,8 @@ def resolve(self, binding: BindingDesc) -> Value:

if binding.binding_type == BindingType.KERNEL_BUFFER:
# Issue a subspan to get into the memref domain.
zero_value = arith_d.constant(IntegerAttr.get(IndexType.get(), 0))
result_type = IndexType.get()
zero_value = arith_d.constant(result_type, IntegerAttr.get(result_type, 0))
linear_arg_value = self._abi_value_by_reference[binding.reference]
# TODO: Need to also look up dynamic symbol values.
return stream_d.binding_subspan(
Expand Down
6 changes: 3 additions & 3 deletions python/shark_turbine/kernel/compiler/vector_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def _(emitter: ThreadEmitter, node: fx.Node):
vector_type = VectorType.get(vector_shape, element_type)
pad_attr = ScalarBuilder.zero_attr(element_type)
indices = cast_indices(emitter, [s.start for s in sa.slices])
pad_value = arith_d.constant(pad_attr)
pad_value = arith_d.constant(element_type, pad_attr)
result = vector_d.transfer_read(
vector_type,
kb_src,
Expand Down Expand Up @@ -588,7 +588,7 @@ def combiner(element_type: IrType, attrs: NodeAttrs) -> vector_d.CombiningKind:
# Non-NaN propagating.
# TODO: Carry a "fastmath" flag on the emitter and choose between this
# and MAXIMUMF?
return vector_d.CombiningKind.MAXF
return vector_d.CombiningKind.MAXNUMF
elif ScalarBuilder.is_integer_type(element_type):
return (
vector_d.CombiningKind.MAXUI
Expand Down Expand Up @@ -624,7 +624,7 @@ def emit_reduction(
vector_type = VectorType(input.type)
element_type = vector_type.element_type
rank = vector_type.rank
zero = arith_d.constant(ScalarBuilder.zero_attr(element_type))
zero = arith_d.constant(element_type, ScalarBuilder.zero_attr(element_type))
combiner = combiner_callback(element_type, attrs)

if len(args) == 1:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def create_benchmark_vmfb(args):
"--iree-llvmcpu-target-triple=x86_64-linux-gnu",
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-opt-const-expr-hoisting=False",
]
device = args.device
Expand Down
1 change: 0 additions & 1 deletion python/turbine_models/custom_models/sd_inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name):
"--iree-llvmcpu-target-triple=x86_64-linux-gnu",
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-opt-const-expr-hoisting=False",
]
if device == "cpu":
Expand Down
1 change: 0 additions & 1 deletion python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,6 @@ def evict_kvcache_space(self):
"--iree-llvmcpu-target-triple=x86_64-linux-gnu",
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-codegen-check-ir-before-llvm-conversion=false",
"--iree-opt-const-expr-hoisting=False",
]
if device == "cpu" or device == "llvm-cpu":
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
-r pytorch-cpu-requirements.txt
-r torchvision-requirements.txt

iree-compiler==20231218.742
iree-runtime==20231218.742
iree-compiler==20240119.775
iree-runtime==20240119.775

0 comments on commit 6e85184

Please sign in to comment.