Skip to content

Commit

Permalink
[compiler] strip execution context affinities in const eval
Browse files Browse the repository at this point in the history
During compile-time constant evaluation in pass
iree-consteval-jit-globals it does not make sense to assign
device/queue affinities. We will be compiling and executing it on the
compilation host.

The JITed IR is first stripped of all attributes that assign affinities
and all flow.tensor.transfer ops.

Signed-off-by: Boian Petkantchin <[email protected]>
  • Loading branch information
sogartar committed Oct 2, 2024
1 parent 20a7638 commit 865a8bd
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 0 deletions.
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/ConstEval/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ iree_compiler_cc_library(
":PassHeaders",
":PassesIncGen",
":Runtime",
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"//compiler/src/iree/compiler/Dialect/HAL/Target",
"//compiler/src/iree/compiler/Dialect/Stream/IR",
"//compiler/src/iree/compiler/Dialect/Util/Analysis/Constant",
"//compiler/src/iree/compiler/Dialect/Util/IR",
"//compiler/src/iree/compiler/Pipelines",
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/ConstEval/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ iree_cc_library(
MLIRFunctionInterfaces
MLIRIR
MLIRPass
iree::compiler::Dialect::Flow::IR
iree::compiler::Dialect::HAL::Target
iree::compiler::Dialect::Stream::IR
iree::compiler::Dialect::Util::Analysis::Constant
iree::compiler::Dialect::Util::IR
iree::compiler::Pipelines
Expand Down
55 changes: 55 additions & 0 deletions compiler/src/iree/compiler/ConstEval/JitGlobals.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

#include "iree/compiler/ConstEval/Passes.h"
#include "iree/compiler/ConstEval/Runtime.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "iree/compiler/Dialect/HAL/Target/TargetOptions.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h"
#include "iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
Expand All @@ -21,7 +23,9 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include <cstdlib>

Expand Down Expand Up @@ -448,6 +452,53 @@ static LogicalResult cloneUsedObjects(FunctionOpInterface funcOp,
return success();
}

struct StripFlowTensorTransferPattern
: public OpRewritePattern<IREE::Flow::TensorTransferOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(IREE::Flow::TensorTransferOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceAllUsesWith(op.getResult(), op.getOperand());
rewriter.eraseOp(op);
return success();
}
};

// If an op that implements AffinityOpInterface has an optional stream affinity
// attribute, remove it.
struct StripStreamAffinityOptionalAttributePattern
: public OpInterfaceRewritePattern<IREE::Stream::AffinityOpInterface> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;

LogicalResult matchAndRewrite(IREE::Stream::AffinityOpInterface op,
PatternRewriter &rewriter) const override {
// Shouldn't we reject ops for which `op.requiresAffinity() == true`?
// For example there are a lot of ops in the Flow dialect that
// have this property, but do they really require an affinity?
// See
// compiler/src/iree/compiler/ExternalInterfaces/StreamExternalModels.cpp
if (op.getAffinityAttr() == nullptr) {
return failure();
}
op.setAffinityAttr(nullptr);
return success();
}
};

// Remove device/queue affinities for the IR.
// E.g. remove `flow.tensor.transfer` ops.
static LogicalResult stripExecutionContextAffinities(ModuleOp moduleOp) {
RewritePatternSet patterns(moduleOp->getContext());
patterns.add<StripFlowTensorTransferPattern,
StripStreamAffinityOptionalAttributePattern>(
moduleOp.getContext());
if (failed(applyPatternsAndFoldGreedily(moduleOp, std::move(patterns)))) {
return emitError(moduleOp->getLoc())
<< "Stripping execution context affinities failed";
}
return success();
}

class ProgramBuilder {
public:
ProgramBuilder(ModuleOp sourceModuleOp,
Expand Down Expand Up @@ -831,6 +882,10 @@ class JitGlobalsPass final : public impl::JitGlobalsPassBase<JitGlobalsPass> {
programBuilder.getTargetModule()->erase();
return;
}
if (failed(stripExecutionContextAffinities(
programBuilder.getTargetModule()))) {
return signalPassFailure();
}

std::optional<llvm::Timer> compileTimer;
if (debugEnabled) {
Expand Down
40 changes: 40 additions & 0 deletions compiler/src/iree/compiler/ConstEval/test/jit_globals.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -439,3 +439,43 @@ module @dispatch_executable {
util.return
}
}

// -----

// CHECK-LABEL: @strip_flow_tensor_transfer
// CHECK: util.global private @[[EVALED:.+]] = dense<2.000000e+02> : tensor<2xf16>
module @strip_flow_tensor_transfer {
util.global private @hoisted : tensor<2xf16>
// CHECK-NOT: util.initializer
util.initializer {
%cst = arith.constant dense<2.0e+2> : tensor<2xf16>
%cst_transfered = flow.tensor.transfer %cst : tensor<2xf16> to #hal.device.promise<@dev_a>
util.global.store %cst_transfered, @hoisted : tensor<2xf16>
util.return
}
util.func public @main() -> tensor<2xf16> {
// CHECK: util.global.load @[[EVALED]]
%hoisted = util.global.load @hoisted : tensor<2xf16>
util.return %hoisted : tensor<2xf16>
}
}

// -----

// CHECK-LABEL: @strip_optional_stream_affinity_attribute
// CHECK: util.global private @[[EVALED:.+]] = dense<1> : tensor<2xi32>
module @strip_optional_stream_affinity_attribute {
util.global private @hoisted : tensor<2xi32>
// CHECK-NOT: util.initializer
util.initializer {
%c1 = arith.constant 1 : i32
%tensor = flow.tensor.splat %c1 : tensor<2xi32> attributes { stream.affinity = #hal.device.promise<@dev_a> }
util.global.store %tensor , @hoisted : tensor<2xi32>
util.return
}
util.func public @main() -> tensor<2xi32> {
// CHECK: util.global.load @[[EVALED]]
%hoisted = util.global.load @hoisted : tensor<2xi32>
util.return %hoisted : tensor<2xi32>
}
}

0 comments on commit 865a8bd

Please sign in to comment.