diff --git a/xla/service/cpu/runtime/BUILD b/xla/service/cpu/runtime/BUILD index 24c791695e369..d6429b5e6d513 100644 --- a/xla/service/cpu/runtime/BUILD +++ b/xla/service/cpu/runtime/BUILD @@ -771,6 +771,29 @@ xla_cc_test( ], ) +cc_library( + name = "resource_use", + srcs = ["resource_use.cc"], + hdrs = ["resource_use.h"], + deps = [ + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/types:span", + ], +) + +xla_cc_test( + name = "resource_use_test", + srcs = ["resource_use_test.cc"], + deps = [ + ":resource_use", + "//xla/service:buffer_assignment", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "rng_state_thunk", srcs = ["rng_state_thunk.cc"], diff --git a/xla/service/cpu/runtime/resource_use.cc b/xla/service/cpu/runtime/resource_use.cc new file mode 100644 index 0000000000000..3e5ceabb9ac53 --- /dev/null +++ b/xla/service/cpu/runtime/resource_use.cc @@ -0,0 +1,77 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/resource_use.h" + +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/types/span.h" + +namespace xla::cpu { + +std::shared_ptr Resource::Create(Kind kind) { + return absl::WrapUnique(new Resource(kind)); +} + +Resource::Resource(Kind kind) : kind_(kind) {} + +ResourceUse::ResourceUse(std::shared_ptr resource, + ResourceAccess access) + : resource_(resource), access_(access) {} + +ResourceUse::ReadWriteSet::ReadWriteSet() = default; + +void ResourceUse::ReadWriteSet::Add(ResourceUse use) { + switch (use.access()) { + case ResourceUse::kRead: + read_.insert(use.resource()); + break; + case ResourceUse::kWrite: + write_.insert(use.resource()); + break; + } +} + +void ResourceUse::ReadWriteSet::AddAll(absl::Span uses) { + for (const auto& use : uses) Add(use); +} + +bool ResourceUse::ReadWriteSet::HasConflicts(const ResourceUse& use) const { + return use.access() == ResourceAccess::kWrite + ? write_.contains(use.resource()) || read_.contains(use.resource()) + : write_.contains(use.resource()); +} + +bool ResourceUse::ReadWriteSet::HasConflicts( + absl::Span uses) const { + return absl::c_any_of( + uses, [&](const ResourceUse& use) { return HasConflicts(use); }); +} + +bool ResourceUse::ReadWriteSet::HasConflicts(const ReadWriteSet& other) { + return absl::c_any_of(other.read_, + [&](const std::shared_ptr& resource) { + return HasConflicts(ResourceUse::Read(resource)); + }) || + absl::c_any_of(other.write_, + [&](const std::shared_ptr& resource) { + return HasConflicts(ResourceUse::Write(resource)); + }); +} + +} // namespace xla::cpu diff --git a/xla/service/cpu/runtime/resource_use.h b/xla/service/cpu/runtime/resource_use.h new file mode 100644 index 0000000000000..6ee1f1bfd6ac9 --- /dev/null +++ b/xla/service/cpu/runtime/resource_use.h @@ -0,0 +1,114 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_RUNTIME_RESOURCE_USE_H_ +#define XLA_SERVICE_CPU_RUNTIME_RESOURCE_USE_H_ + +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" + +namespace xla::cpu { + +// `Resource` models a run time resource that imposes ordering on the thunk +// execution in addition to thunk buffer uses. +class Resource { + public: + enum class Kind { + // Side-effecting operations (i.e., infeed and outfeed) define their + // execution order via token dependencies. We rely on token resource to + // enforce ordering at run time. + kToken, + + // Collective operations must be executed in the same order as they are + // defined in the HLO module. We rely on collective communicator resource + // to enforce ordering at run time. + kCollectiveCommunicator + }; + + static constexpr Kind kToken = Kind::kToken; + static constexpr Kind kCollectiveCommunicator = Kind::kCollectiveCommunicator; + + static std::shared_ptr Create(Kind kind); + + Kind kind() const { return kind_; } + + private: + explicit Resource(Kind kind); + Kind kind_; +}; + +// For consistency with BufferUse, we model resource uses as writes or reads +// to and from resource. Resources have referential equality: we rely on +// comparing pointers to check if resource is the same or not. +class ResourceUse { + public: + enum class ResourceAccess { kRead, kWrite }; + + static constexpr ResourceAccess kRead = ResourceAccess::kRead; + static constexpr ResourceAccess kWrite = ResourceAccess::kWrite; + + static ResourceUse Read(std::shared_ptr resource) { + return ResourceUse(std::move(resource), ResourceAccess::kRead); + } + + static ResourceUse Write(std::shared_ptr resource) { + return ResourceUse(std::move(resource), ResourceAccess::kWrite); + } + + const std::shared_ptr& resource() const { return resource_; } + ResourceAccess access() const { return access_; } + + // ReadWriteSet tracks a set of read and write resources. + class ReadWriteSet { + public: + ReadWriteSet(); + + void Add(ResourceUse use); + void AddAll(absl::Span uses); + + // Returns true if any of the resource use(s) has a conflict with tracked + // resource reads or writes. + bool HasConflicts(const ResourceUse& use) const; + bool HasConflicts(absl::Span uses) const; + bool HasConflicts(const ReadWriteSet& other); + + private: + absl::flat_hash_set> read_; + absl::flat_hash_set> write_; + }; + + bool operator==(const ResourceUse& other) const { + return resource_ == other.resource_ && access_ == other.access_; + } + + bool operator!=(const ResourceUse& other) const { return !(*this == other); } + + template + friend H AbslHashValue(H h, const ResourceUse& use) { + return H::combine(std::move(h), use.resource_, use.access_); + } + + private: + ResourceUse(std::shared_ptr resource, ResourceAccess access); + std::shared_ptr resource_; + ResourceAccess access_; +}; + +} // namespace xla::cpu + +#endif // XLA_SERVICE_CPU_RUNTIME_RESOURCE_USE_H_ diff --git a/xla/service/cpu/runtime/resource_use_test.cc b/xla/service/cpu/runtime/resource_use_test.cc new file mode 100644 index 0000000000000..4d3c9bbaf4cec --- /dev/null +++ b/xla/service/cpu/runtime/resource_use_test.cc @@ -0,0 +1,53 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/runtime/resource_use.h" + +#include "tsl/platform/test.h" + +namespace xla::cpu { +namespace { + +TEST(ResourceUseTest, Equality) { + auto token = Resource::Create(Resource::kToken); + auto use0 = ResourceUse::Read(token); + auto use1 = ResourceUse::Write(token); + auto use2 = ResourceUse::Read(token); + + EXPECT_NE(use0, use1); + EXPECT_EQ(use0, use2); +} + +TEST(ResourceUseTest, ReadWriteSet) { + ResourceUse::ReadWriteSet rwset; + + auto token0 = Resource::Create(Resource::kToken); + auto token1 = Resource::Create(Resource::kToken); + + rwset.Add(ResourceUse::Read(token0)); + EXPECT_FALSE(rwset.HasConflicts({ResourceUse::Read(token0)})); + EXPECT_TRUE(rwset.HasConflicts({ResourceUse::Write(token0)})); + EXPECT_FALSE(rwset.HasConflicts({ResourceUse::Read(token1)})); + EXPECT_FALSE(rwset.HasConflicts({ResourceUse::Write(token1)})); + + rwset.Add(ResourceUse::Write(token0)); + EXPECT_TRUE(rwset.HasConflicts({ResourceUse::Read(token0)})); + EXPECT_TRUE(rwset.HasConflicts({ResourceUse::Write(token0)})); + EXPECT_FALSE(rwset.HasConflicts({ResourceUse::Read(token1)})); + EXPECT_FALSE(rwset.HasConflicts({ResourceUse::Write(token1)})); +} + +} // namespace +} // namespace xla::cpu