Skip to content

Commit

Permalink
[XLA] Add an autotune result wrapper that can be used for externally …
Browse files Browse the repository at this point in the history
…caching autotune results.

PiperOrigin-RevId: 646586644
  • Loading branch information
dimitar-asenov authored and copybara-github committed Jun 25, 2024
1 parent 223bb06 commit 93428dc
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 0 deletions.
26 changes: 26 additions & 0 deletions xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,32 @@ tf_proto_library(
]),
)

cc_library(
name = "autotune_result_wrapper",
srcs = ["autotune_result_wrapper.cc"],
hdrs = ["autotune_result_wrapper.h"],
visibility = ["//visibility:public"],
deps = [
":autotune_results_proto_cc_impl",
":autotuning_proto_cc_impl",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@tsl//tsl/lib/strings:proto_serialization",
],
)

cc_test(
name = "autotune_result_wrapper_test",
srcs = ["autotune_result_wrapper_test.cc"],
deps = [
":autotune_result_wrapper",
":autotuning_proto_cc_impl",
"//testing/base/public:gunit_main",
],
)

cc_library(
name = "printer",
srcs = ["printer.cc"],
Expand Down
88 changes: 88 additions & 0 deletions xla/autotune_result_wrapper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/autotune_result_wrapper.h"

#include <string>
#include <vector>

#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "xla/autotuning.pb.h"
#include "tsl/lib/strings/proto_serialization.h"

namespace xla {

/*static*/ absl::StatusOr<AutotuneResultWrapper>
AutotuneResultWrapper::FromKeyAndValue(absl::string_view key,
absl::string_view value) {
AutotuneResults::Entry key_entry;
if (!key_entry.ParseFromString(key)) {
return absl::Status(absl::StatusCode::kInvalidArgument,
"Could not parse the provided key");
}

AutotuneResults::Entry value_entry;
if (!value_entry.ParseFromString(value)) {
return absl::Status(absl::StatusCode::kInvalidArgument,
"Could not parse the provided value");
}

AutotuneResults::Entry full_entry;
full_entry.set_device(key_entry.device());
full_entry.set_hlo(key_entry.hlo());
*full_entry.mutable_result() = value_entry.result();
return AutotuneResultWrapper(full_entry);
}

std::string AutotuneResultWrapper::Key() const {
AutotuneResults::Entry entry;
entry.set_device(autotune_result_.device());
entry.set_hlo(autotune_result_.hlo());
std::string serialized;
CHECK(tsl::SerializeToStringDeterministic(entry, &serialized));
return serialized;
}

std::string AutotuneResultWrapper::Value() const {
AutotuneResults::Entry entry;
*entry.mutable_result() = autotune_result_.result();
std::string serialized;
CHECK(tsl::SerializeToStringDeterministic(entry, &serialized));
return serialized;
}

/*static*/ std::vector<AutotuneResultWrapper>
AutotuneResultWrapper::AutotuneResultsToWrappers(
const AutotuneResults& autotune_results) {
std::vector<AutotuneResultWrapper> wrappers;
wrappers.reserve(autotune_results.results_size());
for (const auto& result : autotune_results.results()) {
wrappers.push_back(AutotuneResultWrapper(result));
}
return wrappers;
}

/*static*/ AutotuneResults AutotuneResultWrapper::AutotuneResultsFromWrappers(
const std::vector<AutotuneResultWrapper>& wrappers) {
AutotuneResults autotune_results;
for (const auto& wrapper : wrappers) {
*autotune_results.add_results() = wrapper.Entry();
}
return autotune_results;
}

} // namespace xla
65 changes: 65 additions & 0 deletions xla/autotune_result_wrapper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_AUTOTUNE_RESULT_WRAPPER_H_
#define XLA_AUTOTUNE_RESULT_WRAPPER_H_

#include <string>
#include <vector>

#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/autotune_results.pb.h"

namespace xla {

// This class is a thin wrapper around AutotuneResults::Entry. It is used to
// provide opaque accessors to an entry's key and value without exposing the
// internal structure of the entry.
class AutotuneResultWrapper {
public:
explicit AutotuneResultWrapper(const AutotuneResults::Entry& result)
: autotune_result_(result) {}

// Creates an AutotuneResultWrapper from a key and value. The provided key and
// value must be ones that were previously returned by calls to Key() and
// Value().
static absl::StatusOr<AutotuneResultWrapper> FromKeyAndValue(
absl::string_view key, absl::string_view value);

// An opaque string that can be used as a key for this Autotuning result.
// Do not rely on the format of this string.
std::string Key() const;

// An opaque string that encodes the autotuning result.
// Do not rely on the format of this string.
std::string Value() const;

// The AutotuneResults::Entry proto that corresponds to this wrapper.
const AutotuneResults::Entry& Entry() const { return autotune_result_; };

static std::vector<AutotuneResultWrapper> AutotuneResultsToWrappers(
const AutotuneResults& autotune_results);

static AutotuneResults AutotuneResultsFromWrappers(
const std::vector<AutotuneResultWrapper>& wrappers);

private:
AutotuneResults::Entry autotune_result_;
};

} // namespace xla

#endif // XLA_AUTOTUNE_RESULT_WRAPPER_H_
81 changes: 81 additions & 0 deletions xla/autotune_result_wrapper_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/autotune_result_wrapper.h"

#include <string>
#include <utility>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "xla/autotuning.pb.h"

namespace xla {
namespace {

AutotuneResults ThreeAutotuneEntries() {
AutotuneResults results;
auto r1 = results.add_results();
r1->set_device("dev1");
r1->set_hlo("hlo1");
r1->mutable_result()->set_scratch_bytes(1);

auto r2 = results.add_results();
r2->set_device("dev2");
r2->set_hlo("hlo2");
r2->mutable_result()->set_scratch_bytes(2);

auto r3 = results.add_results();
r3->set_device("dev3");
r3->set_hlo("hlo3");
r3->mutable_result()->set_scratch_bytes(3);

return results;
}

TEST(AutotuneResultWrapperTest, FullRoundTrip) {
AutotuneResults results = ThreeAutotuneEntries();
std::vector<AutotuneResultWrapper> wrappers =
AutotuneResultWrapper::AutotuneResultsToWrappers(results);

std::vector<std::pair<std::string, std::string>> key_value_pairs;
for (const auto& wrapper : wrappers) {
key_value_pairs.push_back(std::make_pair(wrapper.Key(), wrapper.Value()));
}

std::vector<AutotuneResultWrapper> new_wrappers;
for (const auto& [key, value] : key_value_pairs) {
ASSERT_OK_AND_ASSIGN(AutotuneResultWrapper wrapper,
AutotuneResultWrapper::FromKeyAndValue(key, value));
new_wrappers.push_back(std::move(wrapper));
}

AutotuneResults round_tripped =
AutotuneResultWrapper::AutotuneResultsFromWrappers(new_wrappers);
EXPECT_EQ(round_tripped.results_size(), 3);
EXPECT_EQ(round_tripped.results(0).device(), "dev1");
EXPECT_EQ(round_tripped.results(0).hlo(), "hlo1");
EXPECT_EQ(round_tripped.results(0).result().scratch_bytes(), 1);
EXPECT_EQ(round_tripped.results(1).device(), "dev2");
EXPECT_EQ(round_tripped.results(1).hlo(), "hlo2");
EXPECT_EQ(round_tripped.results(1).result().scratch_bytes(), 2);
EXPECT_EQ(round_tripped.results(2).device(), "dev3");
EXPECT_EQ(round_tripped.results(2).hlo(), "hlo3");
EXPECT_EQ(round_tripped.results(2).result().scratch_bytes(), 3);
}

} // namespace
} // namespace xla

0 comments on commit 93428dc

Please sign in to comment.