Skip to content

Commit

Permalink
[XLA:Python] Avoid copying an nb::detail::dict_iterator.
Browse files Browse the repository at this point in the history
Nanobind 2.2.0 makes dict iterators uncopyable.

In addition, avoid a possible exception-safety problem where Python .equals() was called from an equality test used by an ABSL hash table.

PiperOrigin-RevId: 679295293
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Sep 26, 2024
1 parent 06bbcd1 commit 95beb0e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 78 deletions.
2 changes: 0 additions & 2 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1109,11 +1109,9 @@ cc_library(
features = ["-use_header_modules"],
visibility = ["//visibility:private"],
deps = [
":nb_helpers",
# placeholder for index annotation deps
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
"@nanobind",
Expand Down
143 changes: 67 additions & 76 deletions xla/python/weakref_lru_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ limitations under the License.
#include <cstdint>
#include <iterator>
#include <memory>
#include <optional>
#include <string>
#include <thread> // NOLINT
#include <unordered_map>
#include <utility>
#include <vector>

#include "absl/base/thread_annotations.h"
#include "absl/cleanup/cleanup.h"
#include "absl/container/node_hash_map.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
Expand All @@ -35,7 +36,6 @@ limitations under the License.
#include "nanobind/stl/string.h" // IWYU pragma: keep
#include "nanobind/stl/vector.h" // IWYU pragma: keep
#include "xla/pjrt/lru_cache.h"
#include "xla/python/nb_helpers.h"

namespace nb = nanobind;

Expand All @@ -44,36 +44,38 @@ namespace {

// Minimal wrapper to expose a nb::dict_iterator's value as something
// hashable with Abseil.
class HashablePyDictValue {
protected:
using Iter = nb::detail::dict_iterator;
class HashablePyDictEntry {
public:
explicit HashablePyDictEntry(std::pair<nb::handle, nb::handle> entry)
: entry_(entry) {}

template <typename H>
friend H AbslHashValue(H h, const HashablePyDictValue& value) {
auto kv = *value.iter_;
return H::combine(std::move(h), nb::hash(kv.first), nb::hash(kv.second));
friend H AbslHashValue(H h, const HashablePyDictEntry& v) {
return H::combine(std::move(h), nb::hash(v.entry_.first),
nb::hash(v.entry_.second));
}

explicit HashablePyDictValue(const Iter& iter) : iter_(iter) {}

Iter iter_;
std::pair<nb::handle, nb::handle> entry_;
};

// Similarly, a minimalist adaptor around the nb::detail::dict_iterator
// itself. Note that the iterator "is" also a Value. Does not meet the full
// standard iterator requirements, only enough to support H::combine_unordered.
class HashablePyDictIter : protected HashablePyDictValue {
class HashablePyDictIter {
public:
using iterator_category = std::input_iterator_tag;

explicit HashablePyDictIter(const Iter& iter) : HashablePyDictValue(iter) {}
explicit HashablePyDictIter(nb::detail::dict_iterator& iter) : iter_(iter) {}

// Minimal set of iterator operations.
const HashablePyDictValue& operator*() const { return *this; }
HashablePyDictEntry operator*() const { return HashablePyDictEntry(*iter_); }
bool operator!=(const HashablePyDictIter& rhs) const {
return iter_ != rhs.iter_;
}
void operator++() { ++iter_; }

private:
nb::detail::dict_iterator& iter_;
};

} // namespace
Expand All @@ -92,10 +94,15 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {

template <typename H>
friend H AbslHashValue(H h, const Key& key) {
// Note: Despite the fact this is an ABSL hash function, it's safe to call
// functions that may throw exceptions such as nb::hash(), because it is
// used by an LRUCache, which uses a std::unordered_map, which is
// exception-safe.
h = H::combine(std::move(h), nb::hash(key.context), nb::hash(key.args));
h = H::combine_unordered(std::move(h),
HashablePyDictIter(key.kwargs.begin()),
HashablePyDictIter(key.kwargs.end()));
nb::detail::dict_iterator begin = key.kwargs.begin();
nb::detail::dict_iterator end = key.kwargs.end();
h = H::combine_unordered(std::move(h), HashablePyDictIter(begin),
HashablePyDictIter(end));
h = H::combine(std::move(h), key.kwargs.size());
return h;
}
Expand All @@ -115,82 +122,65 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
int64_t currsize;
};

struct UnboundWeakrefCacheEntry {
struct WeakrefCacheKey {
nb::handle object;
WeakrefLRUCache* cache;
size_t cached_hash;
};

struct WeakrefCacheEntry {
nb::weakref weakref;
size_t cached_hash;
using Cache = xla::LRUCache<Key, std::shared_ptr<CacheEntry>>;

struct WeakrefCacheValue {
std::optional<nb::weakref> weakref;
std::shared_ptr<Cache> cache;
};

struct WeakrefKeyHash {
using is_transparent = void;

size_t operator()(const UnboundWeakrefCacheEntry& v) const {
return v.cached_hash;
}
size_t operator()(const WeakrefCacheEntry& v) const {
return v.cached_hash;
}
size_t operator()(const WeakrefCacheKey& v) const { return v.cached_hash; }
};

struct WeakrefKeyEq {
using is_transparent = void;
bool operator()(const WeakrefCacheEntry& lhs,
const WeakrefCacheEntry& rhs) const {
return lhs.weakref.equal(rhs.weakref);
}
bool operator()(const WeakrefCacheEntry& lhs,
const UnboundWeakrefCacheEntry& rhs) const {
PyObject* obj = PyWeakref_GET_OBJECT(lhs.weakref.ptr());
if (obj == Py_None) {
return false;
}
return nb::borrow<nb::object>(obj).equal(rhs.object);
bool operator()(const WeakrefCacheKey& lhs,
const WeakrefCacheKey& rhs) const {
return lhs.object.equal(rhs.object);
}
};

using Cache = xla::LRUCache<Key, std::shared_ptr<CacheEntry>>;
WeakrefLRUCache(nb::callable cache_context_fn, nb::callable fn,
int64_t maxsize)
: cache_context_fn_(cache_context_fn), fn_(fn), lru_list_(maxsize) {}

std::shared_ptr<Cache> GetCache(const UnboundWeakrefCacheEntry& key) {
auto it = entries_.find(key);
if (it != entries_.end()) {
return (it->second);
std::shared_ptr<Cache> GetCache(WeakrefCacheKey key) {
auto [it, inserted] = entries_.emplace(key, WeakrefCacheValue());
if (!inserted) {
return it->second.cache;
}
nb::weakref weakref(
key.object,
nb::cpp_function([this_weak = weak_from_this(),
cached_hash = key.cached_hash](nb::handle weakref) {
auto cache = this_weak.lock();
if (cache == nullptr) {
return;
}
auto it = cache->entries_.find(
WeakrefCacheEntry{nb::borrow<nb::weakref>(weakref), cached_hash});
if (it == cache->entries_.end()) {
return;
}
// Create temp-var to avoid re-entrant erase.
auto tmp = std::move(it->second);
cache->entries_.erase(it);
}));
return (entries_
.emplace(WeakrefCacheEntry{std::move(weakref), key.cached_hash},
std::make_shared<Cache>(&lru_list_))
.first->second);

auto& value = it->second;

value.cache = std::make_shared<Cache>(&lru_list_);
value.weakref =
nb::weakref(key.object, nb::cpp_function([this_weak = weak_from_this(),
key](nb::handle weakref) {
auto cache = this_weak.lock();
if (cache == nullptr) {
return;
}
auto it = cache->entries_.find(key);
if (it == cache->entries_.end()) {
return;
}
// Create temp-var to avoid re-entrant erase.
auto tmp = std::move(it->second);
cache->entries_.erase(it);
}));
return value.cache;
}

nb::object Call(nb::object weakref_key, nb::args args,
nb::kwargs kwargs) ABSL_NO_THREAD_SAFETY_ANALYSIS {
nb::object context = cache_context_fn_();
std::shared_ptr<Cache> cache_ptr = GetCache(UnboundWeakrefCacheEntry{
weakref_key, this, static_cast<size_t>(nb::hash(weakref_key))});
std::shared_ptr<Cache> cache_ptr = GetCache(WeakrefCacheKey{
weakref_key, static_cast<size_t>(nb::hash(weakref_key))});
Cache& cache = *cache_ptr;
++total_queries_;

Expand Down Expand Up @@ -246,10 +236,10 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
std::vector<nb::object> GetKeys() {
std::vector<nb::object> results;
mu_.Lock();
for (const auto& wr_key : entries_) {
for (const auto& rest : *wr_key.second) {
for (const auto& wr_entry : entries_) {
for (const auto& rest : *wr_entry.second.cache) {
nb::tuple result =
nb::make_tuple(wr_key.first.weakref, rest.first.context,
nb::make_tuple(*wr_entry.second.weakref, rest.first.context,
rest.first.args, rest.first.kwargs);
results.push_back(std::move(result));
}
Expand All @@ -268,8 +258,9 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
void Clear() {
total_queries_ = misses_ = 0;
std::vector<std::shared_ptr<Cache>> deferred_deletes;
deferred_deletes.reserve(entries_.size());
for (auto& entry : entries_) {
deferred_deletes.push_back(std::move(entry.second));
deferred_deletes.push_back(std::move(entry.second.cache));
}
entries_.clear();
deferred_deletes.clear();
Expand All @@ -278,8 +269,8 @@ class WeakrefLRUCache : public std::enable_shared_from_this<WeakrefLRUCache> {
nb::callable cache_context_fn_;
nb::callable fn_;
Cache::LRUList lru_list_;
absl::node_hash_map<WeakrefCacheEntry, std::shared_ptr<Cache>, WeakrefKeyHash,
WeakrefKeyEq>
std::unordered_map<WeakrefCacheKey, WeakrefCacheValue, WeakrefKeyHash,
WeakrefKeyEq>
entries_;
int64_t misses_ = 0;
int64_t total_queries_ = 0;
Expand Down

0 comments on commit 95beb0e

Please sign in to comment.