Skip to content

Commit

Permalink
It seems to work, now fix and add tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
joka921 committed Jul 8, 2024
1 parent 05089b7 commit aab72ec
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/engine/LocalVocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ auto LocalVocabEntry::lowerBoundInIndex() const -> BoundsInIndex {
result.upperBound_ = vocab.upper_bound(toStringRepresentation());
result.isContained_ =
index.getVocab().getId(toStringRepresentation(), &result.exactMatch_);
indexStatus = result.isContained_ ? IndexStatus::EQUAL : IndexStatus::GREATER;
lowerBoundInIndex_ = result.lowerBound_;
upperBoundInIndex_ = result.upperBound_;
exactMatchInIndex_ = result.exactMatch_;
indexStatus = result.isContained_ ? IndexStatus::EQUAL : IndexStatus::GREATER;
return result;
}
5 changes: 3 additions & 2 deletions src/global/ValueId.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <bit>
#include <cstdint>
#include <functional>
#include <iostream>
#include <limits>

#include "global/IndexTypes.h"
Expand Down Expand Up @@ -169,7 +170,7 @@ class ValueId {
auto lowerBound = x.exactMatch_;
if (lowerBound == getVocabIndex()) {
return x.isContained_ ? std::strong_ordering::equal
: std::strong_ordering::less;
: std::strong_ordering::greater;
} else {
return getVocabIndex() <=> lowerBound;
}
Expand All @@ -179,7 +180,7 @@ class ValueId {
auto lowerBound = x.exactMatch_;
if (lowerBound == other.getVocabIndex()) {
return x.isContained_ ? std::strong_ordering::equal
: std::strong_ordering::greater;
: std::strong_ordering::less;
} else {
return lowerBound <=> other.getVocabIndex();
}
Expand Down
22 changes: 19 additions & 3 deletions src/global/ValueIdComparators.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,25 @@ template <typename RandomIt>
inline std::pair<RandomIt, RandomIt> getRangeForDatatype(RandomIt begin,
RandomIt end,
Datatype datatype) {
return std::equal_range(
begin, end, datatype,
detail::makeSymmetricComparator(&ValueId::getDatatype));
auto comparator = detail::makeSymmetricComparator(&ValueId::getDatatype);
// In a sorted input, `VocabIndex` and `LocalVocabIndex` IDs might be
// interleaved because they logically both store strings. We therefore
// need the range where any of those Datatypes match.

// Binary search on the `Datatype` can only work in the string case, if the
// involved datatypes are directly adjacent.
static_assert(static_cast<int>(Datatype::LocalVocabIndex) ==
static_cast<int>(Datatype::VocabIndex) + 1);
if (ad_utility::contains(
std::array{Datatype::LocalVocabIndex, Datatype::VocabIndex},
datatype)) {
auto lower_bound =
std::lower_bound(begin, end, Datatype::VocabIndex, comparator);
auto upper_bound =
std::upper_bound(begin, end, Datatype::LocalVocabIndex, comparator);
return {lower_bound, upper_bound};
}
return std::equal_range(begin, end, datatype, comparator);
}

namespace detail {
Expand Down

0 comments on commit aab72ec

Please sign in to comment.