Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support of AnyNonArray type, fixes for array_agg #1958

Merged
merged 3 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ydb/docs/en/core/postgresql/_includes/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -3866,16 +3866,16 @@ Table 9.57. General-Purpose Aggregate Functions
#|
||Function|Description|Partial Mode|Example||
||array_agg ( anynonarray ) → anyarray|
Collects all the input values, including nulls, into an array. (NOT SUPPORTED)|
Collects all the input values, including nulls, into an array.|
No|
```sql
#SELECT array_agg(x) FROM (VALUES (1),(2)) a(x) → {1,2}
SELECT array_agg(x) FROM (VALUES (1),(2)) a(x) → {1,2}
```||
||array_agg ( anyarray ) → anyarray|
Concatenates all the input arrays into an array of one higher dimension. (The inputs must all have the same dimensionality, and cannot be empty or null.) (NOT SUPPORTED)|
Concatenates all the input arrays into an array of one higher dimension. (The inputs must all have the same dimensionality, and cannot be empty or null.)|
No|
```sql
#SELECT array_agg(x) FROM (VALUES (Array[1,2]),(Array[3,4])) a(x) → {{1,2},{3,4}}
SELECT array_agg(x) FROM (VALUES (Array[1,2]),(Array[3,4])) a(x) → {{1,2},{3,4}}
```||
||avg ( smallint ) → numeric
avg ( integer ) → numeric
Expand Down
8 changes: 4 additions & 4 deletions ydb/docs/ru/core/postgresql/_includes/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -3866,16 +3866,16 @@ Table 9.57. General-Purpose Aggregate Functions
#|
||Function|Description|Partial Mode|Example||
||array_agg ( anynonarray ) → anyarray|
Collects all the input values, including nulls, into an array. (NOT SUPPORTED)|
Collects all the input values, including nulls, into an array.|
No|
```sql
#SELECT array_agg(x) FROM (VALUES (1),(2)) a(x) → {1,2}
SELECT array_agg(x) FROM (VALUES (1),(2)) a(x) → {1,2}
```||
||array_agg ( anyarray ) → anyarray|
Concatenates all the input arrays into an array of one higher dimension. (The inputs must all have the same dimensionality, and cannot be empty or null.) (NOT SUPPORTED)|
Concatenates all the input arrays into an array of one higher dimension. (The inputs must all have the same dimensionality, and cannot be empty or null.)|
No|
```sql
#SELECT array_agg(x) FROM (VALUES (Array[1,2]),(Array[3,4])) a(x) → {{1,2},{3,4}}
SELECT array_agg(x) FROM (VALUES (Array[1,2]),(Array[3,4])) a(x) → {{1,2},{3,4}}
```||
||avg ( smallint ) → numeric
avg ( integer ) → numeric
Expand Down
26 changes: 23 additions & 3 deletions ydb/library/yql/core/type_ann/type_ann_pg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ bool IsCastRequired(ui32 fromTypeId, ui32 toTypeId) {
if (toTypeId == fromTypeId) {
return false;
}
if (toTypeId == NPg::AnyOid || toTypeId == NPg::AnyArrayOid) {
if (toTypeId == NPg::AnyOid || toTypeId == NPg::AnyArrayOid || toTypeId == NPg::AnyNonArrayOid) {
return false;
}
return true;
Expand Down Expand Up @@ -224,6 +224,7 @@ IGraphTransformer::TStatus PgCallWrapper(const TExprNode::TPtr& input, TExprNode
}

bool rangeFunction = false;
ui32 refinedType = 0;
for (const auto& setting : input->Child(isResolved ? 2 : 1)->Children()) {
if (!EnsureTupleMinSize(*setting, 1, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
Expand All @@ -236,6 +237,16 @@ IGraphTransformer::TStatus PgCallWrapper(const TExprNode::TPtr& input, TExprNode
auto content = setting->Head().Content();
if (content == "range") {
rangeFunction = true;
} else if (content == "type") {
if (!EnsureTupleSize(*setting, 2, ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

if (!EnsureAtom(setting->Tail(), ctx.Expr)) {
return IGraphTransformer::TStatus::Error;
}

refinedType = NPg::LookupType(TString(setting->Tail().Content())).TypeId;
} else {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
TStringBuilder() << "Unexpected setting " << content << " in function " << name));
Expand Down Expand Up @@ -286,9 +297,17 @@ IGraphTransformer::TStatus PgCallWrapper(const TExprNode::TPtr& input, TExprNode
return IGraphTransformer::TStatus::Error;
}

const TTypeAnnotationNode* result = ctx.Expr.MakeType<TPgExprType>(proc.ResultType);
auto resultType = proc.ResultType;
AdjustReturnType(resultType, proc.ArgTypes, argTypes);
if (resultType == NPg::AnyArrayOid && refinedType) {
const auto& refinedDesc = NPg::LookupType(refinedType);
YQL_ENSURE(refinedDesc.ArrayTypeId == refinedDesc.TypeId);
resultType = refinedDesc.TypeId;
}

const TTypeAnnotationNode* result = ctx.Expr.MakeType<TPgExprType>(resultType);
TMaybe<TColumnOrder> resultColumnOrder;
if (proc.ResultType == NPg::RecordOid && rangeFunction) {
if (resultType == NPg::RecordOid && rangeFunction) {
if (proc.OutputArgNames.empty()) {
ctx.Expr.AddError(TIssue(ctx.Expr.GetPosition(input->Pos()),
TStringBuilder() << "Aggregate function " << name << " cannot be used in FROM"));
Expand Down Expand Up @@ -780,6 +799,7 @@ IGraphTransformer::TStatus PgAggWrapper(const TExprNode::TPtr& input, TExprNode:
resultType = NPg::LookupProc(aggDesc.FinalFuncId).ResultType;
}

AdjustReturnType(resultType, aggDesc.ArgTypes, argTypes);
auto result = ctx.Expr.MakeType<TPgExprType>(resultType);
input->SetTypeAnn(result);

Expand Down
61 changes: 60 additions & 1 deletion ydb/library/yql/core/yql_expr_type_annotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6331,22 +6331,42 @@ TExprNode::TPtr ExpandPgAggregationTraits(TPositionHandle pos, const NPg::TAggre
auto saveLambda = idLambda;
auto loadLambda = idLambda;
auto finishLambda = idLambda;
auto nullValue = ctx.NewCallable(pos, "Null", {});
if (aggDesc.FinalFuncId) {
const ui32 originalAggResultType = NPg::LookupProc(aggDesc.FinalFuncId).ResultType;
ui32 aggResultType = originalAggResultType;
AdjustReturnType(aggResultType, aggDesc.ArgTypes, argTypes);
finishLambda = ctx.Builder(pos)
.Lambda()
.Param("state")
.Callable("PgResolvedCallCtx")
.Atom(0, NPg::LookupProc(aggDesc.FinalFuncId).Name)
.Atom(1, ToString(aggDesc.FinalFuncId))
.List(2)
.Do([aggResultType, originalAggResultType](TExprNodeBuilder& builder) -> TExprNodeBuilder& {
if (aggResultType != originalAggResultType) {
builder.List(0)
.Atom(0, "type")
.Atom(1, NPg::LookupType(aggResultType).Name)
.Seal();
}

return builder;
})
.Seal()
.Arg(3, "state")
.Do([&aggDesc, nullValue](TExprNodeBuilder& builder) -> TExprNodeBuilder& {
if (aggDesc.FinalExtra) {
builder.Add(4, nullValue);
}

return builder;
})
.Seal()
.Seal()
.Build();
}

auto nullValue = ctx.NewCallable(pos, "Null", {});
auto initValue = nullValue;
if (aggDesc.InitValue) {
initValue = ctx.Builder(pos)
Expand Down Expand Up @@ -6618,6 +6638,45 @@ TExprNode::TPtr ExpandPgAggregationTraits(TPositionHandle pos, const NPg::TAggre
}
}

void AdjustReturnType(ui32& returnType, const TVector<ui32>& procArgTypes, const TVector<ui32>& argTypes) {
YQL_ENSURE(procArgTypes.size() >= argTypes.size());
if (returnType == NPg::AnyArrayOid) {
TMaybe<ui32> inputElementType;
TMaybe<ui32> inputArrayType;
for (ui32 i = 0; i < argTypes.size(); ++i) {
if (!argTypes[i]) {
continue;
}

if (procArgTypes[i] == NPg::AnyNonArrayOid) {
if (!inputElementType) {
inputElementType = argTypes[i];
} else {
if (*inputElementType != argTypes[i]) {
return;
}
}
}

if (procArgTypes[i] == NPg::AnyArrayOid) {
if (!inputArrayType) {
inputArrayType = argTypes[i];
} else {
if (*inputArrayType != argTypes[i]) {
return;
}
}
}
}

if (inputElementType) {
returnType = NPg::LookupType(*inputElementType).ArrayTypeId;
} else if (inputArrayType) {
returnType = *inputArrayType;
}
}
}

const TTypeAnnotationNode* GetOriginalResultType(TPositionHandle pos, bool isMany, const TTypeAnnotationNode* originalExtractorType, TExprContext& ctx) {
if (!EnsureStructType(pos, *originalExtractorType, ctx)) {
return nullptr;
Expand Down
1 change: 1 addition & 0 deletions ydb/library/yql/core/yql_expr_type_annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ bool GetMinMaxResultType(const TPositionHandle& pos, const TTypeAnnotationNode&
IGraphTransformer::TStatus ExtractPgTypesFromMultiLambda(TExprNode::TPtr& lambda, TVector<ui32>& argTypes,
bool& needRetype, TExprContext& ctx);

void AdjustReturnType(ui32& returnType, const TVector<ui32>& procArgTypes, const TVector<ui32>& argTypes);
TExprNode::TPtr ExpandPgAggregationTraits(TPositionHandle pos, const NPg::TAggregateDesc& aggDesc, bool onWindow,
const TExprNode::TPtr& lambda, const TVector<ui32>& argTypes, const TTypeAnnotationNode* itemType, TExprContext& ctx);

Expand Down
15 changes: 14 additions & 1 deletion ydb/library/yql/parser/pg_catalog/catalog.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ bool IsCompatibleTo(ui32 actualTypeId, ui32 expectedTypeId, const TTypes& types)
return actualDescPtr->ArrayTypeId == actualDescPtr->TypeId;
}

if (expectedTypeId == AnyNonArrayOid) {
const auto& actualDescPtr = types.FindPtr(actualTypeId);
Y_ENSURE(actualDescPtr);
return actualDescPtr->ArrayTypeId != actualDescPtr->TypeId;
}

return false;
}

Expand Down Expand Up @@ -753,6 +759,8 @@ class TAggregationsParser : public TParser {
}
} else if (key == "agginitval") {
LastAggregation.InitValue = value;
} else if (key == "aggfinalextra") {
LastAggregation.FinalExtra = (value == "t");;
}
}

Expand Down Expand Up @@ -1920,7 +1928,6 @@ bool IsCoercible(ui32 fromTypeId, ui32 toTypeId, ECoercionCode coercionType, con
if (toTypeId == AnyOid) {
return true;
}

//TODO: support polymorphic types

if (fromTypeId == UnknownOid) {
Expand All @@ -1943,6 +1950,12 @@ bool IsCoercible(ui32 fromTypeId, ui32 toTypeId, ECoercionCode coercionType, con
return actualDescPtr->ArrayTypeId == actualDescPtr->TypeId;
}

if (toTypeId == AnyNonArrayOid) {
const auto& actualDescPtr = catalog.Types.FindPtr(fromTypeId);
Y_ENSURE(actualDescPtr);
return actualDescPtr->ArrayTypeId != actualDescPtr->TypeId;
}

return false;
}

Expand Down
2 changes: 2 additions & 0 deletions ydb/library/yql/parser/pg_catalog/catalog.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace NYql::NPg {
constexpr ui32 UnknownOid = 705;
constexpr ui32 AnyOid = 2276;
constexpr ui32 AnyArrayOid = 2277;
constexpr ui32 AnyNonArrayOid = 2776;
constexpr ui32 RecordOid = 2249;
constexpr ui32 VarcharOid = 1043;
constexpr ui32 TextOid = 25;
Expand Down Expand Up @@ -156,6 +157,7 @@ struct TAggregateDesc {
ui32 SerializeFuncId = 0;
ui32 DeserializeFuncId = 0;
TString InitValue;
bool FinalExtra = false;
};

enum class EAmType {
Expand Down
2 changes: 2 additions & 0 deletions ydb/library/yql/parser/pg_catalog/ut/catalog_consts_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ Y_UNIT_TEST_SUITE(TConstantsTests) {
UNIT_ASSERT_VALUES_EQUAL(typeDesc.TypeId, VarcharOid);
typeDesc = LookupType("text");
UNIT_ASSERT_VALUES_EQUAL(typeDesc.TypeId, TextOid);
typeDesc = LookupType("anynonarray");
UNIT_ASSERT_VALUES_EQUAL(typeDesc.TypeId, AnyNonArrayOid);
}

Y_UNIT_TEST(TRelationOidConsts) {
Expand Down
5 changes: 4 additions & 1 deletion ydb/library/yql/parser/pg_wrapper/syscache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ extern "C" {
#include "catalog/pg_type_d.h"
#include "catalog/pg_authid.h"
#include "access/htup_details.h"
#include "utils/fmgroids.h"
}

#undef TypeName
Expand Down Expand Up @@ -183,7 +184,7 @@ struct TSysCache {
std::fill_n(nulls, Natts_pg_type, true);
std::fill_n(nulls, Anum_pg_type_typcollation, false); // fixed part of Form_pg_type
FillDatum(Natts_pg_type, values, nulls, Anum_pg_type_oid, oid);
auto name = MakeFixedString(desc.Name, NPg::LookupType(NAMEOID).TypeLen);
auto name = MakeFixedString(desc.Name, NAMEDATALEN);
FillDatum(Natts_pg_type, values, nulls, Anum_pg_type_typname, (Datum)name);
FillDatum(Natts_pg_type, values, nulls, Anum_pg_type_typbyval, desc.PassByValue);
FillDatum(Natts_pg_type, values, nulls, Anum_pg_type_typlen, desc.TypeLen);
Expand All @@ -193,6 +194,8 @@ struct TSysCache {
FillDatum(Natts_pg_type, values, nulls, Anum_pg_type_typisdefined, true);
FillDatum(Natts_pg_type, values, nulls, Anum_pg_type_typdelim, desc.TypeDelim);
FillDatum(Natts_pg_type, values, nulls, Anum_pg_type_typarray, desc.ArrayTypeId);
FillDatum(Natts_pg_type, values, nulls, Anum_pg_type_typsubscript,
(desc.ArrayTypeId == desc.TypeId) ? F_ARRAY_SUBSCRIPT_HANDLER : desc.TypeSubscriptFuncId);
FillDatum(Natts_pg_type, values, nulls, Anum_pg_type_typelem, desc.ElementTypeId);
FillDatum(Natts_pg_type, values, nulls, Anum_pg_type_typinput, desc.InFuncId);
FillDatum(Natts_pg_type, values, nulls, Anum_pg_type_typoutput, desc.OutFuncId);
Expand Down
Loading