Skip to content

Commit

Permalink
#12867: Cleanup repeat_bw
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Sep 20, 2024
1 parent b5001dd commit fa41d12
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,7 @@ std::vector<std::optional<ttnn::Tensor>> ExecuteUnaryBackwardGelu::invoke(
return ExecuteUnaryBackwardGelu::invoke(DefaultQueueId, grad, input, approximate, output_mem_config);
}

std::vector<Tensor> _repeat_bw(
std::vector<Tensor> ExecuteUnaryBackwardRepeat::invoke(
const Tensor& grad, const Tensor& input, const tt::tt_metal::LegacyShape& shape, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ enum class UnaryBackwardOpType {
ERF_BW,
DEG2RAD_BW,
POLYGAMMA_BW,
REPEAT_BW,
};

std::vector<Tensor> _acos_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
Expand Down Expand Up @@ -128,7 +127,6 @@ std::vector<Tensor> _eq_bw( const Tensor& grad, const Tensor& input, float other

std::vector<Tensor> _rdiv_bw( const Tensor& grad, const Tensor& input, float scalar, string round_mode = "None", const std::optional<MemoryConfig>& output_mem_config = std::nullopt);

std::vector<Tensor> _repeat_bw(const Tensor& grad, const Tensor& input, const tt::tt_metal::LegacyShape& shape, const std::optional<MemoryConfig>& output_mem_config);
Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_config);

// OpHandler struct template
Expand Down Expand Up @@ -492,13 +490,6 @@ struct OpHandler<UnaryBackwardOpType::RDIV_BW> {
}
};

template <>
struct OpHandler<UnaryBackwardOpType::REPEAT_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, const tt::tt_metal::LegacyShape& shape, const std::optional<MemoryConfig>& output_mem_config ) {
return _repeat_bw(grad, input, shape, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::EQ_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float other, const std::optional<MemoryConfig>& output_mem_config ) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,16 +144,12 @@ struct ExecuteUnaryBackwardStringDefault {
}
};

template <UnaryBackwardOpType unary_backward_op_type>
struct ExecuteUnaryBackwardShape {
struct ExecuteUnaryBackwardRepeat {
static std::vector<Tensor> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
const tt::tt_metal::LegacyShape &parameter_a,
const std::optional<MemoryConfig> &memory_config = std::nullopt) {
auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config());
return OpHandler<unary_backward_op_type>::handle(grad_tensor_arg, input_tensor_arg, parameter_a, output_memory_config);
}
const std::optional<MemoryConfig> &memory_config = std::nullopt);
};

struct ExecuteUnaryBackwardPow {
Expand Down Expand Up @@ -562,8 +558,7 @@ constexpr auto gelu_bw = ttnn::register_operation<

constexpr auto repeat_bw = ttnn::register_operation<
"ttnn::repeat_bw",
operations::unary_backward::ExecuteUnaryBackwardShape<
operations::unary_backward::UnaryBackwardOpType::REPEAT_BW>>();
operations::unary_backward::ExecuteUnaryBackwardRepeat>();

constexpr auto pow_bw = ttnn::register_operation<
"ttnn::pow_bw",
Expand Down

0 comments on commit fa41d12

Please sign in to comment.