Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

runtime support for full op #644

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,28 @@ run(::tt::target::ttnn::GetDeviceOp const *op,
}
}

static void run(::tt::target::ttnn::FullOp const *op,
std::unordered_map<uint32_t, ::ttnn::Device *> &devicePool,
ProgramTensorPool &tensorPool) {
::ttnn::Device &device = getDevice(op->device(), devicePool);
::ttnn::DataType outputDataType = getDataType(op->out());
auto shape = ::ttnn::Shape(::tt::tt_metal::Shape(
utils::toShapeFromFBShape(*op->out()->desc()->shape())));
float fillValue = op->fill_value();
// TODO(bug #272), determine correct layout by tile shape in the future
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now we can probably assert that there are only 2 tile shapes and therefore infer, [1, 1] -> row major, or [32, 32] -> tiled

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed with @nsmithtt @nobradovictt that currently we don't have compiler support for dynamically generating tile shapes, will leave this hardcoded and TODO for now until we bring up the support.

::ttnn::Layout outputLayout = ::ttnn::Layout::ROW_MAJOR;
std::optional<std::reference_wrapper<::ttnn::Device>> outputDevice =
std::make_optional(std::ref(device));
std::optional<::tt::tt_metal::MemoryConfig> outputMemoryConfig =
std::make_optional(createMemoryConfig(op->out()));

::ttnn::Tensor out =
::ttnn::full(shape, fillValue, outputDataType, outputLayout, outputDevice,
outputMemoryConfig);

tensorPool.insert_or_assign(op->out()->global_id(), std::move(out));
}

static void
run(::tt::target::ttnn::Operation const *op,
const std::unordered_map<uint32_t, ::ttnn::Device *> &allDevices,
Expand All @@ -789,7 +811,7 @@ run(::tt::target::ttnn::Operation const *op,
return run(op->type_as_EmptyOp(), devicePool, tensorPool);
}
case ::tt::target::ttnn::OpType::FullOp: {
// TODO(bug #626)
return run(op->type_as_FullOp(), devicePool, tensorPool);
break;
}
case ::tt::target::ttnn::OpType::EltwiseOp: {
Expand Down Expand Up @@ -850,12 +872,14 @@ void runProgram(::ttnn::Device &device,
::tt::target::ttnn::Program const *program,
std::vector<::ttnn::Tensor *> const &inputs,
std::vector<::ttnn::Tensor *> const &outputs) {
if (handleNopProgram(program, inputs, outputs)) {
return;
}
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> liveTensors;
std::unordered_map<std::uint32_t, ::ttnn::Device *> allDevices;
std::unordered_map<std::uint32_t, ::ttnn::Device *> devicePool;
int inputIndex = 0;
assert(program->inputs()->size() == inputs.size());
bool isNop = handleNopProgram(program, inputs, outputs);
// Assuming single device for now until we support multichip
allDevices.try_emplace(device.id(), &device);
for (::tt::target::TensorRef const *input : *program->inputs()) {
Expand All @@ -869,7 +893,7 @@ void runProgram(::ttnn::Device &device,
for (::tt::target::TensorRef const *output : *program->outputs()) {
auto [iter, inserted] =
liveTensors.try_emplace(output->global_id(), outputs[outputIndex++]);
assert((isNop || inserted) && "Duplicate output tensor");
assert(inserted && "Duplicate output tensor");
}
ProgramTensorPool tensorPool(std::move(liveTensors));
for (::tt::target::ttnn::Operation const *op : *program->operations()) {
Expand Down
Loading