From 8920c28038914951c14aeb9266d2acb16e40645b Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Tue, 5 Nov 2024 14:44:18 -0800 Subject: [PATCH] Adding stream lowering to hal.device.queue.fill/hal.device.queue.copy. (#19033) Stream execution regions that contain only a single transfer operation are now converted to queue operations. This avoids the additional command buffer overhead for individual operations that don't benefit from batching. At runtime these route to the HAL device queue methods which implementations can use to optimize standalone transfer requests (e.g. mapping to `cuMemcpyAsync`). There's currently no `hal.device.queue.update` equivalent in the stream dialect but that will be added as part of the dynamic uniform values for reusable command buffers workstream. --- .../HAL/Conversion/StreamToHAL/Patterns.cpp | 89 +++++++++++++++++ .../Conversion/StreamToHAL/test/cmd_ops.mlir | 98 ++++++++++++++++--- .../Dialect/Util/IR/UtilInterfaces.td | 18 ++++ 3 files changed, 190 insertions(+), 15 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 491bd876ecb1..748483349f98 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -995,6 +995,49 @@ static void insertSerializationBarriers(Location loc, Block &block, } } +// Checks if |executeOp| contains only a single transfer operation and returns +// it. Non-transfer/dispatch operations like cache control will be ignored. +// +// Intended to match things like: +// stream.cmd.execute ... { +// stream.cmd.invalidate +// stream.cmd.fill <----- returned +// stream.cmd.flush +// } +// And not: +// stream.cmd.execute ... { +// stream.cmd.invalidate +// stream.cmd.fill +// stream.cmd.flush +// stream.cmd.dispatch +// } +static Operation *matchSingleTransferOp(IREE::Stream::CmdExecuteOp executeOp) { + Operation *foundOp = nullptr; + for (auto &block : executeOp.getBodyRegion()) { + for (auto &op : block) { + if (!TypeSwitch(&op) + // Ignore non-transfer/dispatch ops. + .Case( + [&](auto metaOp) { return true; }) + .Case( + [&](auto transferOp) { + if (!foundOp) { + foundOp = &op; // first found + return true; + } else { + return false; // more than one transfer op + } + }) + // Dispatch/collective/etc fail the search. + .Default([&](auto otherOp) { return false; })) { + return nullptr; + } + } + } + return foundOp; +} + struct CmdExecuteOpPattern : public StreamConversionPattern { using StreamConversionPattern::StreamConversionPattern; @@ -1005,6 +1048,52 @@ struct CmdExecuteOpPattern auto [device, queueAffinity] = lookupDeviceAndQueueAffinityFor(executeOp, rewriter); + // If the command buffer only contains a single transfer command we may be + // able to convert it to a queue operation instead. This will have + // significantly less overhead than a command buffer especially if we are + // not able to memoize it. + if (auto *singleTransferOp = matchSingleTransferOp(executeOp)) { + // Gather wait/signal fence, which are optional. + Value waitFence = + getOrCreateWaitFence(loc, adaptor.getAwaitTimepoint(), rewriter); + Value signalFence = getOrCreateSignalFence( + loc, device, executeOp.getResultTimepoint(), rewriter); + + // Replace the op with the queue operation. + // Note that since we are matching an op nested within the region we have + // to get the corresponding externally captured operand and lookup the + // remapped value from the conversion state. + // + // Example: + // stream.cmd.execute ... with(%operand as %capture: !stream.resource) + // stream.cmd.fill ... %capture + // -> + // hal.device.queue.fill ... target(%operand : !hal.buffer) + if (auto fillOp = dyn_cast(*singleTransferOp)) { + auto fillTargetBuffer = rewriter.getRemappedValue( + executeOp.getClosureCapturedValue(fillOp.getTarget())); + rewriter.create( + loc, device, queueAffinity, waitFence, signalFence, + fillTargetBuffer, fillOp.getTargetOffset(), + fillOp.getTargetLength(), fillOp.getValue(), + /*flags=*/0); + } else if (auto copyOp = + dyn_cast(*singleTransferOp)) { + auto copySourceBuffer = rewriter.getRemappedValue( + executeOp.getClosureCapturedValue(copyOp.getSource())); + auto copyTargetBuffer = rewriter.getRemappedValue( + executeOp.getClosureCapturedValue(copyOp.getTarget())); + rewriter.create( + loc, device, queueAffinity, waitFence, signalFence, + copySourceBuffer, copyOp.getSourceOffset(), copyTargetBuffer, + copyOp.getTargetOffset(), copyOp.getLength(), + /*flags=*/0); + } + + rewriter.replaceOp(executeOp, signalFence); + return success(); + } + // Until uniform buffers are implemented we can't reuse command buffers that // contain non-constant uniform values (i32, index, etc). We'll have a pass // that runs prior to conversion that creates new stream resources and diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir index 571f7291e631..e2e904770881 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/test/cmd_ops.mlir @@ -24,44 +24,114 @@ util.func public @cmdMemoryControl(%arg0: !stream.resource, %arg1: in // ----- +// Tests that an execution region with a fill and any other op is converted to +// a command buffer instead of a queue operation. + util.global private @device : !hal.device // CHECK-LABEL: @cmdFill -util.func public @cmdFill(%arg0: !stream.resource, %arg1: index) -> !stream.timepoint { +// CHECK-SAME: (%[[TARGET:.+]]: !hal.buffer, %[[TARGET_SIZE:.+]]: index) +util.func public @cmdFill(%target: !stream.resource, %target_size: index) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index %c255_i32 = arith.constant 255 : i32 // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute once on(#hal.device.affinity<@device>) with(%arg0 as %arg2: !stream.resource{%arg1}) { + %signal = stream.cmd.execute once on(#hal.device.affinity<@device>) with(%target as %target_capture: !stream.resource{%target_size}) { // CHECK-NEXT: hal.command_buffer.fill_buffer<%[[CMD]] : !hal.command_buffer> - // CHECK-SAME: target(%arg0 : !hal.buffer)[%c0, %c128] + // CHECK-SAME: target(%[[TARGET]] : !hal.buffer)[%c0, %c128] // CHECK-SAME: pattern(%c255_i32 : i32) - stream.cmd.fill %c255_i32, %arg2[%c0 for %c128] : i32 -> !stream.resource{%arg1} // CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]] + stream.cmd.fill %c255_i32, %target_capture[%c0 for %c128] : i32 -> !stream.resource{%target_size} + // CHECK-NEXT: hal.command_buffer.fill_buffer + // CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]] + stream.cmd.fill %c255_i32, %target_capture[%c0 for %target_size] : i32 -> !stream.resource{%target_size} } => !stream.timepoint // CHECK-NEXT: hal.command_buffer.finalize<%[[CMD]] - util.return %0 : !stream.timepoint + util.return %signal : !stream.timepoint } // ----- +// Tests that an execution region with a single fill is converted to a queue +// operation instead of a command buffer. The extra flush is ignored as queue +// operations have implicit flushes (today). + +util.global private @device : !hal.device + +// CHECK-LABEL: @cmdFillOnQueue +// CHECK-SAME: (%[[TARGET:.+]]: !hal.buffer, %[[TARGET_SIZE:.+]]: index, %[[WAIT:.+]]: !hal.fence) +util.func public @cmdFillOnQueue(%target: !stream.resource, %target_size: index, %wait: !stream.timepoint) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c255_i32 = arith.constant 255 : i32 + // CHECK: %[[SIGNAL:.+]] = hal.fence.create + // CHECK-NOT: hal.command_buffer.create + %signal = stream.cmd.execute once on(#hal.device.affinity<@device>) await(%wait) => with(%target as %target_capture: !stream.resource{%target_size}) { + // CHECK: hal.device.queue.fill<%{{.+}} : !hal.device> + // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]]) + // CHECK-SAME: target(%[[TARGET]] : !hal.buffer)[%c0] length(%c128) + // CHECK-SAME: pattern(%c255_i32 : i32) + stream.cmd.fill %c255_i32, %target_capture[%c0 for %c128] : i32 -> !stream.resource{%target_size} + stream.cmd.flush %target_capture[%c0 for %c128] : !stream.resource{%target_size} + } => !stream.timepoint + // CHECK: util.return %[[SIGNAL]] + util.return %signal : !stream.timepoint +} + +// ----- + +// Tests that an execution region with a copy and any other op is converted to +// a command buffer instead of a queue operation. + util.global private @device : !hal.device // CHECK-LABEL: @cmdCopy -util.func public @cmdCopy(%arg0: !stream.resource, %arg1: index, %arg2: !stream.resource, %arg3: index) -> !stream.timepoint { +// CHECK-SAME: (%[[SOURCE:.+]]: !hal.buffer, %[[SOURCE_SIZE:.+]]: index, %[[TARGET:.+]]: !hal.buffer, %[[TARGET_SIZE:.+]]: index) +util.func public @cmdCopy(%source: !stream.resource, %source_size: index, %target: !stream.resource, %target_size: index) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index // CHECK: %[[CMD:.+]] = hal.command_buffer.create - %0 = stream.cmd.execute once on(#hal.device.affinity<@device>) with(%arg0 as %arg4: !stream.resource{%arg1}, %arg2 as %arg5: !stream.resource{%arg3}) { + %signal = stream.cmd.execute once on(#hal.device.affinity<@device>) with(%source as %source_capture: !stream.resource{%source_size}, %target as %target_capture: !stream.resource{%target_size}) { // CHECK-NEXT: hal.command_buffer.copy_buffer<%[[CMD]] : !hal.command_buffer> - // CHECK-SAME: source(%arg0 : !hal.buffer)[%c0] - // CHECK-SAME: target(%arg2 : !hal.buffer)[%c0] + // CHECK-SAME: source(%[[SOURCE]] : !hal.buffer)[%c0] + // CHECK-SAME: target(%[[TARGET]] : !hal.buffer)[%c0] // CHECK-SAME: length(%c128) - stream.cmd.copy %arg4[%c0], %arg5[%c0], %c128 : !stream.resource{%arg1} -> !stream.resource{%arg3} // CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]] + stream.cmd.copy %source_capture[%c0], %target_capture[%c0], %c128 : !stream.resource{%source_size} -> !stream.resource{%target_size} + // CHECK-NEXT: hal.command_buffer.copy_buffer + // CHECK-NEXT: hal.command_buffer.execution_barrier<%[[CMD]] + stream.cmd.copy %source_capture[%c0], %target_capture[%c0], %target_size : !stream.resource{%source_size} -> !stream.resource{%target_size} } => !stream.timepoint // CHECK-NEXT: hal.command_buffer.finalize<%[[CMD]] - util.return %0 : !stream.timepoint + util.return %signal : !stream.timepoint +} + +// ----- + +// Tests that an execution region with a single copy is converted to a queue +// operation instead of a command buffer. The extra flush is ignored as queue +// operations have implicit flushes (today). + +util.global private @device : !hal.device + +// CHECK-LABEL: @cmdCopyOnQueue +// CHECK-SAME: (%[[SOURCE:.+]]: !hal.buffer, %[[SOURCE_SIZE:.+]]: index, %[[TARGET:.+]]: !hal.buffer, %[[TARGET_SIZE:.+]]: index, %[[WAIT:.+]]: !hal.fence) +util.func public @cmdCopyOnQueue(%source: !stream.resource, %source_size: index, %target: !stream.resource, %target_size: index, %wait: !stream.timepoint) -> !stream.timepoint { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + // CHECK-NOT: hal.command_buffer.create + // CHECK: %[[SIGNAL:.+]] = hal.fence.create + %signal = stream.cmd.execute once on(#hal.device.affinity<@device>) await(%wait) => with(%source as %source_capture: !stream.resource{%source_size}, %target as %target_capture: !stream.resource{%target_size}) { + // CHECK: hal.device.queue.copy<%{{.+}} : !hal.device> + // CHECK-SAME: wait(%[[WAIT]]) signal(%[[SIGNAL]]) + // CHECK-SAME: source(%[[SOURCE]] : !hal.buffer)[%c0] + // CHECK-SAME: target(%[[TARGET]] : !hal.buffer)[%c0] + // CHECK-SAME: length(%c128) + stream.cmd.copy %source_capture[%c0], %target_capture[%c0], %c128 : !stream.resource{%source_size} -> !stream.resource{%target_size} + stream.cmd.flush %target_capture[%c0 for %c128] : !stream.resource{%target_size} + } => !stream.timepoint + // CHECK: util.return %[[SIGNAL]] + util.return %signal : !stream.timepoint } // ----- @@ -364,12 +434,10 @@ util.global private @device : !hal.device util.func public @cmdExecuteAffinities(%arg0: !stream.resource, %arg1: index, %arg2: !stream.resource, %arg3: index, %arg4: !stream.timepoint) -> !stream.timepoint { %c0 = arith.constant 0 : index %c128 = arith.constant 128 : index - // CHECK: %[[CMD:.+]] = hal.command_buffer.create + // CHECK: hal.device.queue.copy + // CHECK-SAME: affinity(%c3_i64) %0 = stream.cmd.execute once on(#hal.device.affinity<@device, [0, 1]>) await(%arg4) => with(%arg0 as %arg5: !stream.resource{%arg1}, %arg2 as %arg6: !stream.resource{%arg3}) { stream.cmd.copy %arg5[%c0], %arg6[%c0], %c128 : !stream.resource{%arg1} -> !stream.resource{%arg3} } => !stream.timepoint - // CHECK: hal.device.queue.execute - // CHECK-SAME: affinity(%c3_i64) - // CHECK-SAME: commands([%[[CMD]]]) util.return %0 : !stream.timepoint } diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td index f0b541096b21..6bf6a4c72bdc 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilInterfaces.td @@ -50,6 +50,24 @@ def Util_ClosureOpInterface : OpInterface<"ClosureOpInterface"> { /*methodName=*/"getClosureResults", /*args=*/(ins) >, + InterfaceMethod< + /*desc=*/[{ + Returns the value captured by the closure operands or the provided value + if it was not captured. + }], + /*retTy=*/"Value", + /*methodName=*/"getClosureCapturedValue", + /*args=*/(ins "Value":$closureValue), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + if (auto arg = dyn_cast(closureValue)) { + if (arg.getParentRegion() == &$_op.getClosureBodyRegion()) { + return $_op.getClosureOperands()[arg.getArgNumber()]; + } + } + return closureValue; + }] + >, InterfaceMethod< /*desc=*/[{ Returns true if the given operation can exist in the closure.