From c15a1cbb0ed9859859653a814b41d804bbcbbaa4 Mon Sep 17 00:00:00 2001 From: Anikait Agrawal <54908236+Anikait143@users.noreply.github.com> Date: Wed, 31 Jan 2024 22:37:21 +0530 Subject: [PATCH] Modified Geospatial TSQL Syntax Support (#2313) (#2324) * Modified Geospatial TSQL Syntax Support (#2313) This commit contains a few changes from https://github.com/babelfish-for-postgresql/babelfish_extensions/pull/2271 Moved Duplicated Code to common helper functions for Exit functions handling Geospatial functional call in both TSQLBuilder and TSQLCommonMutator. To handle local id's we are now quoting them in the geospatial handling itself rather than modifying their start positions, so as to remove the possibility of multiple rewrites within a single context. Signed-off-by: Anikait Agrawal [agraani@amazon.com](mailto:agraani@amazon.com) * Retrigger Tests --------- Signed-off-by: Anikait Agrawal [agraani@amazon.com](mailto:agraani@amazon.com) Co-authored-by: Anikait Agrawal --- contrib/babelfishpg_tsql/antlr/TSqlParser.g4 | 12 +- contrib/babelfishpg_tsql/src/tsqlIface.cpp | 517 +++++++------------ 2 files changed, 195 insertions(+), 334 deletions(-) diff --git a/contrib/babelfishpg_tsql/antlr/TSqlParser.g4 b/contrib/babelfishpg_tsql/antlr/TSqlParser.g4 index 5e7c71f6b0..ebc1fbf719 100644 --- a/contrib/babelfishpg_tsql/antlr/TSqlParser.g4 +++ b/contrib/babelfishpg_tsql/antlr/TSqlParser.g4 @@ -3436,10 +3436,7 @@ constant_LOCAL_ID // https://docs.microsoft.com/en-us/sql/t-sql/language-elements/expressions-transact-sql // Operator precendence: https://docs.microsoft.com/en-us/sql/t-sql/language-elements/operator-precedence-transact-sql expression - : local_id (DOT calls+=method_call)* #local_id_expr - | subquery (DOT calls+=method_call)* #subquery_expr - | LR_BRACKET expression RR_BRACKET (DOT calls+=method_call)* #bracket_expr - | function_call (DOT calls+=method_call)* #func_call_expr + : clr_udt_func_call #clr_udt_expr | expression collation #collate_expr | expression AT_KEYWORD TIME ZONE expression #time_zone_expr | op=(MINUS | PLUS | BIT_NOT) expression #unary_op_expr @@ -3456,6 +3453,13 @@ expression | DOLLAR_ACTION #dollar_action_expr ; +clr_udt_func_call + : local_id (DOT calls+=method_call)* + | subquery (DOT calls+=method_call)* + | LR_BRACKET expression RR_BRACKET (DOT calls+=method_call)* + | function_call (DOT calls+=method_call)* + ; + method_call : xml_methods | hierarchyid_methods diff --git a/contrib/babelfishpg_tsql/src/tsqlIface.cpp b/contrib/babelfishpg_tsql/src/tsqlIface.cpp index 2967472336..2dbf21c32d 100644 --- a/contrib/babelfishpg_tsql/src/tsqlIface.cpp +++ b/contrib/babelfishpg_tsql/src/tsqlIface.cpp @@ -179,6 +179,10 @@ template static void rewrite_geospatial_func_ref_no_arg_query_helper(T template static void rewrite_geospatial_func_ref_args_query_helper(T ctx, TSqlParser::Method_callContext *method, size_t geospatial_start_index); template static void rewrite_function_call_geospatial_func_ref_args(T ctx); template static void rewrite_function_call_geospatial_func_ref_no_arg(T ctx); +static void handleGeospatialFunctionsInFunctionCall(TSqlParser::Function_callContext *ctx); +static void handleClrUdtFuncCall(TSqlParser::Clr_udt_func_callContext *ctx); +static void handleFullColumnNameCtx(TSqlParser::Full_column_nameContext *ctx); +template static void handleLocalIdQuotingFuncRefNoArg(T ctx, size_t geospatial_start_index, int &offset1, std::string &expr, std::vector keysToRemove); static bool does_object_name_need_delimiter(TSqlParser::IdContext *id); static std::string delimit_identifier(TSqlParser::IdContext *id); static bool does_msg_exceeds_params_limit(const std::string& msg); @@ -897,160 +901,20 @@ class tsqlCommonMutator : public TSqlParserBaseListener void exitFunction_call(TSqlParser::Function_callContext *ctx) override { - /* Handles rewrite of geospatial function calls but inside body of CREATE/ALTER View, Procedure, Function */ - if (ctx->spatial_proc_name_server_database_schema()) - { - if (ctx->spatial_proc_name_server_database_schema()->schema) throw PGErrorWrapperException(ERROR, ERRCODE_FEATURE_NOT_SUPPORTED, "Remote procedure/function reference with 4-part object name is not currently supported in Babelfish", getLineAndPos(ctx)); - - /* This if-elseIf clause rewrites the query in case of Geospatial function Call */ - if (ctx->spatial_proc_name_server_database_schema()->geospatial_func_arg() && ctx->function_arg_list()) - rewrite_function_call_geospatial_func_ref_args(ctx); - else if (ctx->spatial_proc_name_server_database_schema()->geospatial_func_no_arg() && !ctx->function_arg_list()) - rewrite_function_call_geospatial_func_ref_no_arg(ctx); - } - } - - /* We are adding handling for Spatial Types in: - * tsqlCommonMutator: for CREATE/ALTER View, Procedure, Function - * tsqlBuilder: for other cases handling - */ - /* Here we are Rewriting Geospatial query query: Func_call DOT Geospatial_Func: - * Func_call DOT Geospatial_col -> ( Func_call ) DOT Geospatial_col - * Func_call DOT Geospatial_func (arg_list) -> Geospatial_func (agr_list, Func_call) - */ - void exitFunc_call_expr(TSqlParser::Func_call_exprContext *ctx) override - { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> function_call.method_call */ - if(method->spatial_methods()) - { - size_t ind; - if (i == 0) ind = ctx->function_call()->stop->getStopIndex(); - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } - } - - /* Handles rewrite of geospatial query but inside body of CREATE/ALTER View, Procedure, Function: - * local_id DOT Geospatial_col -> ( local_id ) DOT Geospatial_col - * local_id DOT Geospatial_func (arg_list) -> Geospatial_func (agr_list, local_id) - */ - void exitLocal_id_expr(TSqlParser::Local_id_exprContext *ctx) override - { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> local_id.method_call */ - if(method->spatial_methods()) - { - size_t ind; - if (i == 0) ind = ctx->local_id()->stop->getStopIndex(); - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } - } - - /* Handles rewrite of geospatial query but inside body of CREATE/ALTER View, Procedure, Function: - * bracket_expr DOT Geospatial_col -> ( bracket_expr ) DOT Geospatial_col - * bracket_expr DOT Geospatial_func (arg_list) -> Geospatial_func (agr_list, bracket_expr) - */ - void exitBracket_expr(TSqlParser::Bracket_exprContext *ctx) override - { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> LR_BRACKET expression RR_BRACKET.method_call */ - if(method->spatial_methods()) - { - size_t ind; - std::string context = ::getFullText(ctx); - size_t spaces = 0; - for (size_t x = ctx->expression()->stop->getStopIndex() + 1 - ctx->start->getStartIndex(); x <= ctx->stop->getStopIndex() - ctx->start->getStartIndex(); x++) - { - if(context[x] == ' ') spaces++; - else if(context[x] == ')') break; - } - if (i == 0) ind = ctx->expression()->stop->getStopIndex() + 1 + spaces; - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } + handleGeospatialFunctionsInFunctionCall(ctx); } - /* Handles rewrite of geospatial query but inside body of CREATE/ALTER View, Procedure, Function: - * Subquery_expr DOT Geospatial_col -> ( Subquery_expr ) DOT Geospatial_col - * Subquery_expr DOT Geospatial_func (arg_list) -> Geospatial_func (agr_list, Subquery_expr) + /* We are adding handling for CLR_UDT Types in: + * tsqlCommonMutator: for cases CREATE/ALTER View, Procedure, Function */ - void exitSubquery_expr(TSqlParser::Subquery_exprContext *ctx) override + void exitClr_udt_func_call(TSqlParser::Clr_udt_func_callContext *ctx) override { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> subquery.method_call */ - if(method->spatial_methods()) - { - size_t ind; - if (i == 0) ind = ctx->subquery()->stop->getStopIndex(); - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } + handleClrUdtFuncCall(ctx); } void exitFull_column_name(TSqlParser::Full_column_nameContext *ctx) override { - GetCtxFunc getSchema = [](TSqlParser::Full_column_nameContext *o) { return o->schema; }; - GetCtxFunc getTablename = [](TSqlParser::Full_column_nameContext *o) { return o->tablename; }; - - /* This if clause rewrites the query in case of Geospatial function Call */ - std::string func_name; - /* Handles rewrite of geospatial query but inside body of CREATE/ALTER View, Procedure, Function: */ - if(ctx->column_name) func_name = stripQuoteFromId(ctx->column_name); - else if (ctx->geospatial_col()) - { - /* Throwing error similar to TSQL as we do not allow 4-Part name for geospatial function call */ - if(ctx->schema) throw PGErrorWrapperException(ERROR, ERRCODE_SYNTAX_ERROR, format_errmsg("The multi-part identifier \"%s\" could not be bound.", ::getFullText(ctx).c_str()), getLineAndPos(ctx)); - - /* Rewriting the query as: table.col.STX -> (table.col).STX */ - std::string rewritten_func_name = "("; - if(ctx->table) rewritten_func_name += stripQuoteFromId(ctx->table) + "."; - rewritten_func_name += stripQuoteFromId(ctx->column) + ")." + ::getFullText(ctx->geospatial_col()); - rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(::getFullText(ctx), rewritten_func_name.c_str()))); - } - - std::string rewritten_name = rewrite_column_name_with_omitted_schema_name(ctx, getSchema, getTablename); - std::string rewritten_schema_name = rewrite_information_schema_to_information_schema_tsql(ctx, getSchema); - if (!rewritten_name.empty()) - rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(::getFullText(ctx), rewritten_name))); - if (pltsql_enable_tsql_information_schema && !rewritten_schema_name.empty()) - rewritten_query_fragment.emplace(std::make_pair(ctx->schema->start->getStartIndex(), std::make_pair(::getFullText(ctx->schema), rewritten_schema_name))); - - if (does_object_name_need_delimiter(ctx->tablename)) - rewritten_query_fragment.emplace(std::make_pair(ctx->tablename->start->getStartIndex(), std::make_pair(::getFullText(ctx->tablename), delimit_identifier(ctx->tablename)))); - - // qualified identifier doesn't need delimiter - if (ctx->DOT().empty() && does_object_name_need_delimiter(ctx->column_name)) - rewritten_query_fragment.emplace(std::make_pair(ctx->column_name->start->getStartIndex(), std::make_pair(::getFullText(ctx->column_name), delimit_identifier(ctx->column_name)))); + handleFullColumnNameCtx(ctx); } /* Object Name */ @@ -2373,128 +2237,15 @@ class tsqlBuilder : public tsqlCommonMutator void exitFull_column_name(TSqlParser::Full_column_nameContext *ctx) override { - GetCtxFunc getSchema = [](TSqlParser::Full_column_nameContext *o) { return o->schema; }; - GetCtxFunc getTablename = [](TSqlParser::Full_column_nameContext *o) { return o->tablename; }; - - /* Handles rewrite of geospatial query except inside body of CREATE/ALTER View, Procedure, Function: */ - std::string func_name; - if(ctx->column_name) func_name = stripQuoteFromId(ctx->column_name); - else if (ctx->geospatial_col()) - { - /* Throwing error similar to TSQL as we do not allow 4-Part name for geospatial function call */ - if(ctx->schema) throw PGErrorWrapperException(ERROR, ERRCODE_SYNTAX_ERROR, format_errmsg("The multi-part identifier \"%s\" could not be bound.", ::getFullText(ctx).c_str()), getLineAndPos(ctx)); - - /* Rewriting the query as: table.col.STX -> (table.col).STX */ - std::string rewritten_func_name = "("; - if(ctx->table) rewritten_func_name += stripQuoteFromId(ctx->table) + "."; - rewritten_func_name += stripQuoteFromId(ctx->column) + ")." + ::getFullText(ctx->geospatial_col()); - rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(::getFullText(ctx), rewritten_func_name.c_str()))); - } - - std::string rewritten_name = rewrite_column_name_with_omitted_schema_name(ctx, getSchema, getTablename); - std::string rewritten_schema_name = rewrite_information_schema_to_information_schema_tsql(ctx, getSchema); - if (!rewritten_name.empty()) - rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(::getFullText(ctx), rewritten_name))); - if (pltsql_enable_tsql_information_schema && !rewritten_schema_name.empty()) - rewritten_query_fragment.emplace(std::make_pair(ctx->schema->start->getStartIndex(), std::make_pair(::getFullText(ctx->schema), rewritten_schema_name))); - - if (does_object_name_need_delimiter(ctx->tablename)) - rewritten_query_fragment.emplace(std::make_pair(ctx->tablename->start->getStartIndex(), std::make_pair(::getFullText(ctx->tablename), delimit_identifier(ctx->tablename)))); - - /* qualified identifier doesn't need delimiter */ - if (ctx->DOT().empty() && does_object_name_need_delimiter(ctx->column_name)) - rewritten_query_fragment.emplace(std::make_pair(ctx->column_name->start->getStartIndex(), std::make_pair(::getFullText(ctx->column_name), delimit_identifier(ctx->column_name)))); - } - - /* Handles rewrite of geospatial query except inside body of CREATE/ALTER View, Procedure, Function: */ - void exitFunc_call_expr(TSqlParser::Func_call_exprContext *ctx) override - { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> function_call.method_call */ - if(method->spatial_methods()) - { - size_t ind; - if (i == 0) ind = ctx->function_call()->stop->getStopIndex(); - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } - } - - /* Handles rewrite of geospatial query except inside body of CREATE/ALTER View, Procedure, Function: */ - void exitLocal_id_expr(TSqlParser::Local_id_exprContext *ctx) override - { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> local_id.method_call */ - if(method->spatial_methods()) - { - size_t ind; - if (i == 0) ind = ctx->local_id()->stop->getStopIndex(); - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } + handleFullColumnNameCtx(ctx); } - /* Handles rewrite of geospatial query except inside body of CREATE/ALTER View, Procedure, Function: */ - void exitBracket_expr(TSqlParser::Bracket_exprContext *ctx) override - { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> LR_BRACKET expression RR_BRACKET.method_call */ - if(method->spatial_methods()) - { - size_t ind; - std::string context = ::getFullText(ctx); - size_t spaces = 0; - for (size_t x = ctx->expression()->stop->getStopIndex() + 1 - ctx->start->getStartIndex(); x <= ctx->stop->getStopIndex() - ctx->start->getStartIndex(); x++) - { - if(context[x] == ' ') spaces++; - else if(context[x] == ')') break; - } - if (i == 0) ind = ctx->expression()->stop->getStopIndex() + 1 + spaces; - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } - } - - /* Handles rewrite of geospatial query except inside body of CREATE/ALTER View, Procedure, Function: */ - void exitSubquery_expr(TSqlParser::Subquery_exprContext *ctx) override + /* We are adding handling for CLR_UDT Types in: + * tsqlBuilder: for cases other than inside CREATE/ALTER View, Procedure, Function + */ + void exitClr_udt_func_call(TSqlParser::Clr_udt_func_callContext *ctx) override { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> subquery.method_call */ - if(method->spatial_methods()) - { - size_t ind; - if (i == 0) ind = ctx->subquery()->stop->getStopIndex(); - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } + handleClrUdtFuncCall(ctx); } ////////////////////////////////////////////////////////////////////////////// @@ -2570,17 +2321,7 @@ class tsqlBuilder : public tsqlCommonMutator rewritten_query_fragment.emplace(std::make_pair(bctx->bif_no_brackets->getStartIndex(), std::make_pair(::getFullText(bctx->SESSION_USER()), "sys.session_user()"))); } - /* Handles rewrite of geospatial query except inside body of CREATE/ALTER View, Procedure, Function: */ - if (ctx->spatial_proc_name_server_database_schema()) - { - if (ctx->spatial_proc_name_server_database_schema()->schema) throw PGErrorWrapperException(ERROR, ERRCODE_FEATURE_NOT_SUPPORTED, "Remote procedure/function reference with 4-part object name is not currently supported in Babelfish", getLineAndPos(ctx)); - - /* This if-elseIf clause rewrites the query in case of Geospatial function Call */ - if (ctx->spatial_proc_name_server_database_schema()->geospatial_func_arg() && ctx->function_arg_list()) - rewrite_function_call_geospatial_func_ref_args(ctx); - else if (ctx->spatial_proc_name_server_database_schema()->geospatial_func_no_arg() && !ctx->function_arg_list()) - rewrite_function_call_geospatial_func_ref_no_arg(ctx); - } + handleGeospatialFunctionsInFunctionCall(ctx); /* analyze scalar function call */ if (ctx->func_proc_name_server_database_schema()) @@ -7970,42 +7711,31 @@ rewrite_geospatial_col_ref_query_helper(T ctx, TSqlParser::Method_callContext *m { std::vector keysToRemove; std::string ctx_str = ::getFullText(ctx); + ctx_str = ctx_str.substr(0, method->stop->getStopIndex() - ctx->start->getStartIndex() + 1); int func_call_len = (int)geospatial_start_index - ctx->start->getStartIndex(); int method_len = (int)method->stop->getStopIndex() - method->start->getStartIndex(); std::string expr = ""; int index = 0; int offset1 = 0; - int offset2 = 0; /* writting the previously rewritten Geospatial context */ for (auto &entry : rewritten_query_fragment) { - if(entry.first >= ctx->start->getStartIndex() && entry.first <= ctx->stop->getStopIndex()) + if(entry.first >= ctx->start->getStartIndex() && entry.first <= method->stop->getStopIndex()) { expr += ctx_str.substr(index, (int)entry.first - ctx->start->getStartIndex() - index) + entry.second.second; index = (int)entry.first - ctx->start->getStartIndex() + entry.second.first.size(); keysToRemove.push_back(entry.first); if(entry.first <= geospatial_start_index) offset1 += (int)entry.second.second.size() - entry.second.first.size(); - else offset2 += (int)entry.second.second.size() - entry.second.first.size(); } } for (const auto &key : keysToRemove) rewritten_query_fragment.erase(key); keysToRemove.clear(); - - /* shifting the local id positions to new positions after rewriting the query since they will be quoted later */ - for (auto &entry : local_id_positions) - { - if(entry.first >= ctx->start->getStartIndex() && entry.first <= geospatial_start_index) - { - keysToRemove.push_back(entry.first); - local_id_positions.emplace(std::make_pair(entry.first + 1, entry.second)); - } - } - for (const auto &key : keysToRemove) local_id_positions.erase(key); - keysToRemove.clear(); expr += ctx_str.substr(index); - std::string rewritten_exp = "(" + expr.substr(0, func_call_len + offset1 + 1) + ")." + expr.substr((int)method->start->getStartIndex() - ctx->start->getStartIndex() + offset1, method_len + offset2 + 1); - if ((int)method->stop->getStopIndex() - ctx->start->getStartIndex() + 1 < ctx_str.size()) rewritten_exp += expr.substr(method->stop->getStopIndex() + offset1 - ctx->start->getStartIndex() + 1); + + handleLocalIdQuotingFuncRefNoArg(ctx, geospatial_start_index, offset1, expr, keysToRemove); + + std::string rewritten_exp = "(" + expr.substr(0, func_call_len + offset1 + 1) + ")." + expr.substr((int)method->start->getStartIndex() - ctx->start->getStartIndex() + offset1, method_len + 1); rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(ctx_str.c_str(), rewritten_exp.c_str()))); } @@ -8019,42 +7749,31 @@ rewrite_geospatial_func_ref_no_arg_query_helper(T ctx, TSqlParser::Method_callCo { std::vector keysToRemove; std::string ctx_str = ::getFullText(ctx); + ctx_str = ctx_str.substr(0, method->stop->getStopIndex() - ctx->start->getStartIndex() + 1); int func_call_len = (int)geospatial_start_index - ctx->start->getStartIndex(); int method_len = (int)method->stop->getStopIndex() - method->start->getStartIndex(); std::string expr = ""; int index = 0; int offset1 = 0; - int offset2 = 0; /* writting the previously rewritten Geospatial context */ for (auto &entry : rewritten_query_fragment) { - if(entry.first >= ctx->start->getStartIndex() && entry.first <= ctx->stop->getStopIndex()) + if(entry.first >= ctx->start->getStartIndex() && entry.first <= method->stop->getStopIndex()) { expr += ctx_str.substr(index, (int)entry.first - ctx->start->getStartIndex() - index) + entry.second.second; index = (int)entry.first - ctx->start->getStartIndex() + entry.second.first.size(); keysToRemove.push_back(entry.first); if(entry.first <= geospatial_start_index) offset1 += (int)entry.second.second.size() - entry.second.first.size(); - else offset2 += (int)entry.second.second.size() - entry.second.first.size(); } } for (const auto &key : keysToRemove) rewritten_query_fragment.erase(key); keysToRemove.clear(); - - /* shifting the local id positions to new positions after rewriting the query since they will be quoted later */ - for (auto &entry : local_id_positions) - { - if(entry.first >= ctx->start->getStartIndex() && entry.first <= geospatial_start_index) - { - keysToRemove.push_back(entry.first); - local_id_positions.emplace(std::make_pair(entry.first + method->spatial_methods()->geospatial_func_no_arg()->stop->getStopIndex() - method->spatial_methods()->geospatial_func_no_arg()->start->getStartIndex() + 1, entry.second)); - } - } - for (const auto &key : keysToRemove) local_id_positions.erase(key); - keysToRemove.clear(); expr += ctx_str.substr(index); - std::string rewritten_exp = expr.substr((int)method->start->getStartIndex() - ctx->start->getStartIndex() + offset1, method_len + offset2) + expr.substr(0, func_call_len + offset1 + 1) + ")"; - if ((int)method->stop->getStopIndex() - ctx->start->getStartIndex() + 1 < ctx_str.size()) rewritten_exp += expr.substr(method->stop->getStopIndex() + offset1 - ctx->start->getStartIndex() + 1); + + handleLocalIdQuotingFuncRefNoArg(ctx, geospatial_start_index, offset1, expr, keysToRemove); + + std::string rewritten_exp = expr.substr((int)method->start->getStartIndex() - ctx->start->getStartIndex() + offset1, method_len) + expr.substr(0, func_call_len + offset1 + 1) + ")"; rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(ctx_str.c_str(), rewritten_exp.c_str()))); } @@ -8068,50 +7787,71 @@ rewrite_geospatial_func_ref_args_query_helper(T ctx, TSqlParser::Method_callCont { std::vector keysToRemove; std::string ctx_str = ::getFullText(ctx); + ctx_str = ctx_str.substr(0, method->stop->getStopIndex() - ctx->start->getStartIndex() + 1); int func_call_len = (int)geospatial_start_index - ctx->start->getStartIndex(); int method_len = (int)method->stop->getStopIndex() - method->start->getStartIndex(); std::string expr = ""; int index = 0; int offset1 = 0; int offset2 = 0; + std::vector> arg_offset_list; + int local_id_end_offset = 0; /* writting the previously rewritten Geospatial context */ for (auto &entry : rewritten_query_fragment) { - if(entry.first >= ctx->start->getStartIndex() && entry.first <= ctx->stop->getStopIndex()) + if(entry.first >= ctx->start->getStartIndex() && entry.first <= method->stop->getStopIndex()) { expr += ctx_str.substr(index, (int)entry.first - ctx->start->getStartIndex() - index) + entry.second.second; index = (int)entry.first - ctx->start->getStartIndex() + entry.second.first.size(); keysToRemove.push_back(entry.first); if(entry.first <= geospatial_start_index) offset1 += (int)entry.second.second.size() - entry.second.first.size(); - else offset2 += (int)entry.second.second.size() - entry.second.first.size(); + else if(entry.first > geospatial_start_index && entry.first <= method->stop->getStopIndex()) + { + offset2 += (int)entry.second.second.size() - entry.second.first.size(); + /* storing these values in a list so that we could correctly calculate the offset for local_id argument rewrites */ + arg_offset_list.push_back(std::make_pair((int)entry.first, (int)entry.second.second.size() - entry.second.first.size())); + } } } for (const auto &key : keysToRemove) rewritten_query_fragment.erase(key); keysToRemove.clear(); + expr += ctx_str.substr(index); - /* shifting the local id positions to new positions after rewriting the query since they will be quoted later */ + /* quoting local_id here so as to remove possibility of multiple rewrites in a single context */ for (auto &entry : local_id_positions) { if(entry.first >= ctx->start->getStartIndex() && entry.first <= geospatial_start_index) { - keysToRemove.push_back(entry.first); - local_id_positions.emplace(std::make_pair(entry.first + method->spatial_methods()->expression_list()->stop->getStopIndex() - method->spatial_methods()->geospatial_func_arg()->start->getStartIndex() + 2, entry.second)); + /* Here we are quoting local_id which are before the function name */ + int local_index = (int)entry.first - ctx->start->getStartIndex() + offset1; + if(expr.substr(local_index, entry.second.size()) == entry.second) + { + keysToRemove.push_back(entry.first); + expr = expr.substr(0, local_index) + "\"" + entry.second + "\"" + expr.substr(local_index + entry.second.size()); + offset1 += 2; + } } else if(entry.first >= method->spatial_methods()->expression_list()->start->getStartIndex() && entry.first <= method->spatial_methods()->expression_list()->stop->getStopIndex()) { - size_t pos = entry.first; - size_t offset = method->start->getStartIndex() - ctx->start->getStartIndex(); - pos -= offset; - keysToRemove.push_back(entry.first); - local_id_positions.emplace(std::make_pair(pos, entry.second)); + /* Here we are quoting local_id which are within the argument list of the function */ + int local_index = (int)entry.first - ctx->start->getStartIndex() + offset1 + local_id_end_offset; + for (size_t i = 0; i < arg_offset_list.size(); i++) + { + if((size_t)arg_offset_list[i].first < entry.first) local_index += arg_offset_list[i].second; + } + if(expr.substr(local_index, entry.second.size()) == entry.second) + { + keysToRemove.push_back(entry.first); + expr = expr.substr(0, local_index) + "\"" + entry.second + "\"" + expr.substr(local_index + entry.second.size()); + offset2 += 2; + local_id_end_offset += 2; + } } } for (const auto &key : keysToRemove) local_id_positions.erase(key); keysToRemove.clear(); - expr += ctx_str.substr(index); std::string rewritten_exp = expr.substr((int)method->start->getStartIndex() - ctx->start->getStartIndex() + offset1, method_len + offset2) + "," + expr.substr(0, func_call_len + offset1 + 1) + ")"; - if ((int)method->stop->getStopIndex() - ctx->start->getStartIndex() + 1 < ctx_str.size()) rewritten_exp += expr.substr(method->stop->getStopIndex() + offset1 - ctx->start->getStartIndex() + 1); rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(ctx_str.c_str(), rewritten_exp.c_str()))); } @@ -8181,6 +7921,8 @@ rewrite_function_call_geospatial_func_ref_args(T ctx) int index = 0; int offset1 = 0; int offset2 = 0; + std::vector> arg_offset_list; + int local_id_end_offset = 0; /* writting the previously rewritten Geospatial context */ for (auto &entry : rewritten_query_fragment) @@ -8191,22 +7933,36 @@ rewrite_function_call_geospatial_func_ref_args(T ctx) index = (int)entry.first - ctx->start->getStartIndex() + entry.second.first.size(); keysToRemove.push_back(entry.first); if(entry.first <= ctx->spatial_proc_name_server_database_schema()->column->stop->getStopIndex()) offset1 += (int)entry.second.second.size() - entry.second.first.size(); - else offset2 += (int)entry.second.second.size() - entry.second.first.size(); + else + { + offset2 += (int)entry.second.second.size() - entry.second.first.size(); + /* storing these values in a list so that we could correctly calculate the offset for local_id argument rewrites */ + arg_offset_list.push_back(std::make_pair((int)entry.first, (int)entry.second.second.size() - entry.second.first.size())); + } } } for (const auto &key : keysToRemove) rewritten_query_fragment.erase(key); keysToRemove.clear(); + expr += func_ctx.substr(index); - /* Shifting the local id positions to new positions after rewriting the query since they will be quoted later */ + /* quoting local_id here so as to remove possibility of multiple rewrites in a single context */ for (auto &entry : local_id_positions) { if(entry.first >= ctx->function_arg_list()->start->getStartIndex() && entry.first <= ctx->function_arg_list()->stop->getStopIndex()) { - size_t pos = entry.first; - size_t offset = ctx->spatial_proc_name_server_database_schema()->geospatial_func_arg()->start->getStartIndex() - ctx->spatial_proc_name_server_database_schema()->start->getStartIndex(); - pos -= offset; - keysToRemove.push_back(entry.first); - local_id_positions.emplace(std::make_pair(pos, entry.second)); + /* Here we are quoting local_id which are within the argument list of the function */ + int local_index = (int)entry.first - ctx->start->getStartIndex() + offset1 + local_id_end_offset; + for (size_t i = 0; i < arg_offset_list.size(); i++) + { + if((size_t)arg_offset_list[i].first < entry.first) local_index += arg_offset_list[i].second; + } + if(expr.substr(local_index, entry.second.size()) == entry.second) + { + keysToRemove.push_back(entry.first); + expr = expr.substr(0, local_index) + "\"" + entry.second + "\"" + expr.substr(local_index + entry.second.size()); + offset2 += 2; + local_id_end_offset += 2; + } } } for (const auto &key : keysToRemove) local_id_positions.erase(key); @@ -8215,7 +7971,6 @@ rewrite_function_call_geospatial_func_ref_args(T ctx) /* * Rewriting the query as: table.col.STDistance(arg) -> STDistance(arg, table.col) */ - expr += func_ctx.substr(index); std::string rewritten_func = expr.substr((int)ctx->spatial_proc_name_server_database_schema()->geospatial_func_arg()->start->getStartIndex() - ctx->start->getStartIndex() + offset1, method_len + offset2) + "," + expr.substr(0, col_len + offset1 + 1) + ")"; rewritten_query_fragment.emplace(std::make_pair(ctx->spatial_proc_name_server_database_schema()->start->getStartIndex(), std::make_pair(::getFullText(ctx), rewritten_func.c_str()))); } @@ -8224,6 +7979,108 @@ rewrite_function_call_geospatial_func_ref_args(T ctx) // End of Spatial Query Helper for Function Calls //////////////////////////////////////////////////////////////////////////////// +template +static void +handleLocalIdQuotingFuncRefNoArg(T ctx, size_t geospatial_start_index, int &offset1, std::string &expr, std::vector keysToRemove) +{ + /* quoting local_id here so as to remove possibility of multiple rewrites in a single context */ + for (auto &entry : local_id_positions) + { + if(entry.first >= ctx->start->getStartIndex() && entry.first <= geospatial_start_index) + { + /* Here we are quoting local_id which are before the function name */ + int local_index = (int)entry.first - ctx->start->getStartIndex() + offset1; + if(expr.substr(local_index, entry.second.size()) == entry.second) + { + keysToRemove.push_back(entry.first); + expr = expr.substr(0, local_index) + "\"" + entry.second + "\"" + expr.substr(local_index + entry.second.size()); + offset1 += 2; + } + } + } + for (const auto &key : keysToRemove) local_id_positions.erase(key); + keysToRemove.clear(); +} + +static void +handleGeospatialFunctionsInFunctionCall(TSqlParser::Function_callContext *ctx) +{ + /* Handles rewrite of geospatial function calls */ + if (ctx->spatial_proc_name_server_database_schema()) + { + if (ctx->spatial_proc_name_server_database_schema()->schema) throw PGErrorWrapperException(ERROR, ERRCODE_FEATURE_NOT_SUPPORTED, "Remote procedure/function reference with 4-part object name is not currently supported in Babelfish", getLineAndPos(ctx)); + + /* This if-elseIf clause rewrites the query in case of geospatial function calls */ + if (ctx->spatial_proc_name_server_database_schema()->geospatial_func_arg() && ctx->function_arg_list()) + rewrite_function_call_geospatial_func_ref_args(ctx); + else if (ctx->spatial_proc_name_server_database_schema()->geospatial_func_no_arg() && !ctx->function_arg_list()) + rewrite_function_call_geospatial_func_ref_no_arg(ctx); + } +} + +static void +handleClrUdtFuncCall(TSqlParser::Clr_udt_func_callContext *ctx) +{ + /* checking if CLR_UDT types */ + if(ctx != NULL && !ctx->DOT().empty()) + { + std::vector method_calls = ctx->method_call(); + for (size_t i = 0; i < method_calls.size(); ++i) + { + TSqlParser::Method_callContext *method = method_calls[i]; + /* rewriting the query in case of geospatial function calls */ + if(method->spatial_methods()) + { + size_t ind = -1; + if (i == 0) + { + if(ctx->local_id()) ind = ctx->local_id()->stop->getStopIndex(); + else if(ctx->subquery()) ind = ctx->subquery()->stop->getStopIndex(); + else if(ctx->function_call()) ind = ctx->function_call()->stop->getStopIndex(); + else if(ctx->RR_BRACKET()) ind = ctx->RR_BRACKET()->getSymbol()->getStartIndex(); + } + else ind = method_calls[i-1]->stop->getStopIndex(); + rewrite_geospatial_query_helper(ctx, method, ind); + } + } + } +} + +static void +handleFullColumnNameCtx(TSqlParser::Full_column_nameContext *ctx) +{ + GetCtxFunc getSchema = [](TSqlParser::Full_column_nameContext *o) { return o->schema; }; + GetCtxFunc getTablename = [](TSqlParser::Full_column_nameContext *o) { return o->tablename; }; + + std::string func_name; + /* Handles rewrite of geospatial query */ + if(ctx->column_name) func_name = stripQuoteFromId(ctx->column_name); + else if (ctx->geospatial_col()) + { + /* Throwing error similar to TSQL as we do not allow 4-Part name for geospatial function calls */ + if(ctx->schema) throw PGErrorWrapperException(ERROR, ERRCODE_SYNTAX_ERROR, format_errmsg("The multi-part identifier \"%s\" could not be bound.", ::getFullText(ctx).c_str()), getLineAndPos(ctx)); + + /* Rewriting the query as: table.col.STX -> (table.col).STX */ + std::string ctx_str = ::getFullText(ctx); + std::string rewritten_func_name = "(" + ctx_str.substr(0, ctx->column->stop->getStopIndex() - ctx->start->getStartIndex() + 1) + ")." + ctx_str.substr(ctx->geospatial_col()->start->getStartIndex() - ctx->start->getStartIndex()); + rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(ctx_str, rewritten_func_name.c_str()))); + } + + std::string rewritten_name = rewrite_column_name_with_omitted_schema_name(ctx, getSchema, getTablename); + std::string rewritten_schema_name = rewrite_information_schema_to_information_schema_tsql(ctx, getSchema); + if (!rewritten_name.empty()) + rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(::getFullText(ctx), rewritten_name))); + if (pltsql_enable_tsql_information_schema && !rewritten_schema_name.empty()) + rewritten_query_fragment.emplace(std::make_pair(ctx->schema->start->getStartIndex(), std::make_pair(::getFullText(ctx->schema), rewritten_schema_name))); + + if (does_object_name_need_delimiter(ctx->tablename)) + rewritten_query_fragment.emplace(std::make_pair(ctx->tablename->start->getStartIndex(), std::make_pair(::getFullText(ctx->tablename), delimit_identifier(ctx->tablename)))); + + // qualified identifier doesn't need delimiter + if (ctx->DOT().empty() && does_object_name_need_delimiter(ctx->column_name)) + rewritten_query_fragment.emplace(std::make_pair(ctx->column_name->start->getStartIndex(), std::make_pair(::getFullText(ctx->column_name), delimit_identifier(ctx->column_name)))); +} + static bool does_object_name_need_delimiter(TSqlParser::IdContext *id) {