Skip to content

Commit

Permalink
[xla:cpu] Add extern templates for Conv2D and Conv3D.
Browse files Browse the repository at this point in the history
These templates were instantiated twice (once for current runtime, once for thunks runtime). Now they are instantiated once. It reduces binary size and compilation time.

PiperOrigin-RevId: 645720154
  • Loading branch information
Adam-Banas authored and copybara-github committed Jun 22, 2024
1 parent 82b0e6e commit 6f3d842
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 0 deletions.
4 changes: 4 additions & 0 deletions xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ filegroup(
srcs = [
# Single-threaded support.
"runtime_custom_call_status.cc",
"runtime_conv_impl.cc",
"runtime_fp16.cc",
"runtime_key_value_sort.cc",
"runtime_pow.cc",
Expand Down Expand Up @@ -1012,12 +1013,15 @@ cc_library(

cc_library(
name = "runtime_conv_impl",
srcs = ["runtime_conv_impl.cc"],
hdrs = ["runtime_conv_impl.h"],
copts = runtime_copts(),
visibility = internal_visibility([":friends"]),
deps = [
"//xla/tsl/framework/contraction:eigen_contraction_kernel",
"//xla/tsl/framework/convolution:eigen_helpers",
"@eigen_archive//:eigen3",
"@tsl//tsl/platform:mutex", # build_cleaner: keep
],
)

Expand Down
66 changes: 66 additions & 0 deletions xla/service/cpu/runtime_conv_impl.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS

#include "xla/service/cpu/runtime_conv_impl.h"

namespace tensorflow::xla {

// Instantiate Conv2D template for all supported devices and data types.
#define CONV2D_INSTANTIATE_TEMPLATE(EigenDevice, ScalarType) \
template void EigenConv2DImpl<EigenDevice, ScalarType>( \
const EigenDevice& device, ScalarType* out, ScalarType* lhs, \
ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \
Eigen::Index input_y, Eigen::Index input_channels, \
Eigen::Index kernel_x, Eigen::Index kernel_y, \
Eigen::Index kernel_channels, Eigen::Index kernel_filters, \
Eigen::Index output_x, Eigen::Index output_y, Eigen::Index x_stride, \
Eigen::Index y_stride, Eigen::Index padding_x_before, \
Eigen::Index padding_x_after, Eigen::Index padding_y_before, \
Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, \
Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, \
Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count)

CONV2D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, Eigen::half);
CONV2D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, float);
CONV2D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half);
CONV2D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, float);

#undef CONV2D_INSTANTIATE_TEMPLATE

// Instantiate Conv3D template for all supported devices and data types.
#define CONV3D_INSTANTIATE_TEMPLATE(EigenDevice, ScalarType) \
template void EigenConv3DImpl<EigenDevice, ScalarType>( \
const EigenDevice& device, ScalarType* out, ScalarType* lhs, \
ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \
Eigen::Index input_y, Eigen::Index input_z, Eigen::Index input_channels, \
Eigen::Index kernel_x, Eigen::Index kernel_y, Eigen::Index kernel_z, \
Eigen::Index kernel_channels, Eigen::Index kernel_filters, \
Eigen::Index output_x, Eigen::Index output_y, Eigen::Index output_z, \
Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index z_stride, \
Eigen::Index padding_x_before, Eigen::Index padding_x_after, \
Eigen::Index padding_y_before, Eigen::Index padding_y_after, \
Eigen::Index padding_z_before, Eigen::Index padding_z_after, \
Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation, \
Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, \
Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, \
Eigen::Index feature_group_count)

CONV3D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, Eigen::half);
CONV3D_INSTANTIATE_TEMPLATE(Eigen::DefaultDevice, float);
CONV3D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half);
CONV3D_INSTANTIATE_TEMPLATE(Eigen::ThreadPoolDevice, float);

} // namespace tensorflow::xla
50 changes: 50 additions & 0 deletions xla/service/cpu/runtime_conv_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,56 @@ void EigenConv3DImpl(
}
}

// Extern Conv2D template for all supported devices and data types.
// TODO(abanas): These templates are instantiated in convolution_thunk.cc. Move
// the definitions from this file to convolution thunk, and make all runtime
// conv targets depend on it.
#define CONV2D_EXTERN_TEMPLATE(EigenDevice, ScalarType) \
extern template void EigenConv2DImpl<EigenDevice, ScalarType>( \
const EigenDevice& device, ScalarType* out, ScalarType* lhs, \
ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \
Eigen::Index input_y, Eigen::Index input_channels, \
Eigen::Index kernel_x, Eigen::Index kernel_y, \
Eigen::Index kernel_channels, Eigen::Index kernel_filters, \
Eigen::Index output_x, Eigen::Index output_y, Eigen::Index x_stride, \
Eigen::Index y_stride, Eigen::Index padding_x_before, \
Eigen::Index padding_x_after, Eigen::Index padding_y_before, \
Eigen::Index padding_y_after, Eigen::Index lhs_x_dilation, \
Eigen::Index lhs_y_dilation, Eigen::Index rhs_x_dilation, \
Eigen::Index rhs_y_dilation, Eigen::Index feature_group_count)

CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half);
CONV2D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float);
CONV2D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half);
CONV2D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, float);

#undef CONV2D_EXTERN_TEMPLATE

// Extern Conv3D template for all supported devices and data types.
#define CONV3D_EXTERN_TEMPLATE(EigenDevice, ScalarType) \
extern template void EigenConv3DImpl<EigenDevice, ScalarType>( \
const EigenDevice& device, ScalarType* out, ScalarType* lhs, \
ScalarType* rhs, Eigen::Index input_batch, Eigen::Index input_x, \
Eigen::Index input_y, Eigen::Index input_z, Eigen::Index input_channels, \
Eigen::Index kernel_x, Eigen::Index kernel_y, Eigen::Index kernel_z, \
Eigen::Index kernel_channels, Eigen::Index kernel_filters, \
Eigen::Index output_x, Eigen::Index output_y, Eigen::Index output_z, \
Eigen::Index x_stride, Eigen::Index y_stride, Eigen::Index z_stride, \
Eigen::Index padding_x_before, Eigen::Index padding_x_after, \
Eigen::Index padding_y_before, Eigen::Index padding_y_after, \
Eigen::Index padding_z_before, Eigen::Index padding_z_after, \
Eigen::Index lhs_x_dilation, Eigen::Index lhs_y_dilation, \
Eigen::Index lhs_z_dilation, Eigen::Index rhs_x_dilation, \
Eigen::Index rhs_y_dilation, Eigen::Index rhs_z_dilation, \
Eigen::Index feature_group_count)

CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, Eigen::half);
CONV3D_EXTERN_TEMPLATE(Eigen::DefaultDevice, float);
CONV3D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, Eigen::half);
CONV3D_EXTERN_TEMPLATE(Eigen::ThreadPoolDevice, float);

#undef CONV3D_EXTERN_TEMPLATE

} // namespace xla
} // namespace tensorflow

Expand Down

0 comments on commit 6f3d842

Please sign in to comment.