From 93428dcb65e0a1a28f09b6853670eb4b360e0dc5 Mon Sep 17 00:00:00 2001 From: "Dimitar (Mitko) Asenov" Date: Tue, 25 Jun 2024 13:27:32 -0700 Subject: [PATCH] [XLA] Add an autotune result wrapper that can be used for externally caching autotune results. PiperOrigin-RevId: 646586644 --- xla/BUILD | 26 +++++++++ xla/autotune_result_wrapper.cc | 88 +++++++++++++++++++++++++++++ xla/autotune_result_wrapper.h | 65 +++++++++++++++++++++ xla/autotune_result_wrapper_test.cc | 81 ++++++++++++++++++++++++++ 4 files changed, 260 insertions(+) create mode 100644 xla/autotune_result_wrapper.cc create mode 100644 xla/autotune_result_wrapper.h create mode 100644 xla/autotune_result_wrapper_test.cc diff --git a/xla/BUILD b/xla/BUILD index ff2a86f8004182..ee846e149bcf20 100644 --- a/xla/BUILD +++ b/xla/BUILD @@ -1215,6 +1215,32 @@ tf_proto_library( ]), ) +cc_library( + name = "autotune_result_wrapper", + srcs = ["autotune_result_wrapper.cc"], + hdrs = ["autotune_result_wrapper.h"], + visibility = ["//visibility:public"], + deps = [ + ":autotune_results_proto_cc_impl", + ":autotuning_proto_cc_impl", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@tsl//tsl/lib/strings:proto_serialization", + ], +) + +cc_test( + name = "autotune_result_wrapper_test", + srcs = ["autotune_result_wrapper_test.cc"], + deps = [ + ":autotune_result_wrapper", + ":autotuning_proto_cc_impl", + "//testing/base/public:gunit_main", + ], +) + cc_library( name = "printer", srcs = ["printer.cc"], diff --git a/xla/autotune_result_wrapper.cc b/xla/autotune_result_wrapper.cc new file mode 100644 index 00000000000000..674671d472e380 --- /dev/null +++ b/xla/autotune_result_wrapper.cc @@ -0,0 +1,88 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/autotune_result_wrapper.h" + +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "xla/autotuning.pb.h" +#include "tsl/lib/strings/proto_serialization.h" + +namespace xla { + +/*static*/ absl::StatusOr +AutotuneResultWrapper::FromKeyAndValue(absl::string_view key, + absl::string_view value) { + AutotuneResults::Entry key_entry; + if (!key_entry.ParseFromString(key)) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Could not parse the provided key"); + } + + AutotuneResults::Entry value_entry; + if (!value_entry.ParseFromString(value)) { + return absl::Status(absl::StatusCode::kInvalidArgument, + "Could not parse the provided value"); + } + + AutotuneResults::Entry full_entry; + full_entry.set_device(key_entry.device()); + full_entry.set_hlo(key_entry.hlo()); + *full_entry.mutable_result() = value_entry.result(); + return AutotuneResultWrapper(full_entry); +} + +std::string AutotuneResultWrapper::Key() const { + AutotuneResults::Entry entry; + entry.set_device(autotune_result_.device()); + entry.set_hlo(autotune_result_.hlo()); + std::string serialized; + CHECK(tsl::SerializeToStringDeterministic(entry, &serialized)); + return serialized; +} + +std::string AutotuneResultWrapper::Value() const { + AutotuneResults::Entry entry; + *entry.mutable_result() = autotune_result_.result(); + std::string serialized; + CHECK(tsl::SerializeToStringDeterministic(entry, &serialized)); + return serialized; +} + +/*static*/ std::vector +AutotuneResultWrapper::AutotuneResultsToWrappers( + const AutotuneResults& autotune_results) { + std::vector wrappers; + wrappers.reserve(autotune_results.results_size()); + for (const auto& result : autotune_results.results()) { + wrappers.push_back(AutotuneResultWrapper(result)); + } + return wrappers; +} + +/*static*/ AutotuneResults AutotuneResultWrapper::AutotuneResultsFromWrappers( + const std::vector& wrappers) { + AutotuneResults autotune_results; + for (const auto& wrapper : wrappers) { + *autotune_results.add_results() = wrapper.Entry(); + } + return autotune_results; +} + +} // namespace xla diff --git a/xla/autotune_result_wrapper.h b/xla/autotune_result_wrapper.h new file mode 100644 index 00000000000000..b4614eb2fe3307 --- /dev/null +++ b/xla/autotune_result_wrapper.h @@ -0,0 +1,65 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_AUTOTUNE_RESULT_WRAPPER_H_ +#define XLA_AUTOTUNE_RESULT_WRAPPER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/autotune_results.pb.h" + +namespace xla { + +// This class is a thin wrapper around AutotuneResults::Entry. It is used to +// provide opaque accessors to an entry's key and value without exposing the +// internal structure of the entry. +class AutotuneResultWrapper { + public: + explicit AutotuneResultWrapper(const AutotuneResults::Entry& result) + : autotune_result_(result) {} + + // Creates an AutotuneResultWrapper from a key and value. The provided key and + // value must be ones that were previously returned by calls to Key() and + // Value(). + static absl::StatusOr FromKeyAndValue( + absl::string_view key, absl::string_view value); + + // An opaque string that can be used as a key for this Autotuning result. + // Do not rely on the format of this string. + std::string Key() const; + + // An opaque string that encodes the autotuning result. + // Do not rely on the format of this string. + std::string Value() const; + + // The AutotuneResults::Entry proto that corresponds to this wrapper. + const AutotuneResults::Entry& Entry() const { return autotune_result_; }; + + static std::vector AutotuneResultsToWrappers( + const AutotuneResults& autotune_results); + + static AutotuneResults AutotuneResultsFromWrappers( + const std::vector& wrappers); + + private: + AutotuneResults::Entry autotune_result_; +}; + +} // namespace xla + +#endif // XLA_AUTOTUNE_RESULT_WRAPPER_H_ diff --git a/xla/autotune_result_wrapper_test.cc b/xla/autotune_result_wrapper_test.cc new file mode 100644 index 00000000000000..400121473edbaf --- /dev/null +++ b/xla/autotune_result_wrapper_test.cc @@ -0,0 +1,81 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/autotune_result_wrapper.h" + +#include +#include +#include + +#include +#include +#include "xla/autotuning.pb.h" + +namespace xla { +namespace { + +AutotuneResults ThreeAutotuneEntries() { + AutotuneResults results; + auto r1 = results.add_results(); + r1->set_device("dev1"); + r1->set_hlo("hlo1"); + r1->mutable_result()->set_scratch_bytes(1); + + auto r2 = results.add_results(); + r2->set_device("dev2"); + r2->set_hlo("hlo2"); + r2->mutable_result()->set_scratch_bytes(2); + + auto r3 = results.add_results(); + r3->set_device("dev3"); + r3->set_hlo("hlo3"); + r3->mutable_result()->set_scratch_bytes(3); + + return results; +} + +TEST(AutotuneResultWrapperTest, FullRoundTrip) { + AutotuneResults results = ThreeAutotuneEntries(); + std::vector wrappers = + AutotuneResultWrapper::AutotuneResultsToWrappers(results); + + std::vector> key_value_pairs; + for (const auto& wrapper : wrappers) { + key_value_pairs.push_back(std::make_pair(wrapper.Key(), wrapper.Value())); + } + + std::vector new_wrappers; + for (const auto& [key, value] : key_value_pairs) { + ASSERT_OK_AND_ASSIGN(AutotuneResultWrapper wrapper, + AutotuneResultWrapper::FromKeyAndValue(key, value)); + new_wrappers.push_back(std::move(wrapper)); + } + + AutotuneResults round_tripped = + AutotuneResultWrapper::AutotuneResultsFromWrappers(new_wrappers); + EXPECT_EQ(round_tripped.results_size(), 3); + EXPECT_EQ(round_tripped.results(0).device(), "dev1"); + EXPECT_EQ(round_tripped.results(0).hlo(), "hlo1"); + EXPECT_EQ(round_tripped.results(0).result().scratch_bytes(), 1); + EXPECT_EQ(round_tripped.results(1).device(), "dev2"); + EXPECT_EQ(round_tripped.results(1).hlo(), "hlo2"); + EXPECT_EQ(round_tripped.results(1).result().scratch_bytes(), 2); + EXPECT_EQ(round_tripped.results(2).device(), "dev3"); + EXPECT_EQ(round_tripped.results(2).hlo(), "hlo3"); + EXPECT_EQ(round_tripped.results(2).result().scratch_bytes(), 3); +} + +} // namespace +} // namespace xla