Skip to content

Commit

Permalink
Add concepts and concrete types for BlockZipperJoinImpl (#1625)
Browse files Browse the repository at this point in the history
This class (which is the central implementation of the lazy merge join) previously had a lot of `const auto&` etc. parameters, which made the (already rather complex) code harder to reason about. This PR adds concrete types wherever possible, and constrains the other parameters using concepts.
  • Loading branch information
RobinTF authored Nov 29, 2024
1 parent 9c41750 commit 527f7bf
Showing 1 changed file with 45 additions and 29 deletions.
74 changes: 45 additions & 29 deletions src/util/JoinAlgorithms/JoinAlgorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "util/JoinAlgorithms/FindUndefRanges.h"
#include "util/JoinAlgorithms/JoinColumnMapping.h"
#include "util/TransparentFunctors.h"
#include "util/TypeTraits.h"

namespace ad_utility {

Expand Down Expand Up @@ -742,8 +743,7 @@ struct AlwaysFalse {
// if they become empty at one point in the algorithm.
template <IsJoinSide LeftSide, IsJoinSide RightSide, typename LessThan,
typename CompatibleRowAction,
ad_utility::InvocableWithExactReturnType<
bool, typename LeftSide::ProjectedEl>
InvocableWithExactReturnType<bool, typename LeftSide::ProjectedEl>
IsUndef = AlwaysFalse>
struct BlockZipperJoinImpl {
// The left and right inputs of the join
Expand All @@ -755,6 +755,16 @@ struct BlockZipperJoinImpl {
CompatibleRowAction& compatibleRowAction_;
[[no_unique_address]] IsUndef isUndefined_{};

using LeftBlocks = typename LeftSide::CurrentBlocks;
using RightBlocks = typename RightSide::CurrentBlocks;

// We can't define aliases for these concepts, so we use macros instead.
#if defined(Side) || defined(Blocks)
#error Side or Blocks are already defined
#endif
#define Side SameAsAny<LeftSide, RightSide> auto
#define Blocks SameAsAny<LeftBlocks, RightBlocks> auto

// Type alias for the result of the projection. Elements from the left and
// right input must be projected to the same type.
using ProjectedEl = LeftSide::ProjectedEl;
Expand All @@ -770,7 +780,7 @@ struct BlockZipperJoinImpl {
// Recompute the `currentEl`. It is the minimum of the last element in the
// first block of either of the join sides.
ProjectedEl getCurrentEl() {
auto getFirst = [](const auto& side) {
auto getFirst = [](const Side& side) {
return side.projection_(side.currentBlocks_.front().back());
};
return std::min(getFirst(leftSide_), getFirst(rightSide_), lessThan_);
Expand All @@ -784,7 +794,7 @@ struct BlockZipperJoinImpl {
// blocks that contain elements <= `currentEl` have been added, and `false` if
// the function returned because 3 blocks were added without fulfilling the
// condition.
bool fillEqualToCurrentEl(auto& side, const auto& currentEl) {
bool fillEqualToCurrentEl(Side& side, const ProjectedEl& currentEl) {
auto& it = side.it_;
auto& end = side.end_;
for (size_t numBlocksRead = 0; it != end && numBlocksRead < 3;
Expand All @@ -808,7 +818,7 @@ struct BlockZipperJoinImpl {
// sides contain all the relevant blocks. Only filling one side is used for
// the optimization for the Cartesian product described in the documentation.
enum struct BlockStatus { leftMissing, rightMissing, allFilled };
BlockStatus fillEqualToCurrentElBothSides(const auto& currentEl) {
BlockStatus fillEqualToCurrentElBothSides(const ProjectedEl& currentEl) {
bool allBlocksFromLeft = false;
bool allBlocksFromRight = false;
while (!(allBlocksFromLeft || allBlocksFromRight)) {
Expand All @@ -828,14 +838,13 @@ struct BlockZipperJoinImpl {
// `rightSide_.currentBlocks`) s.t. only elements `> lastProcessedElement`
// remain. This effectively removes all blocks completely, except maybe the
// last one.
template <typename Blocks, typename ProjectedEl>
void removeEqualToCurrentEl(Blocks& blocks,
ProjectedEl lastProcessedElement) {
const ProjectedEl& lastProcessedElement) {
// Erase all but the last block.
AD_CORRECTNESS_CHECK(!blocks.empty());
if (blocks.size() > 1 && !blocks.front().empty()) {
AD_CORRECTNESS_CHECK(!lessThan_(lastProcessedElement,
std::as_const(blocks.front()).back()));
AD_CORRECTNESS_CHECK(
!lessThan_(lastProcessedElement, blocks.front().back()));
}
blocks.erase(blocks.begin(), blocks.end() - 1);

Expand All @@ -854,15 +863,16 @@ struct BlockZipperJoinImpl {
// * A reference to the first full block
// * The currently active subrange of that block
// * An iterator pointing to the first element ` >= currentEl` in the block.
auto getFirstBlock(auto& currentBlocks, const auto& currentEl) {
auto getFirstBlock(const Blocks& currentBlocks,
const ProjectedEl& currentEl) {
AD_CORRECTNESS_CHECK(!currentBlocks.empty());
const auto& first = currentBlocks.at(0);
auto it = std::ranges::lower_bound(first.subrange(), currentEl, lessThan_);
return std::tuple{std::ref(first.fullBlock()), first.subrange(), it};
}

// Check if a side contains undefined values.
static bool hasUndef(const auto& side) {
static bool hasUndef(const Side& side) {
if constexpr (potentiallyHasUndef) {
return !side.undefBlocks_.empty();
}
Expand All @@ -871,7 +881,8 @@ struct BlockZipperJoinImpl {

// Combine all elements from all blocks on the left with all elements from all
// blocks on the right and add them to the result.
void addCartesianProduct(const auto& blocksLeft, const auto& blocksRight) {
void addCartesianProduct(const LeftBlocks& blocksLeft,
const RightBlocks& blocksRight) {
// TODO<C++23> use `std::views::cartesian_product`.
for (const auto& lBlock : blocksLeft) {
for (const auto& rBlock : blocksRight) {
Expand All @@ -888,8 +899,8 @@ struct BlockZipperJoinImpl {
// Handle non-matching rows from the left side for an optional join or a minus
// join.
template <bool DoOptionalJoin>
void addNonMatchingRowsFromLeftForOptionalJoin(const auto& blocksLeft,
const auto& blocksRight) {
void addNonMatchingRowsFromLeftForOptionalJoin(
const LeftBlocks& blocksLeft, const RightBlocks& blocksRight) {
if constexpr (DoOptionalJoin) {
if (!hasUndef(rightSide_) &&
std::ranges::all_of(
Expand All @@ -910,7 +921,7 @@ struct BlockZipperJoinImpl {
// Call `compatibleRowAction` for all pairs of elements in the Cartesian
// product of the blocks in `blocksLeft` and `blocksRight`.
template <bool DoOptionalJoin>
void addAll(const auto& blocksLeft, const auto& blocksRight) {
void addAll(const LeftBlocks& blocksLeft, const RightBlocks& blocksRight) {
addNonMatchingRowsFromLeftForOptionalJoin<DoOptionalJoin>(blocksLeft,
blocksRight);
addCartesianProduct(blocksLeft, blocksRight);
Expand All @@ -921,7 +932,7 @@ struct BlockZipperJoinImpl {
// `currentEl`. Effectively, these subranges cover all the blocks completely
// except maybe the last one, which might contain elements `> currentEl` at
// the end.
auto getEqualToCurrentEl(const auto& blocks, const auto& currentEl) {
auto getEqualToCurrentEl(const Blocks& blocks, const ProjectedEl& currentEl) {
auto result = blocks;
if (result.empty()) {
return result;
Expand Down Expand Up @@ -990,9 +1001,9 @@ struct BlockZipperJoinImpl {
// The fully joined parts of the block are then removed from
// `currentBlocksLeft/Right`, as they are not needed anymore.
template <bool DoOptionalJoin>
void joinAndRemoveLessThanCurrentEl(auto& currentBlocksLeft,
auto& currentBlocksRight,
const auto& currentEl) {
void joinAndRemoveLessThanCurrentEl(LeftBlocks& currentBlocksLeft,
RightBlocks& currentBlocksRight,
const ProjectedEl& currentEl) {
// Get the first blocks.
auto [fullBlockLeft, subrangeLeft, currentElItL] =
getFirstBlock(currentBlocksLeft, currentEl);
Expand Down Expand Up @@ -1045,7 +1056,7 @@ struct BlockZipperJoinImpl {

// If the `targetBuffer` is empty, read the next nonempty block from `[it,
// end)` if there is one.
void fillWithAtLeastOne(auto& side) {
void fillWithAtLeastOne(Side& side) {
auto& targetBuffer = side.currentBlocks_;
auto& it = side.it_;
const auto& end = side.end_;
Expand Down Expand Up @@ -1086,8 +1097,9 @@ struct BlockZipperJoinImpl {
// Based on `blockStatus` add the Cartesian product of the blocks in
// `leftBlocks` and/or `rightBlocks` with their respective counterpart in
// `undefBlocks_`.
void joinWithUndefBlocks(BlockStatus blockStatus, const auto& leftBlocks,
const auto& rightBlocks) {
void joinWithUndefBlocks(BlockStatus blockStatus,
const LeftBlocks& leftBlocks,
const RightBlocks& rightBlocks) {
if (blockStatus == BlockStatus::allFilled ||
blockStatus == BlockStatus::leftMissing) {
addCartesianProduct(leftBlocks, rightSide_.undefBlocks_);
Expand Down Expand Up @@ -1117,14 +1129,15 @@ struct BlockZipperJoinImpl {
auto equalToCurrentElRight =
getEqualToCurrentEl(currentBlocksRight, currentEl);

auto getNextBlocks = [&currentEl, self = this, &blockStatus](auto& target,
auto& side) {
self->removeEqualToCurrentEl(side.currentBlocks_, currentEl);
bool allBlocksWereFilled = self->fillEqualToCurrentEl(side, currentEl);
auto getNextBlocks = [this, &currentEl, &blockStatus](Blocks& target,
Side& side) {
// Explicit this to avoid false positive warning in clang.
this->removeEqualToCurrentEl(side.currentBlocks_, currentEl);
bool allBlocksWereFilled = fillEqualToCurrentEl(side, currentEl);
if (side.currentBlocks_.empty()) {
AD_CORRECTNESS_CHECK(allBlocksWereFilled);
}
target = self->getEqualToCurrentEl(side.currentBlocks_, currentEl);
target = getEqualToCurrentEl(side.currentBlocks_, currentEl);
if (allBlocksWereFilled) {
blockStatus = BlockStatus::allFilled;
}
Expand Down Expand Up @@ -1185,7 +1198,7 @@ struct BlockZipperJoinImpl {
// those blocks with the undef blocks from the other side.
// `reverse` is used to determine if the left or right side is consumed.
template <bool reversed>
void consumeRemainingBlocks(auto& side, const auto& undefBlocks) {
void consumeRemainingBlocks(Side& side, const Blocks& undefBlocks) {
while (side.it_ != side.end_) {
const auto& lBlock = *side.it_;
for (const auto& rBlock : undefBlocks) {
Expand Down Expand Up @@ -1228,7 +1241,7 @@ struct BlockZipperJoinImpl {
// `side.undefBlocks_` and skipped for subsequent processing. The first block
// containing defined values is split and the defined part is stored in
// `side.currentBlocks_`.
void findFirstBlockWithoutUndef(auto& side) {
void findFirstBlockWithoutUndef(Side& side) {
// The reference of `it` is there on purpose.
for (auto& it = side.it_; it != side.end_; ++it) {
auto& el = *it;
Expand Down Expand Up @@ -1293,6 +1306,9 @@ struct BlockZipperJoinImpl {
joinBuffers<DoOptionalJoin, ProjectedEl>(blockStatus);
}
}
// Don't clutter other compilation units with these aliases.
#undef Side
#undef Blocks
};

// Deduction guide for the above struct.
Expand Down

0 comments on commit 527f7bf

Please sign in to comment.