Skip to content

Commit

Permalink
#12829: Cleanup softplus_bw, hardtanh_bw, prod_bw (#12864)
Browse files Browse the repository at this point in the history
* #12829: Cleanup softplus_bw

* #12829: Cleanup hardtanh_bw

* #12864: Update files

* #12829: Cleanup prod_bw

* #12864: Restructure ExecuteUnaryBackwardTwoFloatWithDefault
  • Loading branch information
VirdhatchaniKN authored Sep 20, 2024
1 parent 927ff4b commit e2fd289
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ std::vector<Tensor> ExecuteUnaryBackwardClamp::invoke(

// Hardtanh
// result: torch.where((input <= min) | (input >= max), 0.0, grad)
std::vector<Tensor> _hardtanh_bw(
std::vector<Tensor> ExecuteUnaryBackwardHardtanh::invoke(
const Tensor& grad, const Tensor& input, float min, float max, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_result = ttnn::where(
Expand All @@ -81,7 +81,7 @@ std::vector<Tensor> ExecuteUnaryBackwardThreshold::invoke(
}

// Softplus
std::vector<Tensor> _softplus_bw(
std::vector<Tensor> ExecuteUnaryBackwardSoftplus::invoke(
const Tensor& grad, const Tensor& input, float beta, float threshold, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor mul_input_beta = ttnn::multiply(input, beta, std::nullopt, output_mem_config);
Expand Down Expand Up @@ -1466,7 +1466,7 @@ Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_

// Prod
// along a single dimension --> result: grad_data * (y / input )
std::vector<Tensor> _prod_bw(
std::vector<Tensor> ExecuteUnaryBackwardProd::invoke(
const Tensor& grad, const Tensor& input, bool all_dimensions, int64_t dim, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
namespace ttnn::operations::unary_backward {

enum class UnaryBackwardOpType {
HARDTANH_BW,
SOFTPLUS_BW,
DIV_BW,
RDIV_BW,
MULTIGAMMALN_BW,
Expand Down Expand Up @@ -76,7 +74,6 @@ enum class UnaryBackwardOpType {
DEG2RAD_BW,
POLYGAMMA_BW,
REPEAT_BW,
PROD_BW,
};

std::vector<Tensor> _acos_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
Expand Down Expand Up @@ -132,9 +129,6 @@ std::vector<Tensor> _floor_bw( const Tensor& grad, const Tensor& input, const st
std::vector<Tensor> _round_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _log_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);

std::vector<Tensor> _softplus_bw( const Tensor& grad, const Tensor& input, float beta = 1.0, float threshold = 20.0, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
std::vector<Tensor> _hardtanh_bw( const Tensor& grad, const Tensor& input, float min = -1.0, float max = 1.0, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);

std::vector<Tensor> _add_bw( const Tensor& grad, const Tensor& input, float alpha, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
std::vector<Tensor> _eq_bw( const Tensor& grad, const Tensor& input, float other, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);

Expand All @@ -148,21 +142,12 @@ std::vector<Tensor> _logiteps_bw( const Tensor& grad, const Tensor& input, float
std::vector<Tensor> _rdiv_bw( const Tensor& grad, const Tensor& input, float scalar, string round_mode = "None", const std::optional<MemoryConfig>& output_mem_config = std::nullopt);

std::vector<Tensor> _repeat_bw(const Tensor& grad, const Tensor& input, const tt::tt_metal::LegacyShape& shape, const std::optional<MemoryConfig>& output_mem_config);

std::vector<Tensor> _prod_bw( const Tensor& grad, const Tensor& input, bool all_dimensions = true, int64_t dim = 0, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_config);

// OpHandler struct template
template <UnaryBackwardOpType OpType>
struct OpHandler;

template <>
struct OpHandler<UnaryBackwardOpType::HARDTANH_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float min, float max, const std::optional<MemoryConfig>& output_mem_config ) {
return _hardtanh_bw(grad, input, min, max, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::RPOW_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float exponent, const std::optional<MemoryConfig>& output_mem_config ) {
Expand Down Expand Up @@ -555,13 +540,6 @@ struct OpHandler<UnaryBackwardOpType::SUB_BW> {
}
};

template <>
struct OpHandler<UnaryBackwardOpType::SOFTPLUS_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float beta, float threshold, const std::optional<MemoryConfig>& output_mem_config ) {
return _softplus_bw(grad, input, beta, threshold, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::RDIV_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float scalar, string round_mode, const std::optional<MemoryConfig>& output_mem_config ) {
Expand All @@ -576,13 +554,6 @@ struct OpHandler<UnaryBackwardOpType::REPEAT_BW> {
}
};

template <>
struct OpHandler<UnaryBackwardOpType::PROD_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, bool all_dimensions, int64_t dim, const std::optional<MemoryConfig>& output_mem_config ) {
return _prod_bw(grad, input, all_dimensions, dim, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::EQ_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float other, const std::optional<MemoryConfig>& output_mem_config ) {
Expand Down
42 changes: 16 additions & 26 deletions ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,14 @@ struct ExecuteUnaryBackwardWoFloat {

};

template <UnaryBackwardOpType unary_backward_op_type>
struct ExecuteUnaryBackwardTwoFloatWithDefault {
static std::vector<Tensor> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
float parameter_a,
float parameter_b,
const std::optional<MemoryConfig> &memory_config = std::nullopt) {
auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config());
return OpHandler<unary_backward_op_type>::handle(grad_tensor_arg, input_tensor_arg, parameter_a, parameter_b, output_memory_config);
}
#define DEFINE_UNARY_BACKWARD_OPERATION_WITH_2_DEFAULT_FLOATS(op_name) \
struct ExecuteUnaryBackward##op_name { \
static std::vector<Tensor> invoke( \
const Tensor &grad_tensor_arg, \
const Tensor &input_tensor_arg, \
float parameter_a, \
float parameter_b, \
const std::optional<MemoryConfig> &memory_config = std::nullopt); \
};

template <UnaryBackwardOpType unary_backward_op_type>
Expand Down Expand Up @@ -248,17 +245,13 @@ struct ExecuteUnaryBackwardFill {
std::optional<Tensor> input_grad = std::nullopt);
};

template <UnaryBackwardOpType unary_backward_op_type>
struct ExecuteUnaryBackwardProdBW {
struct ExecuteUnaryBackwardProd {
static std::vector<Tensor> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
bool all_dimensions = true,
int64_t dim = 0,
const std::optional<MemoryConfig> &memory_config = std::nullopt) {
auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config());
return OpHandler<unary_backward_op_type>::handle(grad_tensor_arg, input_tensor_arg, all_dimensions, dim, output_memory_config);
}
const std::optional<MemoryConfig> &memory_config = std::nullopt);
};

struct ExecuteUnaryBackwardRecip {
Expand Down Expand Up @@ -305,6 +298,8 @@ struct ExecuteUnaryBackwardGelu{

};

DEFINE_UNARY_BACKWARD_OPERATION_WITH_2_DEFAULT_FLOATS(Softplus)
DEFINE_UNARY_BACKWARD_OPERATION_WITH_2_DEFAULT_FLOATS(Hardtanh)

} // operations::unary

Expand Down Expand Up @@ -569,14 +564,9 @@ constexpr auto clamp_bw = ttnn::register_operation<
"ttnn::clamp_bw",
operations::unary_backward::ExecuteUnaryBackwardClamp>();

constexpr auto softplus_bw = ttnn::register_operation<
"ttnn::softplus_bw",
operations::unary_backward::ExecuteUnaryBackwardTwoFloatWithDefault<
operations::unary_backward::UnaryBackwardOpType::SOFTPLUS_BW>>();
constexpr auto hardtanh_bw = ttnn::register_operation<
"ttnn::hardtanh_bw",
operations::unary_backward::ExecuteUnaryBackwardTwoFloatWithDefault<
operations::unary_backward::UnaryBackwardOpType::HARDTANH_BW>>();
// Tensor + Float(Default) + Float(Default)
constexpr auto hardtanh_bw = ttnn::register_operation<"ttnn::hardtanh_bw", operations::unary_backward::ExecuteUnaryBackwardHardtanh>();
constexpr auto softplus_bw = ttnn::register_operation<"ttnn::softplus_bw", operations::unary_backward::ExecuteUnaryBackwardSoftplus>();

constexpr auto rdiv_bw = ttnn::register_operation<
"ttnn::rdiv_bw",
Expand Down Expand Up @@ -612,7 +602,7 @@ constexpr auto silu_bw = ttnn::register_operation<

constexpr auto prod_bw = ttnn::register_operation<
"ttnn::prod_bw",
operations::unary_backward::ExecuteUnaryBackwardProdBW<operations::unary_backward::UnaryBackwardOpType::PROD_BW>>();
operations::unary_backward::ExecuteUnaryBackwardProd>();

constexpr auto relu_bw = ttnn::register_operation<
"ttnn::relu_bw",
Expand Down

0 comments on commit e2fd289

Please sign in to comment.