diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 491bd876ecb1..748483349f98 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -995,6 +995,49 @@ static void insertSerializationBarriers(Location loc, Block &block, } } +// Checks if |executeOp| contains only a single transfer operation and returns +// it. Non-transfer/dispatch operations like cache control will be ignored. +// +// Intended to match things like: +// stream.cmd.execute ... { +// stream.cmd.invalidate +// stream.cmd.fill <----- returned +// stream.cmd.flush +// } +// And not: +// stream.cmd.execute ... { +// stream.cmd.invalidate +// stream.cmd.fill +// stream.cmd.flush +// stream.cmd.dispatch +// } +static Operation *matchSingleTransferOp(IREE::Stream::CmdExecuteOp executeOp) { + Operation *foundOp = nullptr; + for (auto &block : executeOp.getBodyRegion()) { + for (auto &op : block) { + if (!TypeSwitch(&op) + // Ignore non-transfer/dispatch ops. + .Case( + [&](auto metaOp) { return true; }) + .Case( + [&](auto transferOp) { + if (!foundOp) { + foundOp = &op; // first found + return true; + } else { + return false; // more than one transfer op + } + }) + // Dispatch/collective/etc fail the search. + .Default([&](auto otherOp) { return false; })) { + return nullptr; + } + } + } + return foundOp; +} + struct CmdExecuteOpPattern : public StreamConversionPattern { using StreamConversionPattern::StreamConversionPattern; @@ -1005,6 +1048,52 @@ struct CmdExecuteOpPattern auto [device, queueAffinity] = lookupDeviceAndQueueAffinityFor(executeOp, rewriter); + // If the command buffer only contains a single transfer command we may be + // able to convert it to a queue operation instead. This will have + // significantly less overhead than a command buffer especially if we are + // not able to memoize it. + if (auto *singleTransferOp = matchSingleTransferOp(executeOp)) { + // Gather wait/signal fence, which are optional. + Value waitFence = + getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter); + Value signalFence = getOrCreateSignalFence( + loc, device, executeOp.getResultTimepoint(), rewriter); + + // Replace the op with the queue operation. + // Note that since we are matching an op nested within the region we have + // to get the corresponding externally captured operand and lookup the + // remapped value from the conversion state. + // + // Example: + // stream.cmd.execute ... with(%operand as %capture: !stream.resource) + // stream.cmd.fill ... %capture + // -> + // hal.device.queue.fill ... target(%operand : !hal.buffer) + if (auto fillOp = dyn_cast(*singleTransferOp)) { + auto fillTargetBuffer = rewriter.getRemappedValue( + executeOp.getClosureCapturedValue(fillOp.getTarget())); + rewriter.create( + loc, device, queueAffinity, waitFence, signalFence, + fillTargetBuffer, fillOp.getTargetOffset(), + fillOp.getTargetLength(), fillOp.getValue(), + /*flags=*/0); + } else if (auto copyOp = + dyn_cast(*singleTransferOp)) { + auto copySourceBuffer = rewriter.getRemappedValue( + executeOp.getClosureCapturedValue(copyOp.getSource())); + auto copyTargetBuffer = rewriter.getRemappedValue( + executeOp.getClosureCapturedValue(copyOp.getTarget())); + rewriter.create( + loc, device, queueAffinity, waitFence, signalFence, + copySourceBuffer, copyOp.getSourceOffset(), copyTargetBuffer, + copyOp.getTargetOffset(), copyOp.getLength(), + /*flags=*/0); + } + + rewriter.replaceOp(executeOp, signalFence); + return success(); + } + // Until uniform buffers are implemented we can't reuse command buffers that // contain non-constant uniform values (i32, index, etc). We'll have a pass // that runs prior to conversion that creates new stream resources and diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir index 571f7291e631..e2e904770881 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir @@ -24,44 +24,114 @@ util.func public @cmdMemoryControl(%arg0: !stream.resource, %arg1: in // ----- +// Tests that an execution region with a fill and any other op is converted to +// a command buffer instead of a queue operation. + util.global private @device : !hal.device // CHECK-LABEL: @cmdFill -util.func public @cmdFill(%arg0: !stream.resource, %arg1: index) -> !stream.timepoint { +// CHECK-SAME: (%[[TARGET:.+]]: !hal.buffer, %[[TARGET_SIZE:.+]]: index) +util.func public @cmdFill(%target: !stream.resource, %target_size: index) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index %c255_i32 = arith.constant 255 : i32 // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute once on(#hal.device.affinity<@device>) with(%arg0 as %arg2: !stream.resource{%arg1}) { + %signal = stream.cmd.execute once on(#hal.device.affinity<@device>) with(%target as %target_capture: !stream.resource{%target_size}) { // CHECK-NEXT: hal.command_buffer.fill_buffer<%[[CMD]] : !hal.command_buffer> - // CHECK-SAME: target(%arg0 : !hal.buffer)[%c0, %c128] + // CHECK-SAME: target(%[[TARGET]] : !hal.buffer)[%c0, %c128] // CHECK-SAME: pattern(%c255_i32 : i32) - stream.cmd.fill %c255_i32, %arg2[%c0 for %c128] : i32 -> !stream.resource{%arg1} // CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]] + stream.cmd.fill %c255_i32, %target_capture[%c0 for %c128] : i32 -> !stream.resource{%target_size} + // CHECK-NEXT: hal.command_buffer.fill_buffer + // CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]] + stream.cmd.fill %c255_i32, %target_capture[%c0 for %target_size] : i32 -> !stream.resource{%target_size} } => !stream.timepoint // CHECK-NEXT: hal.command_buffer.finalize<%[[CMD]] - util.return %0 : !stream.timepoint + util.return %signal : !stream.timepoint } // ----- +// Tests that an execution region with a single fill is converted to a queue +// operation instead of a command buffer. The extra flush is ignored as queue +// operations have implicit flushes (today). + +util.global private @device : !hal.device + +// CHECK-LABEL: @cmdFillOnQueue +// CHECK-SAME: (%[[TARGET:.+]]: !hal.buffer, %[[TARGET_SIZE:.+]]: index, %[[WAIT:.+]]: !hal.fence) +util.func public @cmdFillOnQueue(%target: !stream.resource, %target_size: index, %wait: !stream.timepoint) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c255_i32 = arith.constant 255 : i32 + // CHECK: %[[SIGNAL:.+]] = hal.fence.create + // CHECK-NOT: hal.command_buffer.create + %signal = stream.cmd.execute once on(#hal.device.affinity<@device>) await(%wait) => with(%target as %target_capture: !stream.resource{%target_size}) { + // CHECK: hal.device.queue.fill<%{{.+}} : !hal.device> + // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]]) + // CHECK-SAME: target(%[[TARGET]] : !hal.buffer)[%c0] length(%c128) + // CHECK-SAME: pattern(%c255_i32 : i32) + stream.cmd.fill %c255_i32, %target_capture[%c0 for %c128] : i32 -> !stream.resource{%target_size} + stream.cmd.flush %target_capture[%c0 for %c128] : !stream.resource{%target_size} + } => !stream.timepoint + // CHECK: util.return %[[SIGNAL]] + util.return %signal : !stream.timepoint +} + +// ----- + +// Tests that an execution region with a copy and any other op is converted to +// a command buffer instead of a queue operation. + util.global private @device : !hal.device // CHECK-LABEL: @cmdCopy -util.func public @cmdCopy(%arg0: !stream.resource, %arg1: index, %arg2: !stream.resource, %arg3: index) -> !stream.timepoint { +// CHECK-SAME: (%[[SOURCE:.+]]: !hal.buffer, %[[SOURCE_SIZE:.+]]: index, %[[TARGET:.+]]: !hal.buffer, %[[TARGET_SIZE:.+]]: index) +util.func public @cmdCopy(%source: !stream.resource, %source_size: index, %target: !stream.resource, %target_size: index) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute once on(#hal.device.affinity<@device>) with(%arg0 as %arg4: !stream.resource{%arg1}, %arg2 as %arg5: !stream.resource{%arg3}) { + %signal = stream.cmd.execute once on(#hal.device.affinity<@device>) with(%source as %source_capture: !stream.resource{%source_size}, %target as %target_capture: !stream.resource{%target_size}) { // CHECK-NEXT: hal.command_buffer.copy_buffer<%[[CMD]] : !hal.command_buffer> - // CHECK-SAME: source(%arg0 : !hal.buffer)[%c0] - // CHECK-SAME: target(%arg2 : !hal.buffer)[%c0] + // CHECK-SAME: source(%[[SOURCE]] : !hal.buffer)[%c0] + // CHECK-SAME: target(%[[TARGET]] : !hal.buffer)[%c0] // CHECK-SAME: length(%c128) - stream.cmd.copy %arg4[%c0], %arg5[%c0], %c128 : !stream.resource{%arg1} -> !stream.resource{%arg3} // CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]] + stream.cmd.copy %source_capture[%c0], %target_capture[%c0], %c128 : !stream.resource{%source_size} -> !stream.resource{%target_size} + // CHECK-NEXT: hal.command_buffer.copy_buffer + // CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]] + stream.cmd.copy %source_capture[%c0], %target_capture[%c0], %target_size : !stream.resource{%source_size} -> !stream.resource{%target_size} } => !stream.timepoint // CHECK-NEXT: hal.command_buffer.finalize<%[[CMD]] - util.return %0 : !stream.timepoint + util.return %signal : !stream.timepoint +} + +// ----- + +// Tests that an execution region with a single copy is converted to a queue +// operation instead of a command buffer. The extra flush is ignored as queue +// operations have implicit flushes (today). + +util.global private @device : !hal.device + +// CHECK-LABEL: @cmdCopyOnQueue +// CHECK-SAME: (%[[SOURCE:.+]]: !hal.buffer, %[[SOURCE_SIZE:.+]]: index, %[[TARGET:.+]]: !hal.buffer, %[[TARGET_SIZE:.+]]: index, %[[WAIT:.+]]: !hal.fence) +util.func public @cmdCopyOnQueue(%source: !stream.resource, %source_size: index, %target: !stream.resource, %target_size: index, %wait: !stream.timepoint) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + // CHECK-NOT: hal.command_buffer.create + // CHECK: %[[SIGNAL:.+]] = hal.fence.create + %signal = stream.cmd.execute once on(#hal.device.affinity<@device>) await(%wait) => with(%source as %source_capture: !stream.resource{%source_size}, %target as %target_capture: !stream.resource{%target_size}) { + // CHECK: hal.device.queue.copy<%{{.+}} : !hal.device> + // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]]) + // CHECK-SAME: source(%[[SOURCE]] : !hal.buffer)[%c0] + // CHECK-SAME: target(%[[TARGET]] : !hal.buffer)[%c0] + // CHECK-SAME: length(%c128) + stream.cmd.copy %source_capture[%c0], %target_capture[%c0], %c128 : !stream.resource{%source_size} -> !stream.resource{%target_size} + stream.cmd.flush %target_capture[%c0 for %c128] : !stream.resource{%target_size} + } => !stream.timepoint + // CHECK: util.return %[[SIGNAL]] + util.return %signal : !stream.timepoint } // ----- @@ -364,12 +434,10 @@ util.global private @device : !hal.device util.func public @cmdExecuteAffinities(%arg0: !stream.resource, %arg1: index, %arg2: !stream.resource, %arg3: index, %arg4: !stream.timepoint) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index - // CHECK: %[[CMD:.+]] = hal.command_buffer.create + // CHECK: hal.device.queue.copy + // CHECK-SAME: affinity(%c3_i64) %0 = stream.cmd.execute once on(#hal.device.affinity<@device, [0, 1]>) await(%arg4) => with(%arg0 as %arg5: !stream.resource{%arg1}, %arg2 as %arg6: !stream.resource{%arg3}) { stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource{%arg1} -> !stream.resource{%arg3} } => !stream.timepoint - // CHECK: hal.device.queue.execute - // CHECK-SAME: affinity(%c3_i64) - // CHECK-SAME: commands([%[[CMD]]]) util.return %0 : !stream.timepoint } diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td index f0b541096b21..6bf6a4c72bdc 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td @@ -50,6 +50,24 @@ def Util_ClosureOpInterface : OpInterface<"ClosureOpInterface"> { /*methodName=*/"getClosureResults", /*args=*/(ins) >, + InterfaceMethod< + /*desc=*/[{ + Returns the value captured by the closure operands or the provided value + if it was not captured. + }], + /*retTy=*/"Value", + /*methodName=*/"getClosureCapturedValue", + /*args=*/(ins "Value":$closureValue), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + if (auto arg = dyn_cast(closureValue)) { + if (arg.getParentRegion() == &$_op.getClosureBodyRegion()) { + return $_op.getClosureOperands()[arg.getArgNumber()]; + } + } + return closureValue; + }] + >, InterfaceMethod< /*desc=*/[{ Returns true if the given operation can exist in the closure.