Skip to content

Commit

Permalink
Fix DQ input union/merge values with zero input channels (#7515)
Browse files Browse the repository at this point in the history
  • Loading branch information
nepal authored Aug 6, 2024
1 parent 8ee38ab commit 2f2468c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 36 deletions.
61 changes: 30 additions & 31 deletions ydb/library/yql/dq/runtime/dq_input_producer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ template<bool IsWide>
class TDqInputUnionStreamValue : public TComputationValue<TDqInputUnionStreamValue<IsWide>> {
using TBase = TComputationValue<TDqInputUnionStreamValue<IsWide>>;
public:
TDqInputUnionStreamValue(TMemoryUsageInfo* memInfo, TVector<IDqInput::TPtr>&& inputs, TDqMeteringStats::TInputStatsMeter stats)
TDqInputUnionStreamValue(TMemoryUsageInfo* memInfo, const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs, TDqMeteringStats::TInputStatsMeter stats)
: TBase(memInfo)
, Inputs(std::move(inputs))
, Alive(Inputs.size())
, Batch(Inputs.empty() ? nullptr : Inputs.front()->GetInputType())
, Batch(type)
, Stats(stats)
{}

Expand Down Expand Up @@ -114,13 +114,15 @@ template<bool IsWide>
class TDqInputMergeStreamValue : public TComputationValue<TDqInputMergeStreamValue<IsWide>> {
using TBase = TComputationValue<TDqInputMergeStreamValue<IsWide>>;
public:
TDqInputMergeStreamValue(TMemoryUsageInfo* memInfo, TVector<IDqInput::TPtr>&& inputs,
TDqInputMergeStreamValue(TMemoryUsageInfo* memInfo, const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs,
TVector<TSortColumnInfo>&& sortCols, TDqMeteringStats::TInputStatsMeter stats)
: TBase(memInfo)
, Inputs(std::move(inputs))
, Width(type->IsMulti() ? static_cast<const NMiniKQL::TMultiType*>(type)->GetElementsCount() : TMaybe<ui32>())
, SortCols(std::move(sortCols))
, Stats(stats)
{
YQL_ENSURE(!IsWide ^ Width.Defined());
CurrentBuffers.reserve(Inputs.size());
CurrentItemIndexes.reserve(Inputs.size());
for (ui32 idx = 0; idx < Inputs.size(); ++idx) {
Expand Down Expand Up @@ -216,7 +218,7 @@ class TDqInputMergeStreamValue : public TComputationValue<TDqInputMergeStreamVal
return status;
}

YQL_ENSURE(!Inputs.empty() && *Inputs.front()->GetInputWidth() == width);
YQL_ENSURE(*Width == width);
CopyResult(result, width);
if (Stats) {
Stats.Add(result, width);
Expand Down Expand Up @@ -300,6 +302,7 @@ class TDqInputMergeStreamValue : public TComputationValue<TDqInputMergeStreamVal

private:
TVector<IDqInput::TPtr> Inputs;
const TMaybe<ui32> Width;
TVector<TSortColumnInfo> SortCols;
TVector<TUnboxedValueBatch> CurrentBuffers;
TVector<TUnboxedValuesIterator<IsWide>> CurrentItemIndexes;
Expand All @@ -308,20 +311,6 @@ class TDqInputMergeStreamValue : public TComputationValue<TDqInputMergeStreamVal
TDqMeteringStats::TInputStatsMeter Stats;
};

bool IsWideInputs(const TVector<IDqInput::TPtr>& inputs) {
NKikimr::NMiniKQL::TType* type = nullptr;
bool isWide = false;
for (auto& input : inputs) {
if (!type) {
type = input->GetInputType();
isWide = input->GetInputWidth().Defined();
} else {
YQL_ENSURE(type->IsSameType(*input->GetInputType()));
}
}
return isWide;
}

TVector<NKikimr::NMiniKQL::TType*> ExtractBlockItemTypes(const NKikimr::NMiniKQL::TType* type) {
TVector<NKikimr::NMiniKQL::TType*> result;

Expand Down Expand Up @@ -390,18 +379,17 @@ TVector<IBlockItemComparator::TPtr> MakeComparators(const TVector<TSortColumnInf
class TDqInputMergeBlockStreamValue : public TComputationValue<TDqInputMergeBlockStreamValue> {
using TBase = TComputationValue<TDqInputMergeBlockStreamValue>;
public:
TDqInputMergeBlockStreamValue(TMemoryUsageInfo* memInfo, TVector<IDqInput::TPtr>&& inputs,
TDqInputMergeBlockStreamValue(TMemoryUsageInfo* memInfo, const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs,
TVector<TSortColumnInfo>&& sortCols, const NKikimr::NMiniKQL::THolderFactory& factory, TDqMeteringStats::TInputStatsMeter stats)
: TBase(memInfo)
, SortCols_(std::move(sortCols))
, ItemTypes_(ExtractBlockItemTypes(inputs.front()->GetInputType()))
, ItemTypes_(ExtractBlockItemTypes(type))
, MaxOutputBlockLen_(CalcMaxBlockLength(ItemTypes_.begin(), ItemTypes_.end(), TTypeInfoHelper()))
, Comparators_(MakeComparators(SortCols_, ItemTypes_))
, Builders_(MakeBuilders(MaxOutputBlockLen_, ItemTypes_))
, Factory_(factory)
, Stats_(stats)
{
YQL_ENSURE(!inputs.empty());
YQL_ENSURE(MaxOutputBlockLen_ > 0);
InputData_.reserve(inputs.size());
for (auto& input : inputs) {
Expand Down Expand Up @@ -697,6 +685,15 @@ class TDqInputMergeBlockStreamValue : public TComputationValue<TDqInputMergeBloc
bool IsFinished_ = false;
};

void ValidateInputTypes(const NKikimr::NMiniKQL::TType* type, const TVector<IDqInput::TPtr>& inputs) {
YQL_ENSURE(type);
for (size_t i = 0; i < inputs.size(); ++i) {
auto inputType = inputs[i]->GetInputType();
YQL_ENSURE(inputType);
YQL_ENSURE(type->IsSameType(*inputType), "Unexpected type for input #" << i << ": expected " << *type << ", got " << *inputType);
}
}

} // namespace

void TDqMeteringStats::TInputStatsMeter::Add(const NKikimr::NUdf::TUnboxedValue& val) {
Expand Down Expand Up @@ -737,31 +734,33 @@ void TDqMeteringStats::TInputStatsMeter::Add(const NKikimr::NUdf::TUnboxedValue*
}
}

NUdf::TUnboxedValue CreateInputUnionValue(TVector<IDqInput::TPtr>&& inputs,
NUdf::TUnboxedValue CreateInputUnionValue(const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs,
const NMiniKQL::THolderFactory& factory, TDqMeteringStats::TInputStatsMeter stats)
{
if (IsWideInputs(inputs)) {
return factory.Create<TDqInputUnionStreamValue<true>>(std::move(inputs), stats);
ValidateInputTypes(type, inputs);
if (type->IsMulti()) {
return factory.Create<TDqInputUnionStreamValue<true>>(type, std::move(inputs), stats);
}
return factory.Create<TDqInputUnionStreamValue<false>>(std::move(inputs), stats);
return factory.Create<TDqInputUnionStreamValue<false>>(type, std::move(inputs), stats);
}

NKikimr::NUdf::TUnboxedValue CreateInputMergeValue(TVector<IDqInput::TPtr>&& inputs,
NKikimr::NUdf::TUnboxedValue CreateInputMergeValue(const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs,
TVector<TSortColumnInfo>&& sortCols, const NKikimr::NMiniKQL::THolderFactory& factory, TDqMeteringStats::TInputStatsMeter stats)
{
ValidateInputTypes(type, inputs);
YQL_ENSURE(!inputs.empty());
if (IsWideInputs(inputs)) {
if (type->IsMulti()) {
if (AnyOf(sortCols, [](const auto& sortCol) { return sortCol.IsBlockOrScalar(); })) {
// we can ignore scalar columns, since all they have exactly the same value in all inputs
EraseIf(sortCols, [](const auto& sortCol) { return *sortCol.IsScalar; });
if (sortCols.empty()) {
return factory.Create<TDqInputUnionStreamValue<true>>(std::move(inputs), stats);
return factory.Create<TDqInputUnionStreamValue<true>>(type, std::move(inputs), stats);
}
return factory.Create<TDqInputMergeBlockStreamValue>(std::move(inputs), std::move(sortCols), factory, stats);
return factory.Create<TDqInputMergeBlockStreamValue>(type, std::move(inputs), std::move(sortCols), factory, stats);
}
return factory.Create<TDqInputMergeStreamValue<true>>(std::move(inputs), std::move(sortCols), stats);
return factory.Create<TDqInputMergeStreamValue<true>>(type, std::move(inputs), std::move(sortCols), stats);
}
return factory.Create<TDqInputMergeStreamValue<false>>(std::move(inputs), std::move(sortCols), stats);
return factory.Create<TDqInputMergeStreamValue<false>>(type, std::move(inputs), std::move(sortCols), stats);
}

} // namespace NYql::NDq
4 changes: 2 additions & 2 deletions ydb/library/yql/dq/runtime/dq_input_producer.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ struct TDqMeteringStats {
}
};

NKikimr::NUdf::TUnboxedValue CreateInputUnionValue(TVector<IDqInput::TPtr>&& inputs,
NKikimr::NUdf::TUnboxedValue CreateInputUnionValue(const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs,
const NKikimr::NMiniKQL::THolderFactory& holderFactory, TDqMeteringStats::TInputStatsMeter = {});

NKikimr::NUdf::TUnboxedValue CreateInputMergeValue(TVector<IDqInput::TPtr>&& inputs,
NKikimr::NUdf::TUnboxedValue CreateInputMergeValue(const NKikimr::NMiniKQL::TType* type, TVector<IDqInput::TPtr>&& inputs,
TVector<TSortColumnInfo>&& sortCols, const NKikimr::NMiniKQL::THolderFactory& factory,
TDqMeteringStats::TInputStatsMeter = {});

Expand Down
6 changes: 3 additions & 3 deletions ydb/library/yql/dq/runtime/dq_tasks_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,14 @@ NUdf::TUnboxedValue DqBuildInputValue(const NDqProto::TTaskInput& inputDesc, con
Y_ABORT_UNLESS(inputs.size() == 1);
[[fallthrough]];
case NYql::NDqProto::TTaskInput::kUnionAll:
return CreateInputUnionValue(std::move(inputs), holderFactory, stats);
return CreateInputUnionValue(type, std::move(inputs), holderFactory, stats);
case NYql::NDqProto::TTaskInput::kMerge: {
const auto& protoSortCols = inputDesc.GetMerge().GetSortColumns();
TVector<TSortColumnInfo> sortColsInfo;
GetSortColumnsInfo(type, protoSortCols, sortColsInfo);
YQL_ENSURE(!sortColsInfo.empty());

return CreateInputMergeValue(std::move(inputs), std::move(sortColsInfo), holderFactory, stats);
return CreateInputMergeValue(type, std::move(inputs), std::move(sortColsInfo), holderFactory, stats);
}
default:
YQL_ENSURE(false, "Unknown input type: " << (ui32) inputDesc.GetTypeCase());
Expand Down Expand Up @@ -576,7 +576,7 @@ class TDqTaskRunner : public IDqTaskRunner {
inputs.clear();
inputs.emplace_back(transform->TransformOutput);
entryNode->SetValue(AllocatedHolder->ProgramParsed.CompGraph->GetContext(),
CreateInputUnionValue(std::move(inputs), holderFactory,
CreateInputUnionValue(transform->TransformOutput->GetInputType(), std::move(inputs), holderFactory,
{&inputStats, transform->TransformOutputType}));
} else {
entryNode->SetValue(AllocatedHolder->ProgramParsed.CompGraph->GetContext(),
Expand Down

0 comments on commit 2f2468c

Please sign in to comment.