diff --git a/contrib/babelfishpg_tsql/antlr/TSqlParser.g4 b/contrib/babelfishpg_tsql/antlr/TSqlParser.g4 index 5e7c71f6b0..ebc1fbf719 100644 --- a/contrib/babelfishpg_tsql/antlr/TSqlParser.g4 +++ b/contrib/babelfishpg_tsql/antlr/TSqlParser.g4 @@ -3436,10 +3436,7 @@ constant_LOCAL_ID // https://docs.microsoft.com/en-us/sql/t-sql/language-elements/expressions-transact-sql // Operator precendence: https://docs.microsoft.com/en-us/sql/t-sql/language-elements/operator-precedence-transact-sql expression - : local_id (DOT calls+=method_call)* #local_id_expr - | subquery (DOT calls+=method_call)* #subquery_expr - | LR_BRACKET expression RR_BRACKET (DOT calls+=method_call)* #bracket_expr - | function_call (DOT calls+=method_call)* #func_call_expr + : clr_udt_func_call #clr_udt_expr | expression collation #collate_expr | expression AT_KEYWORD TIME ZONE expression #time_zone_expr | op=(MINUS | PLUS | BIT_NOT) expression #unary_op_expr @@ -3456,6 +3453,13 @@ expression | DOLLAR_ACTION #dollar_action_expr ; +clr_udt_func_call + : local_id (DOT calls+=method_call)* + | subquery (DOT calls+=method_call)* + | LR_BRACKET expression RR_BRACKET (DOT calls+=method_call)* + | function_call (DOT calls+=method_call)* + ; + method_call : xml_methods | hierarchyid_methods diff --git a/contrib/babelfishpg_tsql/src/tsqlIface.cpp b/contrib/babelfishpg_tsql/src/tsqlIface.cpp index 2967472336..2dbf21c32d 100644 --- a/contrib/babelfishpg_tsql/src/tsqlIface.cpp +++ b/contrib/babelfishpg_tsql/src/tsqlIface.cpp @@ -179,6 +179,10 @@ template static void rewrite_geospatial_func_ref_no_arg_query_helper(T template static void rewrite_geospatial_func_ref_args_query_helper(T ctx, TSqlParser::Method_callContext *method, size_t geospatial_start_index); template static void rewrite_function_call_geospatial_func_ref_args(T ctx); template static void rewrite_function_call_geospatial_func_ref_no_arg(T ctx); +static void handleGeospatialFunctionsInFunctionCall(TSqlParser::Function_callContext *ctx); +static void handleClrUdtFuncCall(TSqlParser::Clr_udt_func_callContext *ctx); +static void handleFullColumnNameCtx(TSqlParser::Full_column_nameContext *ctx); +template static void handleLocalIdQuotingFuncRefNoArg(T ctx, size_t geospatial_start_index, int &offset1, std::string &expr, std::vector keysToRemove); static bool does_object_name_need_delimiter(TSqlParser::IdContext *id); static std::string delimit_identifier(TSqlParser::IdContext *id); static bool does_msg_exceeds_params_limit(const std::string& msg); @@ -897,160 +901,20 @@ class tsqlCommonMutator : public TSqlParserBaseListener void exitFunction_call(TSqlParser::Function_callContext *ctx) override { - /* Handles rewrite of geospatial function calls but inside body of CREATE/ALTER View, Procedure, Function */ - if (ctx->spatial_proc_name_server_database_schema()) - { - if (ctx->spatial_proc_name_server_database_schema()->schema) throw PGErrorWrapperException(ERROR, ERRCODE_FEATURE_NOT_SUPPORTED, "Remote procedure/function reference with 4-part object name is not currently supported in Babelfish", getLineAndPos(ctx)); - - /* This if-elseIf clause rewrites the query in case of Geospatial function Call */ - if (ctx->spatial_proc_name_server_database_schema()->geospatial_func_arg() && ctx->function_arg_list()) - rewrite_function_call_geospatial_func_ref_args(ctx); - else if (ctx->spatial_proc_name_server_database_schema()->geospatial_func_no_arg() && !ctx->function_arg_list()) - rewrite_function_call_geospatial_func_ref_no_arg(ctx); - } - } - - /* We are adding handling for Spatial Types in: - * tsqlCommonMutator: for CREATE/ALTER View, Procedure, Function - * tsqlBuilder: for other cases handling - */ - /* Here we are Rewriting Geospatial query query: Func_call DOT Geospatial_Func: - * Func_call DOT Geospatial_col -> ( Func_call ) DOT Geospatial_col - * Func_call DOT Geospatial_func (arg_list) -> Geospatial_func (agr_list, Func_call) - */ - void exitFunc_call_expr(TSqlParser::Func_call_exprContext *ctx) override - { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> function_call.method_call */ - if(method->spatial_methods()) - { - size_t ind; - if (i == 0) ind = ctx->function_call()->stop->getStopIndex(); - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } - } - - /* Handles rewrite of geospatial query but inside body of CREATE/ALTER View, Procedure, Function: - * local_id DOT Geospatial_col -> ( local_id ) DOT Geospatial_col - * local_id DOT Geospatial_func (arg_list) -> Geospatial_func (agr_list, local_id) - */ - void exitLocal_id_expr(TSqlParser::Local_id_exprContext *ctx) override - { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> local_id.method_call */ - if(method->spatial_methods()) - { - size_t ind; - if (i == 0) ind = ctx->local_id()->stop->getStopIndex(); - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } - } - - /* Handles rewrite of geospatial query but inside body of CREATE/ALTER View, Procedure, Function: - * bracket_expr DOT Geospatial_col -> ( bracket_expr ) DOT Geospatial_col - * bracket_expr DOT Geospatial_func (arg_list) -> Geospatial_func (agr_list, bracket_expr) - */ - void exitBracket_expr(TSqlParser::Bracket_exprContext *ctx) override - { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> LR_BRACKET expression RR_BRACKET.method_call */ - if(method->spatial_methods()) - { - size_t ind; - std::string context = ::getFullText(ctx); - size_t spaces = 0; - for (size_t x = ctx->expression()->stop->getStopIndex() + 1 - ctx->start->getStartIndex(); x <= ctx->stop->getStopIndex() - ctx->start->getStartIndex(); x++) - { - if(context[x] == ' ') spaces++; - else if(context[x] == ')') break; - } - if (i == 0) ind = ctx->expression()->stop->getStopIndex() + 1 + spaces; - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } + handleGeospatialFunctionsInFunctionCall(ctx); } - /* Handles rewrite of geospatial query but inside body of CREATE/ALTER View, Procedure, Function: - * Subquery_expr DOT Geospatial_col -> ( Subquery_expr ) DOT Geospatial_col - * Subquery_expr DOT Geospatial_func (arg_list) -> Geospatial_func (agr_list, Subquery_expr) + /* We are adding handling for CLR_UDT Types in: + * tsqlCommonMutator: for cases CREATE/ALTER View, Procedure, Function */ - void exitSubquery_expr(TSqlParser::Subquery_exprContext *ctx) override + void exitClr_udt_func_call(TSqlParser::Clr_udt_func_callContext *ctx) override { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> subquery.method_call */ - if(method->spatial_methods()) - { - size_t ind; - if (i == 0) ind = ctx->subquery()->stop->getStopIndex(); - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } + handleClrUdtFuncCall(ctx); } void exitFull_column_name(TSqlParser::Full_column_nameContext *ctx) override { - GetCtxFunc getSchema = [](TSqlParser::Full_column_nameContext *o) { return o->schema; }; - GetCtxFunc getTablename = [](TSqlParser::Full_column_nameContext *o) { return o->tablename; }; - - /* This if clause rewrites the query in case of Geospatial function Call */ - std::string func_name; - /* Handles rewrite of geospatial query but inside body of CREATE/ALTER View, Procedure, Function: */ - if(ctx->column_name) func_name = stripQuoteFromId(ctx->column_name); - else if (ctx->geospatial_col()) - { - /* Throwing error similar to TSQL as we do not allow 4-Part name for geospatial function call */ - if(ctx->schema) throw PGErrorWrapperException(ERROR, ERRCODE_SYNTAX_ERROR, format_errmsg("The multi-part identifier \"%s\" could not be bound.", ::getFullText(ctx).c_str()), getLineAndPos(ctx)); - - /* Rewriting the query as: table.col.STX -> (table.col).STX */ - std::string rewritten_func_name = "("; - if(ctx->table) rewritten_func_name += stripQuoteFromId(ctx->table) + "."; - rewritten_func_name += stripQuoteFromId(ctx->column) + ")." + ::getFullText(ctx->geospatial_col()); - rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(::getFullText(ctx), rewritten_func_name.c_str()))); - } - - std::string rewritten_name = rewrite_column_name_with_omitted_schema_name(ctx, getSchema, getTablename); - std::string rewritten_schema_name = rewrite_information_schema_to_information_schema_tsql(ctx, getSchema); - if (!rewritten_name.empty()) - rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(::getFullText(ctx), rewritten_name))); - if (pltsql_enable_tsql_information_schema && !rewritten_schema_name.empty()) - rewritten_query_fragment.emplace(std::make_pair(ctx->schema->start->getStartIndex(), std::make_pair(::getFullText(ctx->schema), rewritten_schema_name))); - - if (does_object_name_need_delimiter(ctx->tablename)) - rewritten_query_fragment.emplace(std::make_pair(ctx->tablename->start->getStartIndex(), std::make_pair(::getFullText(ctx->tablename), delimit_identifier(ctx->tablename)))); - - // qualified identifier doesn't need delimiter - if (ctx->DOT().empty() && does_object_name_need_delimiter(ctx->column_name)) - rewritten_query_fragment.emplace(std::make_pair(ctx->column_name->start->getStartIndex(), std::make_pair(::getFullText(ctx->column_name), delimit_identifier(ctx->column_name)))); + handleFullColumnNameCtx(ctx); } /* Object Name */ @@ -2373,128 +2237,15 @@ class tsqlBuilder : public tsqlCommonMutator void exitFull_column_name(TSqlParser::Full_column_nameContext *ctx) override { - GetCtxFunc getSchema = [](TSqlParser::Full_column_nameContext *o) { return o->schema; }; - GetCtxFunc getTablename = [](TSqlParser::Full_column_nameContext *o) { return o->tablename; }; - - /* Handles rewrite of geospatial query except inside body of CREATE/ALTER View, Procedure, Function: */ - std::string func_name; - if(ctx->column_name) func_name = stripQuoteFromId(ctx->column_name); - else if (ctx->geospatial_col()) - { - /* Throwing error similar to TSQL as we do not allow 4-Part name for geospatial function call */ - if(ctx->schema) throw PGErrorWrapperException(ERROR, ERRCODE_SYNTAX_ERROR, format_errmsg("The multi-part identifier \"%s\" could not be bound.", ::getFullText(ctx).c_str()), getLineAndPos(ctx)); - - /* Rewriting the query as: table.col.STX -> (table.col).STX */ - std::string rewritten_func_name = "("; - if(ctx->table) rewritten_func_name += stripQuoteFromId(ctx->table) + "."; - rewritten_func_name += stripQuoteFromId(ctx->column) + ")." + ::getFullText(ctx->geospatial_col()); - rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(::getFullText(ctx), rewritten_func_name.c_str()))); - } - - std::string rewritten_name = rewrite_column_name_with_omitted_schema_name(ctx, getSchema, getTablename); - std::string rewritten_schema_name = rewrite_information_schema_to_information_schema_tsql(ctx, getSchema); - if (!rewritten_name.empty()) - rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(::getFullText(ctx), rewritten_name))); - if (pltsql_enable_tsql_information_schema && !rewritten_schema_name.empty()) - rewritten_query_fragment.emplace(std::make_pair(ctx->schema->start->getStartIndex(), std::make_pair(::getFullText(ctx->schema), rewritten_schema_name))); - - if (does_object_name_need_delimiter(ctx->tablename)) - rewritten_query_fragment.emplace(std::make_pair(ctx->tablename->start->getStartIndex(), std::make_pair(::getFullText(ctx->tablename), delimit_identifier(ctx->tablename)))); - - /* qualified identifier doesn't need delimiter */ - if (ctx->DOT().empty() && does_object_name_need_delimiter(ctx->column_name)) - rewritten_query_fragment.emplace(std::make_pair(ctx->column_name->start->getStartIndex(), std::make_pair(::getFullText(ctx->column_name), delimit_identifier(ctx->column_name)))); - } - - /* Handles rewrite of geospatial query except inside body of CREATE/ALTER View, Procedure, Function: */ - void exitFunc_call_expr(TSqlParser::Func_call_exprContext *ctx) override - { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> function_call.method_call */ - if(method->spatial_methods()) - { - size_t ind; - if (i == 0) ind = ctx->function_call()->stop->getStopIndex(); - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } - } - - /* Handles rewrite of geospatial query except inside body of CREATE/ALTER View, Procedure, Function: */ - void exitLocal_id_expr(TSqlParser::Local_id_exprContext *ctx) override - { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> local_id.method_call */ - if(method->spatial_methods()) - { - size_t ind; - if (i == 0) ind = ctx->local_id()->stop->getStopIndex(); - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } + handleFullColumnNameCtx(ctx); } - /* Handles rewrite of geospatial query except inside body of CREATE/ALTER View, Procedure, Function: */ - void exitBracket_expr(TSqlParser::Bracket_exprContext *ctx) override - { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> LR_BRACKET expression RR_BRACKET.method_call */ - if(method->spatial_methods()) - { - size_t ind; - std::string context = ::getFullText(ctx); - size_t spaces = 0; - for (size_t x = ctx->expression()->stop->getStopIndex() + 1 - ctx->start->getStartIndex(); x <= ctx->stop->getStopIndex() - ctx->start->getStartIndex(); x++) - { - if(context[x] == ' ') spaces++; - else if(context[x] == ')') break; - } - if (i == 0) ind = ctx->expression()->stop->getStopIndex() + 1 + spaces; - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } - } - - /* Handles rewrite of geospatial query except inside body of CREATE/ALTER View, Procedure, Function: */ - void exitSubquery_expr(TSqlParser::Subquery_exprContext *ctx) override + /* We are adding handling for CLR_UDT Types in: + * tsqlBuilder: for cases other than inside CREATE/ALTER View, Procedure, Function + */ + void exitClr_udt_func_call(TSqlParser::Clr_udt_func_callContext *ctx) override { - if(ctx != NULL && !ctx->DOT().empty()) - { - std::vector method_calls = ctx->method_call(); - for (size_t i = 0; i < method_calls.size(); ++i) - { - TSqlParser::Method_callContext *method = method_calls[i]; - /* rewriting the query in case of Geospatial function Call -> subquery.method_call */ - if(method->spatial_methods()) - { - size_t ind; - if (i == 0) ind = ctx->subquery()->stop->getStopIndex(); - else ind = method_calls[i-1]->stop->getStopIndex(); - rewrite_geospatial_query_helper(ctx, method, ind); - } - } - } + handleClrUdtFuncCall(ctx); } ////////////////////////////////////////////////////////////////////////////// @@ -2570,17 +2321,7 @@ class tsqlBuilder : public tsqlCommonMutator rewritten_query_fragment.emplace(std::make_pair(bctx->bif_no_brackets->getStartIndex(), std::make_pair(::getFullText(bctx->SESSION_USER()), "sys.session_user()"))); } - /* Handles rewrite of geospatial query except inside body of CREATE/ALTER View, Procedure, Function: */ - if (ctx->spatial_proc_name_server_database_schema()) - { - if (ctx->spatial_proc_name_server_database_schema()->schema) throw PGErrorWrapperException(ERROR, ERRCODE_FEATURE_NOT_SUPPORTED, "Remote procedure/function reference with 4-part object name is not currently supported in Babelfish", getLineAndPos(ctx)); - - /* This if-elseIf clause rewrites the query in case of Geospatial function Call */ - if (ctx->spatial_proc_name_server_database_schema()->geospatial_func_arg() && ctx->function_arg_list()) - rewrite_function_call_geospatial_func_ref_args(ctx); - else if (ctx->spatial_proc_name_server_database_schema()->geospatial_func_no_arg() && !ctx->function_arg_list()) - rewrite_function_call_geospatial_func_ref_no_arg(ctx); - } + handleGeospatialFunctionsInFunctionCall(ctx); /* analyze scalar function call */ if (ctx->func_proc_name_server_database_schema()) @@ -7970,42 +7711,31 @@ rewrite_geospatial_col_ref_query_helper(T ctx, TSqlParser::Method_callContext *m { std::vector keysToRemove; std::string ctx_str = ::getFullText(ctx); + ctx_str = ctx_str.substr(0, method->stop->getStopIndex() - ctx->start->getStartIndex() + 1); int func_call_len = (int)geospatial_start_index - ctx->start->getStartIndex(); int method_len = (int)method->stop->getStopIndex() - method->start->getStartIndex(); std::string expr = ""; int index = 0; int offset1 = 0; - int offset2 = 0; /* writting the previously rewritten Geospatial context */ for (auto &entry : rewritten_query_fragment) { - if(entry.first >= ctx->start->getStartIndex() && entry.first <= ctx->stop->getStopIndex()) + if(entry.first >= ctx->start->getStartIndex() && entry.first <= method->stop->getStopIndex()) { expr += ctx_str.substr(index, (int)entry.first - ctx->start->getStartIndex() - index) + entry.second.second; index = (int)entry.first - ctx->start->getStartIndex() + entry.second.first.size(); keysToRemove.push_back(entry.first); if(entry.first <= geospatial_start_index) offset1 += (int)entry.second.second.size() - entry.second.first.size(); - else offset2 += (int)entry.second.second.size() - entry.second.first.size(); } } for (const auto &key : keysToRemove) rewritten_query_fragment.erase(key); keysToRemove.clear(); - - /* shifting the local id positions to new positions after rewriting the query since they will be quoted later */ - for (auto &entry : local_id_positions) - { - if(entry.first >= ctx->start->getStartIndex() && entry.first <= geospatial_start_index) - { - keysToRemove.push_back(entry.first); - local_id_positions.emplace(std::make_pair(entry.first + 1, entry.second)); - } - } - for (const auto &key : keysToRemove) local_id_positions.erase(key); - keysToRemove.clear(); expr += ctx_str.substr(index); - std::string rewritten_exp = "(" + expr.substr(0, func_call_len + offset1 + 1) + ")." + expr.substr((int)method->start->getStartIndex() - ctx->start->getStartIndex() + offset1, method_len + offset2 + 1); - if ((int)method->stop->getStopIndex() - ctx->start->getStartIndex() + 1 < ctx_str.size()) rewritten_exp += expr.substr(method->stop->getStopIndex() + offset1 - ctx->start->getStartIndex() + 1); + + handleLocalIdQuotingFuncRefNoArg(ctx, geospatial_start_index, offset1, expr, keysToRemove); + + std::string rewritten_exp = "(" + expr.substr(0, func_call_len + offset1 + 1) + ")." + expr.substr((int)method->start->getStartIndex() - ctx->start->getStartIndex() + offset1, method_len + 1); rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(ctx_str.c_str(), rewritten_exp.c_str()))); } @@ -8019,42 +7749,31 @@ rewrite_geospatial_func_ref_no_arg_query_helper(T ctx, TSqlParser::Method_callCo { std::vector keysToRemove; std::string ctx_str = ::getFullText(ctx); + ctx_str = ctx_str.substr(0, method->stop->getStopIndex() - ctx->start->getStartIndex() + 1); int func_call_len = (int)geospatial_start_index - ctx->start->getStartIndex(); int method_len = (int)method->stop->getStopIndex() - method->start->getStartIndex(); std::string expr = ""; int index = 0; int offset1 = 0; - int offset2 = 0; /* writting the previously rewritten Geospatial context */ for (auto &entry : rewritten_query_fragment) { - if(entry.first >= ctx->start->getStartIndex() && entry.first <= ctx->stop->getStopIndex()) + if(entry.first >= ctx->start->getStartIndex() && entry.first <= method->stop->getStopIndex()) { expr += ctx_str.substr(index, (int)entry.first - ctx->start->getStartIndex() - index) + entry.second.second; index = (int)entry.first - ctx->start->getStartIndex() + entry.second.first.size(); keysToRemove.push_back(entry.first); if(entry.first <= geospatial_start_index) offset1 += (int)entry.second.second.size() - entry.second.first.size(); - else offset2 += (int)entry.second.second.size() - entry.second.first.size(); } } for (const auto &key : keysToRemove) rewritten_query_fragment.erase(key); keysToRemove.clear(); - - /* shifting the local id positions to new positions after rewriting the query since they will be quoted later */ - for (auto &entry : local_id_positions) - { - if(entry.first >= ctx->start->getStartIndex() && entry.first <= geospatial_start_index) - { - keysToRemove.push_back(entry.first); - local_id_positions.emplace(std::make_pair(entry.first + method->spatial_methods()->geospatial_func_no_arg()->stop->getStopIndex() - method->spatial_methods()->geospatial_func_no_arg()->start->getStartIndex() + 1, entry.second)); - } - } - for (const auto &key : keysToRemove) local_id_positions.erase(key); - keysToRemove.clear(); expr += ctx_str.substr(index); - std::string rewritten_exp = expr.substr((int)method->start->getStartIndex() - ctx->start->getStartIndex() + offset1, method_len + offset2) + expr.substr(0, func_call_len + offset1 + 1) + ")"; - if ((int)method->stop->getStopIndex() - ctx->start->getStartIndex() + 1 < ctx_str.size()) rewritten_exp += expr.substr(method->stop->getStopIndex() + offset1 - ctx->start->getStartIndex() + 1); + + handleLocalIdQuotingFuncRefNoArg(ctx, geospatial_start_index, offset1, expr, keysToRemove); + + std::string rewritten_exp = expr.substr((int)method->start->getStartIndex() - ctx->start->getStartIndex() + offset1, method_len) + expr.substr(0, func_call_len + offset1 + 1) + ")"; rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(ctx_str.c_str(), rewritten_exp.c_str()))); } @@ -8068,50 +7787,71 @@ rewrite_geospatial_func_ref_args_query_helper(T ctx, TSqlParser::Method_callCont { std::vector keysToRemove; std::string ctx_str = ::getFullText(ctx); + ctx_str = ctx_str.substr(0, method->stop->getStopIndex() - ctx->start->getStartIndex() + 1); int func_call_len = (int)geospatial_start_index - ctx->start->getStartIndex(); int method_len = (int)method->stop->getStopIndex() - method->start->getStartIndex(); std::string expr = ""; int index = 0; int offset1 = 0; int offset2 = 0; + std::vector> arg_offset_list; + int local_id_end_offset = 0; /* writting the previously rewritten Geospatial context */ for (auto &entry : rewritten_query_fragment) { - if(entry.first >= ctx->start->getStartIndex() && entry.first <= ctx->stop->getStopIndex()) + if(entry.first >= ctx->start->getStartIndex() && entry.first <= method->stop->getStopIndex()) { expr += ctx_str.substr(index, (int)entry.first - ctx->start->getStartIndex() - index) + entry.second.second; index = (int)entry.first - ctx->start->getStartIndex() + entry.second.first.size(); keysToRemove.push_back(entry.first); if(entry.first <= geospatial_start_index) offset1 += (int)entry.second.second.size() - entry.second.first.size(); - else offset2 += (int)entry.second.second.size() - entry.second.first.size(); + else if(entry.first > geospatial_start_index && entry.first <= method->stop->getStopIndex()) + { + offset2 += (int)entry.second.second.size() - entry.second.first.size(); + /* storing these values in a list so that we could correctly calculate the offset for local_id argument rewrites */ + arg_offset_list.push_back(std::make_pair((int)entry.first, (int)entry.second.second.size() - entry.second.first.size())); + } } } for (const auto &key : keysToRemove) rewritten_query_fragment.erase(key); keysToRemove.clear(); + expr += ctx_str.substr(index); - /* shifting the local id positions to new positions after rewriting the query since they will be quoted later */ + /* quoting local_id here so as to remove possibility of multiple rewrites in a single context */ for (auto &entry : local_id_positions) { if(entry.first >= ctx->start->getStartIndex() && entry.first <= geospatial_start_index) { - keysToRemove.push_back(entry.first); - local_id_positions.emplace(std::make_pair(entry.first + method->spatial_methods()->expression_list()->stop->getStopIndex() - method->spatial_methods()->geospatial_func_arg()->start->getStartIndex() + 2, entry.second)); + /* Here we are quoting local_id which are before the function name */ + int local_index = (int)entry.first - ctx->start->getStartIndex() + offset1; + if(expr.substr(local_index, entry.second.size()) == entry.second) + { + keysToRemove.push_back(entry.first); + expr = expr.substr(0, local_index) + "\"" + entry.second + "\"" + expr.substr(local_index + entry.second.size()); + offset1 += 2; + } } else if(entry.first >= method->spatial_methods()->expression_list()->start->getStartIndex() && entry.first <= method->spatial_methods()->expression_list()->stop->getStopIndex()) { - size_t pos = entry.first; - size_t offset = method->start->getStartIndex() - ctx->start->getStartIndex(); - pos -= offset; - keysToRemove.push_back(entry.first); - local_id_positions.emplace(std::make_pair(pos, entry.second)); + /* Here we are quoting local_id which are within the argument list of the function */ + int local_index = (int)entry.first - ctx->start->getStartIndex() + offset1 + local_id_end_offset; + for (size_t i = 0; i < arg_offset_list.size(); i++) + { + if((size_t)arg_offset_list[i].first < entry.first) local_index += arg_offset_list[i].second; + } + if(expr.substr(local_index, entry.second.size()) == entry.second) + { + keysToRemove.push_back(entry.first); + expr = expr.substr(0, local_index) + "\"" + entry.second + "\"" + expr.substr(local_index + entry.second.size()); + offset2 += 2; + local_id_end_offset += 2; + } } } for (const auto &key : keysToRemove) local_id_positions.erase(key); keysToRemove.clear(); - expr += ctx_str.substr(index); std::string rewritten_exp = expr.substr((int)method->start->getStartIndex() - ctx->start->getStartIndex() + offset1, method_len + offset2) + "," + expr.substr(0, func_call_len + offset1 + 1) + ")"; - if ((int)method->stop->getStopIndex() - ctx->start->getStartIndex() + 1 < ctx_str.size()) rewritten_exp += expr.substr(method->stop->getStopIndex() + offset1 - ctx->start->getStartIndex() + 1); rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(ctx_str.c_str(), rewritten_exp.c_str()))); } @@ -8181,6 +7921,8 @@ rewrite_function_call_geospatial_func_ref_args(T ctx) int index = 0; int offset1 = 0; int offset2 = 0; + std::vector> arg_offset_list; + int local_id_end_offset = 0; /* writting the previously rewritten Geospatial context */ for (auto &entry : rewritten_query_fragment) @@ -8191,22 +7933,36 @@ rewrite_function_call_geospatial_func_ref_args(T ctx) index = (int)entry.first - ctx->start->getStartIndex() + entry.second.first.size(); keysToRemove.push_back(entry.first); if(entry.first <= ctx->spatial_proc_name_server_database_schema()->column->stop->getStopIndex()) offset1 += (int)entry.second.second.size() - entry.second.first.size(); - else offset2 += (int)entry.second.second.size() - entry.second.first.size(); + else + { + offset2 += (int)entry.second.second.size() - entry.second.first.size(); + /* storing these values in a list so that we could correctly calculate the offset for local_id argument rewrites */ + arg_offset_list.push_back(std::make_pair((int)entry.first, (int)entry.second.second.size() - entry.second.first.size())); + } } } for (const auto &key : keysToRemove) rewritten_query_fragment.erase(key); keysToRemove.clear(); + expr += func_ctx.substr(index); - /* Shifting the local id positions to new positions after rewriting the query since they will be quoted later */ + /* quoting local_id here so as to remove possibility of multiple rewrites in a single context */ for (auto &entry : local_id_positions) { if(entry.first >= ctx->function_arg_list()->start->getStartIndex() && entry.first <= ctx->function_arg_list()->stop->getStopIndex()) { - size_t pos = entry.first; - size_t offset = ctx->spatial_proc_name_server_database_schema()->geospatial_func_arg()->start->getStartIndex() - ctx->spatial_proc_name_server_database_schema()->start->getStartIndex(); - pos -= offset; - keysToRemove.push_back(entry.first); - local_id_positions.emplace(std::make_pair(pos, entry.second)); + /* Here we are quoting local_id which are within the argument list of the function */ + int local_index = (int)entry.first - ctx->start->getStartIndex() + offset1 + local_id_end_offset; + for (size_t i = 0; i < arg_offset_list.size(); i++) + { + if((size_t)arg_offset_list[i].first < entry.first) local_index += arg_offset_list[i].second; + } + if(expr.substr(local_index, entry.second.size()) == entry.second) + { + keysToRemove.push_back(entry.first); + expr = expr.substr(0, local_index) + "\"" + entry.second + "\"" + expr.substr(local_index + entry.second.size()); + offset2 += 2; + local_id_end_offset += 2; + } } } for (const auto &key : keysToRemove) local_id_positions.erase(key); @@ -8215,7 +7971,6 @@ rewrite_function_call_geospatial_func_ref_args(T ctx) /* * Rewriting the query as: table.col.STDistance(arg) -> STDistance(arg, table.col) */ - expr += func_ctx.substr(index); std::string rewritten_func = expr.substr((int)ctx->spatial_proc_name_server_database_schema()->geospatial_func_arg()->start->getStartIndex() - ctx->start->getStartIndex() + offset1, method_len + offset2) + "," + expr.substr(0, col_len + offset1 + 1) + ")"; rewritten_query_fragment.emplace(std::make_pair(ctx->spatial_proc_name_server_database_schema()->start->getStartIndex(), std::make_pair(::getFullText(ctx), rewritten_func.c_str()))); } @@ -8224,6 +7979,108 @@ rewrite_function_call_geospatial_func_ref_args(T ctx) // End of Spatial Query Helper for Function Calls //////////////////////////////////////////////////////////////////////////////// +template +static void +handleLocalIdQuotingFuncRefNoArg(T ctx, size_t geospatial_start_index, int &offset1, std::string &expr, std::vector keysToRemove) +{ + /* quoting local_id here so as to remove possibility of multiple rewrites in a single context */ + for (auto &entry : local_id_positions) + { + if(entry.first >= ctx->start->getStartIndex() && entry.first <= geospatial_start_index) + { + /* Here we are quoting local_id which are before the function name */ + int local_index = (int)entry.first - ctx->start->getStartIndex() + offset1; + if(expr.substr(local_index, entry.second.size()) == entry.second) + { + keysToRemove.push_back(entry.first); + expr = expr.substr(0, local_index) + "\"" + entry.second + "\"" + expr.substr(local_index + entry.second.size()); + offset1 += 2; + } + } + } + for (const auto &key : keysToRemove) local_id_positions.erase(key); + keysToRemove.clear(); +} + +static void +handleGeospatialFunctionsInFunctionCall(TSqlParser::Function_callContext *ctx) +{ + /* Handles rewrite of geospatial function calls */ + if (ctx->spatial_proc_name_server_database_schema()) + { + if (ctx->spatial_proc_name_server_database_schema()->schema) throw PGErrorWrapperException(ERROR, ERRCODE_FEATURE_NOT_SUPPORTED, "Remote procedure/function reference with 4-part object name is not currently supported in Babelfish", getLineAndPos(ctx)); + + /* This if-elseIf clause rewrites the query in case of geospatial function calls */ + if (ctx->spatial_proc_name_server_database_schema()->geospatial_func_arg() && ctx->function_arg_list()) + rewrite_function_call_geospatial_func_ref_args(ctx); + else if (ctx->spatial_proc_name_server_database_schema()->geospatial_func_no_arg() && !ctx->function_arg_list()) + rewrite_function_call_geospatial_func_ref_no_arg(ctx); + } +} + +static void +handleClrUdtFuncCall(TSqlParser::Clr_udt_func_callContext *ctx) +{ + /* checking if CLR_UDT types */ + if(ctx != NULL && !ctx->DOT().empty()) + { + std::vector method_calls = ctx->method_call(); + for (size_t i = 0; i < method_calls.size(); ++i) + { + TSqlParser::Method_callContext *method = method_calls[i]; + /* rewriting the query in case of geospatial function calls */ + if(method->spatial_methods()) + { + size_t ind = -1; + if (i == 0) + { + if(ctx->local_id()) ind = ctx->local_id()->stop->getStopIndex(); + else if(ctx->subquery()) ind = ctx->subquery()->stop->getStopIndex(); + else if(ctx->function_call()) ind = ctx->function_call()->stop->getStopIndex(); + else if(ctx->RR_BRACKET()) ind = ctx->RR_BRACKET()->getSymbol()->getStartIndex(); + } + else ind = method_calls[i-1]->stop->getStopIndex(); + rewrite_geospatial_query_helper(ctx, method, ind); + } + } + } +} + +static void +handleFullColumnNameCtx(TSqlParser::Full_column_nameContext *ctx) +{ + GetCtxFunc getSchema = [](TSqlParser::Full_column_nameContext *o) { return o->schema; }; + GetCtxFunc getTablename = [](TSqlParser::Full_column_nameContext *o) { return o->tablename; }; + + std::string func_name; + /* Handles rewrite of geospatial query */ + if(ctx->column_name) func_name = stripQuoteFromId(ctx->column_name); + else if (ctx->geospatial_col()) + { + /* Throwing error similar to TSQL as we do not allow 4-Part name for geospatial function calls */ + if(ctx->schema) throw PGErrorWrapperException(ERROR, ERRCODE_SYNTAX_ERROR, format_errmsg("The multi-part identifier \"%s\" could not be bound.", ::getFullText(ctx).c_str()), getLineAndPos(ctx)); + + /* Rewriting the query as: table.col.STX -> (table.col).STX */ + std::string ctx_str = ::getFullText(ctx); + std::string rewritten_func_name = "(" + ctx_str.substr(0, ctx->column->stop->getStopIndex() - ctx->start->getStartIndex() + 1) + ")." + ctx_str.substr(ctx->geospatial_col()->start->getStartIndex() - ctx->start->getStartIndex()); + rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(ctx_str, rewritten_func_name.c_str()))); + } + + std::string rewritten_name = rewrite_column_name_with_omitted_schema_name(ctx, getSchema, getTablename); + std::string rewritten_schema_name = rewrite_information_schema_to_information_schema_tsql(ctx, getSchema); + if (!rewritten_name.empty()) + rewritten_query_fragment.emplace(std::make_pair(ctx->start->getStartIndex(), std::make_pair(::getFullText(ctx), rewritten_name))); + if (pltsql_enable_tsql_information_schema && !rewritten_schema_name.empty()) + rewritten_query_fragment.emplace(std::make_pair(ctx->schema->start->getStartIndex(), std::make_pair(::getFullText(ctx->schema), rewritten_schema_name))); + + if (does_object_name_need_delimiter(ctx->tablename)) + rewritten_query_fragment.emplace(std::make_pair(ctx->tablename->start->getStartIndex(), std::make_pair(::getFullText(ctx->tablename), delimit_identifier(ctx->tablename)))); + + // qualified identifier doesn't need delimiter + if (ctx->DOT().empty() && does_object_name_need_delimiter(ctx->column_name)) + rewritten_query_fragment.emplace(std::make_pair(ctx->column_name->start->getStartIndex(), std::make_pair(::getFullText(ctx->column_name), delimit_identifier(ctx->column_name)))); +} + static bool does_object_name_need_delimiter(TSqlParser::IdContext *id) {