Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenXLA|TSL] Replace tsl::Status to absl::Status in profiler.h/.cc #6655

Merged
merged 26 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torch_xla/csrc/convolution_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ std::string ToString(TensorFormat format) {

// Performs some basic checks on ConvOpAttrs that are true for all kinds of
// XLA convolutions (as currently implemented).
tsl::Status CheckConvAttrs(const ConvOpAttrs& attrs) {
xla::Status CheckConvAttrs(const ConvOpAttrs& attrs) {
const int num_dims = attrs.num_spatial_dims + 2;
const int attrs_strides_size = attrs.strides.size();
if (attrs_strides_size != num_dims) {
Expand Down Expand Up @@ -94,7 +94,7 @@ xla::Shape GroupedFilterShapeForDepthwiseConvolution(
// This part of helpers are origionally from
// https://github.com/tensorflow/tensorflow/blob/7f39a389d5b82d6aca13240c21f2647c3ebdb765/tensorflow/core/framework/kernel_shape_util.cc

tsl::Status GetWindowedOutputSizeVerboseV2(
xla::Status GetWindowedOutputSizeVerboseV2(
int64_t input_size, int64_t filter_size, int64_t dilation_rate,
int64_t stride, Padding padding_type, int64_t* output_size,
int64_t* padding_before, int64_t* padding_after) {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ void BuildProfilerSubmodule(py::module* m) {
absl::flat_hash_map<std::string, std::variant<int, std::string>> opts =
ConvertDictToMap(options);
std::chrono::seconds sleep_s(interval_s);
tsl::Status status;
xla::Status status;
{
NoGilSection nogil;
for (int i = 0; i <= timeout_s / interval_s; i++) {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ cc_library(
"@xla//xla/backends/profiler/plugin:profiler_c_api_hdrs",
"@xla//xla/backends/profiler/plugin:plugin_tracer",
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",
"@tsl//tsl/platform:status",
"@xla//xla:status",
"@tsl//tsl/profiler/lib:profiler_factory",
"@tsl//tsl/profiler/rpc:profiler_server_impl",
"@tsl//tsl/profiler/rpc/client:capture_profile",
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/ifrt_computation_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
namespace torch_xla {
namespace runtime {

tsl::StatusOr<xla::XlaComputation> MakeComputation() {
xla::StatusOr<xla::XlaComputation> MakeComputation() {
xla::Shape input_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2});
xla::XlaBuilder builder("AddComputation");
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(
metrics::TimedSection timed(TransferFromDeviceMetric());
tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromDevice",
tsl::profiler::TraceMeLevel::kInfo);
std::vector<xla::PjRtFuture<tsl::Status>> futures;
std::vector<xla::PjRtFuture<absl::Status>> futures;
futures.reserve(handles.size());
std::vector<xla::Literal> literals;
literals.reserve(handles.size());
Expand All @@ -403,7 +403,7 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(
total_size += literal.size_bytes();
}
for (auto& future : futures) {
tsl::Status status = future.Await();
absl::Status status = future.Await();
XLA_CHECK_OK(status);
}
InboundDataMetric()->AddSample(total_size);
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace torch_xla {
namespace runtime {

tsl::StatusOr<xla::XlaComputation> MakeComputation() {
xla::StatusOr<xla::XlaComputation> MakeComputation() {
xla::Shape input_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2});
xla::XlaBuilder builder("AddComputation");
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ InitializePjRt(const std::string& device_type) {
env::kEnvTpuLibraryPath,
sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so"));
XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status());
tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu");
xla::Status tpu_status = pjrt::InitializePjrtPlugin("tpu");
XLA_CHECK_OK(tpu_status);
client = std::move(xla::GetCApiClient("TPU").value());
const PJRT_Api* c_api =
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

#include "absl/container/flat_hash_map.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "tsl/platform/status.h"
#include "tsl/profiler/lib/profiler_factory.h"
#include "tsl/profiler/rpc/client/capture_profile.h"
#include "tsl/profiler/rpc/profiler_server.h"
#include "xla/backends/profiler/plugin/plugin_tracer.h"
#include "xla/backends/profiler/plugin/profiler_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_profiler_extension.h"
#include "xla/status.h"

namespace torch_xla {
namespace runtime {
Expand Down Expand Up @@ -45,7 +45,7 @@ void ProfilerServer::Start(int port) {

ProfilerServer::~ProfilerServer() {}

tsl::Status Trace(
xla::Status Trace(
const char* service_addr, const char* logdir, int duration_ms,
int num_tracing_attempts,
const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/runtime/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
#include <memory>

#include "absl/container/flat_hash_map.h"
#include "tsl/platform/status.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/status.h"

namespace torch_xla {
namespace runtime {
Expand All @@ -23,7 +23,7 @@ class ProfilerServer {
std::unique_ptr<Impl> impl_;
};

tsl::Status Trace(
xla::Status Trace(
const char* service_addr, const char* logdir, int duration_ms,
int num_tracing_attempts,
const absl::flat_hash_map<std::string, std::variant<int, std::string>>&
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/runtime/tf_logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <sstream>

#include "tsl/platform/logging.h"
#include "tsl/platform/status.h"
#include "xla/status.h"

namespace torch_xla {
Expand Down
Loading