Skip to content

Commit

Permalink
#12864: Restructure ExecuteUnaryBackwardTwoFloatWithDefault
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Sep 20, 2024
1 parent efcfaa5 commit 468c038
Showing 1 changed file with 13 additions and 23 deletions.
36 changes: 13 additions & 23 deletions ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,14 @@ struct ExecuteUnaryBackwardWoFloat {

};

struct ExecuteUnaryBackwardSoftplus {
static std::vector<Tensor> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
float parameter_a,
float parameter_b,
const std::optional<MemoryConfig> &memory_config = std::nullopt);
};

struct ExecuteUnaryBackwardHardtanh {
static std::vector<Tensor> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
float parameter_a,
float parameter_b,
const std::optional<MemoryConfig> &memory_config = std::nullopt);
#define DEFINE_UNARY_BACKWARD_OPERATION_WITH_2_DEFAULT_FLOATS(op_name) \
struct ExecuteUnaryBackward##op_name { \
static std::vector<Tensor> invoke( \
const Tensor &grad_tensor_arg, \
const Tensor &input_tensor_arg, \
float parameter_a, \
float parameter_b, \
const std::optional<MemoryConfig> &memory_config = std::nullopt); \
};

template <UnaryBackwardOpType unary_backward_op_type>
Expand Down Expand Up @@ -306,6 +298,8 @@ struct ExecuteUnaryBackwardGelu{

};

DEFINE_UNARY_BACKWARD_OPERATION_WITH_2_DEFAULT_FLOATS(Softplus)
DEFINE_UNARY_BACKWARD_OPERATION_WITH_2_DEFAULT_FLOATS(Hardtanh)

} // operations::unary

Expand Down Expand Up @@ -570,13 +564,9 @@ constexpr auto clamp_bw = ttnn::register_operation<
"ttnn::clamp_bw",
operations::unary_backward::ExecuteUnaryBackwardClamp>();

constexpr auto softplus_bw = ttnn::register_operation<
"ttnn::softplus_bw",
operations::unary_backward::ExecuteUnaryBackwardSoftplus>();

constexpr auto hardtanh_bw = ttnn::register_operation<
"ttnn::hardtanh_bw",
operations::unary_backward::ExecuteUnaryBackwardHardtanh>();
// Tensor + Float(Default) + Float(Default)
constexpr auto hardtanh_bw = ttnn::register_operation<"ttnn::hardtanh_bw", operations::unary_backward::ExecuteUnaryBackwardHardtanh>();
constexpr auto softplus_bw = ttnn::register_operation<"ttnn::softplus_bw", operations::unary_backward::ExecuteUnaryBackwardSoftplus>();

constexpr auto rdiv_bw = ttnn::register_operation<
"ttnn::rdiv_bw",
Expand Down

0 comments on commit 468c038

Please sign in to comment.