From 6d935a4d8b22db5d2603f4f5c6e0e917cfd6bef0 Mon Sep 17 00:00:00 2001 From: Zhibai Song Date: Fri, 27 Oct 2023 17:12:56 -0700 Subject: [PATCH] Support default as params for function and procedure (#242) * Support default as params for function and procedure Previously when we call functions/procedures, we don't support default keyword usage like : select foofunc(default); or exec fooproc default This commit support to use default keyword as function or procedure param when the default value is previously defined in create proc/func Task: BABEL-335 Signed-off-by: Zhibai Song --- src/backend/optimizer/util/clauses.c | 60 ++++++++++++++++++++++++++++ src/backend/parser/analyze.c | 16 ++++++-- src/backend/parser/parse_coerce.c | 4 ++ src/backend/parser/parse_expr.c | 22 +++++++--- src/include/optimizer/clauses.h | 3 ++ 5 files changed, 97 insertions(+), 8 deletions(-) diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c index d7d4787a203..423e340850e 100644 --- a/src/backend/optimizer/util/clauses.c +++ b/src/backend/optimizer/util/clauses.c @@ -40,6 +40,7 @@ #include "optimizer/plancat.h" #include "optimizer/planmain.h" #include "parser/analyze.h" +#include "parser/parser.h" #include "parser/parse_agg.h" #include "parser/parse_coerce.h" #include "parser/parse_func.h" @@ -92,6 +93,7 @@ typedef struct } max_parallel_hazard_context; insert_pltsql_function_defaults_hook_type insert_pltsql_function_defaults_hook = NULL; +replace_pltsql_function_defaults_hook_type replace_pltsql_function_defaults_hook = NULL; static bool contain_agg_clause_walker(Node *node, void *context); static bool find_window_functions_walker(Node *node, WindowFuncLists *lists); @@ -129,6 +131,8 @@ static Expr *simplify_function(Oid funcid, eval_const_expressions_context *context); static List *reorder_function_arguments(List *args, int pronargs, HeapTuple func_tuple); +static List *replace_function_defaults(List *args, HeapTuple func_tuple, bool *need_replace); + static List *add_function_defaults(List *args, int pronargs, HeapTuple func_tuple); static List *fetch_function_defaults(HeapTuple func_tuple); @@ -4050,6 +4054,19 @@ expand_function_arguments(List *args, bool include_out_arguments, func_tuple); } + /* Not add else here, for the reason to also replace the default keyword after add function defaults */ + if (list_length(args) == pronargs && sql_dialect == SQL_DIALECT_TSQL) + { + bool need_replace = false; + args = replace_function_defaults(args, func_tuple, &need_replace); + if (need_replace) + { + recheck_cast_function_args(args, result_type, + proargtypes, pronargs, + func_tuple); + } + } + return args; } @@ -4129,6 +4146,49 @@ reorder_function_arguments(List *args, int pronargs, HeapTuple func_tuple) return args; } +/* + * replace_function_defaults: replace default keyword item with default values + */ +static List * +replace_function_defaults(List *args, HeapTuple func_tuple, bool *need_replace) +{ + List *defaults; + ListCell *lc; + List *ret = NIL; + int nargs = list_length(args); + + *need_replace = false; + if (nargs == 0) + return args; + + foreach(lc, args) + { + if (nodeTag((Node*)lfirst(lc)) == T_RelabelType) + { + if( nodeTag(((RelabelType*)lfirst(lc))->arg) == T_SetToDefault) + *need_replace = true; + } + else if (nodeTag((Node*)lfirst(lc)) == T_FuncExpr) + { + if(((FuncExpr*)lfirst(lc))->funcformat == COERCE_IMPLICIT_CAST && + nodeTag(linitial(((FuncExpr*)lfirst(lc))->args)) == T_SetToDefault) + *need_replace = true; + } + } + + if (!(*need_replace)) + return args; + + /* Get all the default expressions from the pg_proc tuple */ + defaults = fetch_function_defaults(func_tuple); + + if (replace_pltsql_function_defaults_hook) + ret = replace_pltsql_function_defaults_hook(func_tuple, defaults, args); + else return args; + + return ret; +} + /* * add_function_defaults: add missing function arguments from its defaults * diff --git a/src/backend/parser/analyze.c b/src/backend/parser/analyze.c index 5988dad4829..1e9202915f5 100644 --- a/src/backend/parser/analyze.c +++ b/src/backend/parser/analyze.c @@ -3119,9 +3119,19 @@ transformCallStmt(ParseState *pstate, CallStmt *stmt) targs = NIL; foreach(lc, stmt->funccall->args) { - targs = lappend(targs, transformExpr(pstate, - (Node *) lfirst(lc), - EXPR_KIND_CALL_ARGUMENT)); + if (sql_dialect == SQL_DIALECT_TSQL && nodeTag((Node*)lfirst(lc)) == T_SetToDefault) + { + // For Tsql Default in function call, we set it to UNKNOWN in parser stage + // In analyzer it'll use other types to detect the right func candidate + ((SetToDefault *)lfirst(lc))->typeId = UNKNOWNOID; + targs = lappend(targs, (Node *) lfirst(lc)); + } + else + { + targs = lappend(targs, transformExpr(pstate, + (Node *) lfirst(lc), + EXPR_KIND_CALL_ARGUMENT)); + } } node = ParseFuncOrColumn(pstate, diff --git a/src/backend/parser/parse_coerce.c b/src/backend/parser/parse_coerce.c index 88b37bbf6f1..b69d2bcc83e 100644 --- a/src/backend/parser/parse_coerce.c +++ b/src/backend/parser/parse_coerce.c @@ -176,6 +176,10 @@ coerce_type(ParseState *pstate, Node *node, /* no conversion needed */ return node; } + if (nodeTag((Node*)node) == T_SetToDefault) + { + return node; + } if (targetTypeId == ANYOID || targetTypeId == ANYELEMENTOID || targetTypeId == ANYNONARRAYOID || diff --git a/src/backend/parser/parse_expr.c b/src/backend/parser/parse_expr.c index 89d18b7391d..4ce91b6da39 100644 --- a/src/backend/parser/parse_expr.c +++ b/src/backend/parser/parse_expr.c @@ -1489,8 +1489,18 @@ transformFuncCall(ParseState *pstate, FuncCall *fn) targs = NIL; foreach(args, fn->args) { - targs = lappend(targs, transformExprRecurse(pstate, + if (sql_dialect == SQL_DIALECT_TSQL && nodeTag((Node*)lfirst(args)) == T_SetToDefault) + { + // For Tsql Default in function call, we set it to UNKNOWN in parser stage + // In analyzer it'll use other types to detect the right func candidate + ((SetToDefault *)lfirst(args))->typeId = UNKNOWNOID; + targs = lappend(targs, (Node *) lfirst(args)); + } + else + { + targs = lappend(targs, transformExprRecurse(pstate, (Node *) lfirst(args))); + } } /* @@ -1522,10 +1532,12 @@ transformFuncCall(ParseState *pstate, FuncCall *fn) */ if (!schemaname || (strlen(schemaname) == 3 && strncmp(schemaname, "sys", 3) == 0)) - if (strlen(functionname) == 8 && - strncmp(functionname, "checksum", 8) == 0 && - fn->agg_star == true) - targs = ExpandChecksumStar(pstate, fn, fn->location); + { + if (strlen(functionname) == 8 && + strncmp(functionname, "checksum", 8) == 0 && + fn->agg_star == true) + targs = ExpandChecksumStar(pstate, fn, fn->location); + } /* ... and hand off to ParseFuncOrColumn */ return ParseFuncOrColumn(pstate, diff --git a/src/include/optimizer/clauses.h b/src/include/optimizer/clauses.h index 0e527f4eb3f..bba0db788dc 100644 --- a/src/include/optimizer/clauses.h +++ b/src/include/optimizer/clauses.h @@ -59,4 +59,7 @@ extern Bitmapset *pull_paramids(Expr *expr); typedef void (*insert_pltsql_function_defaults_hook_type)(HeapTuple func_tuple, List *defaults, Node **argarray); extern PGDLLIMPORT insert_pltsql_function_defaults_hook_type insert_pltsql_function_defaults_hook; +typedef List* (*replace_pltsql_function_defaults_hook_type)(HeapTuple func_tuple, List *defaults, List *fargs); +extern PGDLLIMPORT replace_pltsql_function_defaults_hook_type replace_pltsql_function_defaults_hook; + #endif /* CLAUSES_H */