Skip to content

Commit

Permalink
Add e2e maxpool2d op (minus runtime support)
Browse files Browse the repository at this point in the history
bringup maxpool2d in TTNNToFlatBuffer. Skeleton code for runtime

WIP

Add reshape insertion pass for maxpool2d

remove reshapes in maxpool runtime

add input_height/width as attributes to maxpool2d

Fix dilation attribute for maxpool2d

fix program errors from faulty rebase

revert builds to release mode

Clenup maxpool fix shpes

Remove MLIR module dump in TTIRToTTNNPass.cpp, add more checks to verify() methods for TTIR/TTNN MaxPool2dOp

Fix verify condition for maxpool2d

Generalize reshape-inserting pass with template

fix build warnings

remove use of auto keyword

Rebase fixes

Make input height/width optional in TTIR as per nicks suggestion

Remove stray prints
  • Loading branch information
LPanosTT committed Sep 11, 2024
1 parent fc4a2f9 commit 479f05a
Show file tree
Hide file tree
Showing 12 changed files with 424 additions and 4 deletions.
32 changes: 32 additions & 0 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,38 @@ def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
let hasVerifier = 1;
}

def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> {
let summary = "Applies a 2D max pooling over an input signal composed of several input planes.";
let description = [{
Applies a 2D max pooling over an input signal composed of several input planes.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$kernel_height,
SI32Attr:$kernel_width,
SI32Attr:$stride_height,
SI32Attr:$stride_width,
SI32Attr:$dilation_height,
SI32Attr:$dilation_width,
BoolAttr:$ceil_mode,
SI32Attr:$padding_left,
SI32Attr:$padding_right,
SI32Attr:$padding_top,
SI32Attr:$padding_bottom,
TT_OperandConstraintArrayAttr:$operand_constraints,
OptionalAttr<SI32Attr>:$original_height,
OptionalAttr<SI32Attr>:$original_width);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> {
let summary = "Reshape op.";
let description = [{
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> {
];
}

def TTIRSlidingWindow2dFixShapes: Pass<"ttir-sliding-window-2d-fix-shapes", "::mlir::ModuleOp"> {
let summary = "Insert reshapes on the input and output of 2-dimensional sliding window ops that collapse N,H,W on the input: i.e (N, H, W, C) --> (1, 1, N*H*W, C), and unflatten the output: i.e (1, 1, N*H*W, C) --> (N, H, W, C)";
let description = [{
Insert reshapes on the input and output of 2-dimensional sliding window ops that collapse N,H,W on the input: i.e (N, H, W, C) --> (1, 1, N*H*W, C), and unflatten the output: i.e (1, 1, N*H*W, C) --> (N, H, W, C)
}];
}

def TTIRSplitCompoundLayout: Pass<"ttir-split-compound-layout", "::mlir::ModuleOp"> {
let summary = "Split compound layouts.";
let description = [{
Expand Down
32 changes: 32 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,38 @@ def TTNN_Conv2dOp : TTNN_NamedDPSOp<"conv2d"> {
let hasVerifier = 1;
}

def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> {
let summary = "Applies a 2D max pooling over an input signal composed of several input planes.";
let description = [{
Applies a 2D max pooling over an input signal composed of several input planes.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
TT_Device:$device,
SI32Attr:$batch_size,
SI32Attr:$input_height,
SI32Attr:$input_width,
SI32Attr:$channels,
SI32Attr:$kernel_height,
SI32Attr:$kernel_width,
SI32Attr:$stride_height,
SI32Attr:$stride_width,
SI32Attr:$dilation_height,
SI32Attr:$dilation_width,
BoolAttr:$ceil_mode,
SI32Attr:$padding_height,
SI32Attr:$padding_width);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTNN_EmptyOp : TTNN_Op<"empty"> {
let summary = "Empty op.";
let description = [{
Expand Down
20 changes: 20 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,25 @@ table Conv2dOp {
groups: uint32;
}

table MaxPool2dOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
batch_size: uint32;
input_height: uint32;
input_width: uint32;
channels: uint32;
kernel_height: uint32;
kernel_width: uint32;
stride_height: uint32;
stride_width: uint32;
dilation_height: uint32;
dilation_width: uint32;
ceil_mode: bool;
padding_height: uint32;
padding_width: uint32;
}

table DeallocOp {
in: tt.target.TensorRef;
}
Expand All @@ -139,6 +158,7 @@ union OpType {
Conv2dOp,
ConcatOp,
ReshapeOp,
MaxPool2dOp,
DeallocOp
}

Expand Down
51 changes: 48 additions & 3 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,18 +364,62 @@ class Conv2dOpConversionPattern : public OpConversionPattern<ttir::Conv2dOp> {
auto dilation_width =
rewriter.getI32IntegerAttr(adaptor.getDilationWidth());
auto groups = rewriter.getI32IntegerAttr(adaptor.getGroups());

rewriter.replaceOpWithNewOp<ttnn::Conv2dOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getWeight(), adaptor.getBias(),
adaptor.getOutput(), device, in_channels, out_channels, batch_size,
input_width, input_height, kernel_height, kernel_width, stride_height,
input_height, input_width, kernel_height, kernel_width, stride_height,
stride_width, padding_height, padding_width, dilation_height,
dilation_width, groups);
return success();
}
};

class MaxPool2dOpConversionPattern
: public OpConversionPattern<ttir::MaxPool2dOp> {
public:
using OpConversionPattern<ttir::MaxPool2dOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::MaxPool2dOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

assert(adaptor.getPaddingBottom() == adaptor.getPaddingTop() &&
"TTNN max_pool2d does not support padding top/bottom/left/right "
"separately");
assert(adaptor.getPaddingLeft() == adaptor.getPaddingRight() &&
"TTNN max_pool2d does not support padding top/bottom/left/right "
"separately");

auto device = getOrInsertDevice(rewriter, op);
auto input_ty = mlir::cast<RankedTensorType>(adaptor.getInput().getType());
llvm::ArrayRef<std::int64_t> input_shape = input_ty.getShape();

auto batch_size =
rewriter.getSI32IntegerAttr(input_shape[input_shape.size() - 4]);
auto channels =
rewriter.getSI32IntegerAttr(input_shape[input_shape.size() - 1]);

assert(adaptor.getOriginalHeight().has_value() &&
"ttir::MaxPool2dOp must have original_height set before translating "
"to TTNN dialect.");
assert(adaptor.getOriginalWidth().has_value() &&
"ttir::MaxPool2dOp must have original_width set before translating "
"to TTNN dialect.");

rewriter.replaceOpWithNewOp<ttnn::MaxPool2dOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getOutput(), device, batch_size,
adaptor.getOriginalHeightAttr(), adaptor.getOriginalWidthAttr(),
channels, adaptor.getKernelHeightAttr(), adaptor.getKernelWidthAttr(),
adaptor.getStrideHeightAttr(), adaptor.getStrideWidthAttr(),
adaptor.getDilationHeightAttr(), adaptor.getDilationWidthAttr(),
adaptor.getCeilModeAttr(), adaptor.getPaddingTopAttr(),
adaptor.getPaddingRightAttr());
return success();
}
};

namespace mlir::tt {

void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
Expand Down Expand Up @@ -407,7 +451,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
SqueezeOpConversionPattern,
UnsqueezeOpConversionPattern,
MatmulOpConversionPattern,
Conv2dOpConversionPattern
Conv2dOpConversionPattern,
MaxPool2dOpConversionPattern
>(typeConverter, ctx);
// ANCHOR_END: op_rewriter_pattern_set
// clang-format on
Expand Down
40 changes: 40 additions & 0 deletions lib/Dialect/TTIR/IR/TTIROps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"

#include "ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.cpp.inc"
#include <llvm/ADT/ArrayRef.h>
#include <mlir/IR/BuiltinTypes.h>

#define GET_OP_CLASSES
Expand Down Expand Up @@ -268,6 +269,45 @@ ::mlir::LogicalResult mlir::tt::ttir::ReshapeOp::verify() {
return success();
}

::mlir::LogicalResult mlir::tt::ttir::MaxPool2dOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
std::vector<int64_t> inputShape = getInput().getType().getShape().vec();

if (inputType.getRank() != 4) {
return emitOpError()
<< "Input tensor rank must be 4. Recieved input with rank "
<< inputType.getRank() << ". Shape: (" << inputShape << ").";
}

if (getOriginalHeight().has_value() != getOriginalWidth().has_value()) {
std::string with_value =
getOriginalHeight().has_value() ? "original_height" : "original_width";
return emitOpError()
<< "If providing the original height and width as attributs, both "
"original_height and original_width must be set. However, only "
<< with_value << " was provided.";
}

if (getOriginalHeight().has_value() && getOriginalWidth().has_value()) {
inputShape[1] = getOriginalHeight().value();
inputShape[2] = getOriginalWidth().value();
}

if (getKernelHeight() > inputShape[1]) {
return emitOpError() << "Kernel height " << getKernelHeight()
<< " is greater than input height " << inputShape[1]
<< ". This MaxPool2d configuration is invalid.";
}

if (getKernelWidth() > inputShape[2]) {
return emitOpError() << "Kernel width " << getKernelWidth()
<< " is greater than input width " << inputShape[2]
<< ". This MaxPool2d configuration is invalid.";
}

return success();
}

::mlir::LogicalResult mlir::tt::ttir::SqueezeOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
::mlir::RankedTensorType outputType = getOutput().getType();
Expand Down
Loading

0 comments on commit 479f05a

Please sign in to comment.