Skip to content

Commit

Permalink
Rebase fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT committed Sep 10, 2024
1 parent 71b7e4a commit 11a40cd
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 11 deletions.
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def TTNN_MaxPool2dOp : TTNN_NamedDPSOp<"max_pool2d"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
TT_Device:$device,
SI32Attr:$batch_size,
SI32Attr:$input_height,
SI32Attr:$input_width,
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ table Conv2dOp {
table MaxPool2dOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
batch_size: uint32;
input_height: uint32;
input_width: uint32;
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ class MaxPool2dOpConversionPattern
"TTNN max_pool2d does not support padding top/bottom/left/right "
"separately");

auto device = getOrInsertDevice(rewriter, op);
auto input_ty = mlir::cast<RankedTensorType>(adaptor.getInput().getType());
llvm::ArrayRef<std::int64_t> input_shape = input_ty.getShape();

Expand All @@ -401,7 +402,7 @@ class MaxPool2dOpConversionPattern

rewriter.replaceOpWithNewOp<ttnn::MaxPool2dOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getOutput(), batch_size,
adaptor.getInput(), adaptor.getOutput(), device, batch_size,
adaptor.getInputHeightAttr(), adaptor.getInputWidthAttr(), channels,
adaptor.getKernelHeightAttr(), adaptor.getKernelWidthAttr(),
adaptor.getStrideHeightAttr(), adaptor.getStrideWidthAttr(),
Expand Down
12 changes: 7 additions & 5 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,14 @@ createMaxPool2dOp(FlatbufferObjectCache &cache, MaxPool2dOp op) {
auto out = cache.at<::tt::target::TensorRef>(
getOperandThroughDPSOps(op.getResult()));

auto device = getOperandThroughDPSOps(op.getDevice());
return ::tt::target::ttnn::CreateMaxPool2dOp(
*cache.fbb, in, out, op.getBatchSize(), op.getInputHeight(),
op.getInputWidth(), op.getChannels(), op.getKernelHeight(),
op.getKernelWidth(), op.getStrideHeight(), op.getStrideWidth(),
op.getDilationHeight(), op.getDilationWidth(), op.getCeilMode(),
op.getPaddingHeight(), op.getPaddingWidth());
*cache.fbb, in, out, cache.at<::tt::target::DeviceRef>(device),
op.getBatchSize(), op.getInputHeight(), op.getInputWidth(),
op.getChannels(), op.getKernelHeight(), op.getKernelWidth(),
op.getStrideHeight(), op.getStrideWidth(), op.getDilationHeight(),
op.getDilationWidth(), op.getCeilMode(), op.getPaddingHeight(),
op.getPaddingWidth());
}

template <typename SoftmaxOp>
Expand Down
7 changes: 5 additions & 2 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,11 +751,13 @@ static void run(::tt::target::ttnn::Conv2dOp const *op,
}

static void run(::tt::target::ttnn::MaxPool2dOp const *op,
::ttnn::Device &device, ProgramTensorPool &tensorPool) {
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
const ::ttnn::Tensor &input = tensorPool.at(op->in()->global_id());
const ::ttnn::operations::pool::MaxPoolNewOp operation =
::ttnn::operations::pool::MaxPoolNewOp();

::ttnn::Device &device = getDevice(op->device(), devicePool);
::ttnn::Tensor out = operation.invoke(
0, input, op->batch_size(), op->input_height(), op->input_width(),
op->channels(), {op->kernel_height(), op->kernel_width()},
Expand Down Expand Up @@ -836,14 +838,15 @@ run(::tt::target::ttnn::Operation const *op,
}
case ::tt::target::ttnn::OpType::ConcatOp: {
return run(op->type_as_ConcatOp(), devicePool, tensorPool);
}
case ::tt::target::ttnn::OpType::ReshapeOp: {
return run(op->type_as_ReshapeOp(), devicePool, tensorPool);
}
case ::tt::target::ttnn::OpType::DeallocOp: {
return run(op->type_as_DeallocOp(), devicePool, tensorPool);
}
case ::tt::target::ttnn::OpType::MaxPool2dOp: {
return run(op->type_as_MaxPool2dOp(), device, tensorPool);
return run(op->type_as_MaxPool2dOp(), devicePool, tensorPool);
}
default: {
throw std::runtime_error("Unsupported operation type");
Expand Down
3 changes: 0 additions & 3 deletions test/ttmlir/Dialect/TTNN/simple_maxpool2d.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<1x128x128x32xbf16>) -> tensor<1x64x64x32xbf16> {
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<1x64x64x32xbf16>
// CHECK: %[[C:.*]] = "ttnn.max_pool2d"[[C:.*]]
%1 = "ttir.max_pool2d"(%arg0, %0) <{input_height=128: si32, input_width=128: si32, kernel_height=2: si32, kernel_width=2: si32, stride_height=2: si32, stride_width=2: si32, dilation_height=1: si32, dilation_width=1: si32, ceil_mode=false, padding_left=0: si32, padding_right=0: si32, padding_top=0: si32, padding_bottom=0: si32, operand_constraints = [#any_device, #any_device]}> : (tensor<1x128x128x32xbf16>, tensor<1x64x64x32xbf16>) -> tensor<1x64x64x32xbf16>
// CHECK: "ttnn.close_device"[[C:.*]]
return %1 : tensor<1x64x64x32xbf16>
}
}

0 comments on commit 11a40cd

Please sign in to comment.