Skip to content

Commit

Permalink
Add a small int64 workaround for elementwise ops that don't support it (
Browse files Browse the repository at this point in the history
#391)

It turns out that DIVIDE, MODULUS_FLOOR, DIFFERENCE_SQUARE, MODULUS_TRUNCATE and POW don't yet support emulated int64 in DirectML, so we can't entirely get rid of int64 workarounds.
PatriceVignola authored Sep 6, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 8e565a1 commit a4a0e27
Showing 2 changed files with 83 additions and 16 deletions.
3 changes: 3 additions & 0 deletions tensorflow/core/common_runtime/dml/dml_common.h
Original file line number Diff line number Diff line change
@@ -131,6 +131,9 @@ static constexpr uint32_t kNchwSpatialDimensionCount = 2;
static constexpr uint32_t kNcdhwDimensionCount = 5;
static constexpr uint32_t kNcdhwSpatialDimensionCount = 3;

// 8 dimensions are supported for elementwise operators
static constexpr uint32_t kBinaryCwiseOpMaxDimCount = 8;

// The batch and channel dimensions of NCW, NCHW, NCDHW....
static constexpr uint32_t kNonspatialDimensionCount = 2;

96 changes: 80 additions & 16 deletions tensorflow/core/kernels/dml_cwise_ops.cc
Original file line number Diff line number Diff line change
@@ -450,8 +450,16 @@ REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(RealDiv, x / y, 8, true, Eigen::half,
// cwise_op_floor_div.cc).
REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(FloorDiv, dml::Floor(x / y), 8, true,
Eigen::half, float)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(FloorMod, dml::ModulusFloor(x, y), 8,
true, Eigen::half, float, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(FloorMod, dml::ModulusFloor(x, y), 8,
true, Eigen::half, float)
// TODO: Revisit this and consider having a native int64 alternative
// TFDML #41163316
REGISTER_DML_COMPOSITE_BINARY_KERNEL_1(
FloorMod,
dml::Cast(dml::ModulusFloor(dml::Cast(x, DML_TENSOR_DATA_TYPE_INT32),
dml::Cast(y, DML_TENSOR_DATA_TYPE_INT32)),
DML_TENSOR_DATA_TYPE_INT64),
8, false, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(SigmoidGrad, (y * x * (1 - x)), 8, false,
Eigen::half, float)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(TanhGrad, (y * (1 - x * x)), 8, false,
@@ -481,25 +489,53 @@ REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Minimum, dml::Min(x, y), 8, true,
// cwise_op_maximum.cc).
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Maximum, dml::Max(x, y), 8, true,
Eigen::half, float, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(SquaredDifference,
REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(SquaredDifference,
dml::DifferenceSquare(x, y), 8, true,
Eigen::half, float, int64)
Eigen::half, float)
// TODO: Revisit this and consider having a native int64 alternative
// TFDML #41163316
REGISTER_DML_COMPOSITE_BINARY_KERNEL_1(SquaredDifference, (x - y) * (x - y), 8,
false, int64)
// TODO(b/25387198): A special kernel exists for int32 (see cwise_op_mul1.cc).
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Mul, (x * y), 8, true, Eigen::half,
float, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Pow, dml::Pow(x, y), 8, true,
Eigen::half, float, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_2(Pow, dml::Pow(x, y), 8, true,
Eigen::half, float)
// TODO: Revisit this and consider having a native int64 alternative
// TFDML #41163316
REGISTER_DML_COMPOSITE_BINARY_KERNEL_1(
Pow,
dml::Cast(dml::Pow(dml::Cast(x, DML_TENSOR_DATA_TYPE_INT32),
dml::Cast(y, DML_TENSOR_DATA_TYPE_INT32)),
DML_TENSOR_DATA_TYPE_INT64),
8, false, int64)
// TODO(b/25387198): A special kernel exists for int32 (see cwise_op_add1.cc).
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(Add, x + y, 8, true, Eigen::half, float,
int64)
// TODO(b/25387198): A special kernel exists for int32 (see cwise_op_add1.cc).
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(AddV2, x + y, 8, true, Eigen::half,
float, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_4(TruncateDiv, x / y, 8, true, uint8,
uint16, int16, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_3(TruncateDiv, x / y, 8, true, uint8,
uint16, int16)
// TODO: Revisit this and consider having a native int64 alternative
// TFDML #41163316
REGISTER_DML_COMPOSITE_BINARY_KERNEL_1(
TruncateDiv,
dml::Cast(dml::Cast(x, DML_TENSOR_DATA_TYPE_INT32) /
dml::Cast(y, DML_TENSOR_DATA_TYPE_INT32),
DML_TENSOR_DATA_TYPE_INT64),
8, false, int64)
// TODO(b/25387198): A special kernel exists for int32 (see cwise_op_div.cc).
REGISTER_DML_COMPOSITE_BINARY_KERNEL_6(Div, x / y, 8, true, Eigen::half, float,
uint8, uint16, int16, int64)
REGISTER_DML_COMPOSITE_BINARY_KERNEL_5(Div, x / y, 8, true, Eigen::half, float,
uint8, uint16, int16)
// TODO: Revisit this and consider having a native int64 alternative
// TFDML #41163316
REGISTER_DML_COMPOSITE_BINARY_KERNEL_1(
Div,
dml::Cast(dml::Cast(x, DML_TENSOR_DATA_TYPE_INT32) /
dml::Cast(y, DML_TENSOR_DATA_TYPE_INT32),
DML_TENSOR_DATA_TYPE_INT64),
8, false, int64)
// TODO(b/25387198): A special kernel exists for int32 (see
// cwise_op_greater.cc).
REGISTER_DML_COMPOSITE_BINARY_KERNEL_6(Greater, x > y, 8, false, Eigen::half,
@@ -869,10 +905,40 @@ class DmlLeakyReluKernel : public DmlKernel {
TF_CALL_DML_FLOAT_TYPES(DML_REGISTER_KERNEL);
#undef DML_REGISTER_KERNEL

class ApproximateEqualInitHelper
: public ElementWiseInitHelper<kBinaryCwiseOpMaxDimCount> {
public:
struct Attributes
: public ElementWiseInitHelper<kBinaryCwiseOpMaxDimCount>::Attributes {
explicit Attributes(OpKernelConstruction* ctx)
: ElementWiseInitHelper<kBinaryCwiseOpMaxDimCount>::Attributes(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("tolerance", &tolerance));
}
float tolerance;
};
ApproximateEqualInitHelper(OpKernelContext* ctx,
std::shared_ptr<const Attributes> attr)
: ElementWiseInitHelper<kBinaryCwiseOpMaxDimCount>(ctx, attr),
tolerance_(attr->tolerance) {
const Tensor& x_input = ctx->input(0);
const Tensor& y_input = ctx->input(1);
OP_REQUIRES(
ctx, x_input.shape() == y_input.shape(),
errors::InvalidArgument("x and y must be of the same shape. ",
"x shape: ", x_input.shape().DebugString(),
". y shape: ", y_input.shape().DebugString()));
}

float GetTolerance() const { return tolerance_; }

private:
float tolerance_;
};

template <typename T>
class DmlApproximateEqualKernel : public DmlKernel {
public:
using InitHelper = ElementWiseInitHelper<kNchwDimensionCount>;
using InitHelper = ApproximateEqualInitHelper;

explicit DmlApproximateEqualKernel(DmlKernelConstruction* ctx,
const InitHelper* init_helper) {
@@ -891,11 +957,9 @@ class DmlApproximateEqualKernel : public DmlKernel {
auto x = dml::InputTensor(scope, 0, inputs[0]);
auto y = dml::InputTensor(scope, 1, inputs[1]);

float tolerance;
TF_CHECK_OK(ctx->GetAttr("tolerance", &tolerance));
auto tolerance_tensor =
dml::ScalarTensor<T>(scope, TfTensorTypeTraits<T>::FromFloat(tolerance),
x.GetOutputDesc().sizes);
auto tolerance_tensor = dml::ScalarTensor<T>(
scope, TfTensorTypeTraits<T>::FromFloat(init_helper->GetTolerance()),
x.GetOutputDesc().sizes);

auto result = dml::Abs(x - y) < tolerance_tensor;

0 comments on commit a4a0e27

Please sign in to comment.