Skip to content

Commit

Permalink
#0: Improve zeros_like perf
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW committed Sep 21, 2024
1 parent 7a2ca61 commit aa0cacf
Showing 1 changed file with 49 additions and 2 deletions.
51 changes: 49 additions & 2 deletions ttnn/cpp/ttnn/operations/creation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
#include "ttnn/decorators.hpp"
#include "ttnn/types.hpp"
#include "ttnn/common/constants.hpp"
#include "ttnn/operations/eltwise/unary/unary.hpp"
#include "ttnn/operations/core/core.hpp"
#include "ttnn/operations/data_movement/copy/copy.hpp"


namespace ttnn {
namespace operations {
Expand Down Expand Up @@ -199,10 +203,8 @@ struct FullLikeWith {
}
};

struct ZerosLike : FullLikeWith<0.0f> {};
struct OnesLike : FullLikeWith<1.0f> {};

inline constexpr ZerosLike zeros_like{};
inline constexpr OnesLike ones_like{};

struct Empty {
Expand Down Expand Up @@ -231,6 +233,51 @@ struct EmptyLike {
}
};


struct ZerosLike {
static ttnn::Tensor invoke(
uint8_t queue_id,
const ttnn::Tensor& tensor,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
const std::optional<std::reference_wrapper<Device>>& device_arg = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt ) {

if(!optional_output_tensor.has_value()) {
Device* device = device_arg.has_value() ? &(device_arg.value().get()) : tensor.device();
Layout layout_value = layout.value_or(tensor.get_layout());
DataType dtype_value = dtype.value_or(tensor.get_dtype());
MemoryConfig mem_cfg = memory_config.value_or(tensor.memory_config());
optional_output_tensor = create_device_tensor(tensor.get_shape(), dtype_value, layout_value, device, mem_cfg);
}

// this if() {...} can be skipped if RM support is not needed for zeros_like
if(optional_output_tensor.value().get_layout() == Layout::ROW_MAJOR){
Tensor x = optional_output_tensor.value();
x = ttnn::to_layout(x, ttnn::TILE_LAYOUT, std::nullopt, std::nullopt, (Device *)nullptr);
ttnn::mul_sfpu(x, 0.0f, std::nullopt, x);
x = ttnn::to_layout(x, ttnn::ROW_MAJOR_LAYOUT, std::nullopt, std::nullopt, (Device *)nullptr);
ttnn::assign(x, optional_output_tensor.value());
return optional_output_tensor.value();
}

ttnn::mul_sfpu(optional_output_tensor.value(), 0.0f, std::nullopt, optional_output_tensor);
return optional_output_tensor.value();
}

static ttnn::Tensor invoke(
const ttnn::Tensor& tensor,
const std::optional<DataType>& dtype = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
const std::optional<std::reference_wrapper<Device>>& device = std::nullopt,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<ttnn::Tensor> optional_output_tensor = std::nullopt) {
return invoke(ttnn::DefaultQueueId, tensor, dtype, layout, device, memory_config, optional_output_tensor);
}

};

struct Full {
static ttnn::Tensor invoke(
uint8_t queue_id,
Expand Down

0 comments on commit aa0cacf

Please sign in to comment.