Skip to content

Commit

Permalink
TVP as an argument of a procedure does not require TypeName to be spe…
Browse files Browse the repository at this point in the history
…cified (#3288)

The type name of the specified table-valued parameter as an argument for a procedure should be optional. Currently, Babelfish does not support this since we need explicit type to create the temp table.

We are fixing it by doing a look up of the typename of the TVP argument based on procedure name, namespace, db and role and setting the same.

Issues Resolved: BABEL-5311

Signed-off-by: Tanya Gupta <[email protected]>
  • Loading branch information
tanyagupta17 authored Dec 20, 2024
1 parent f0b6706 commit 0eac028
Show file tree
Hide file tree
Showing 14 changed files with 504 additions and 103 deletions.
2 changes: 1 addition & 1 deletion contrib/babelfishpg_tds/src/backend/tds/tdsrpc.c
Original file line number Diff line number Diff line change
Expand Up @@ -1587,7 +1587,7 @@ ReadParameters(TDSRequestSP request, uint64_t offset, StringInfo message, int *p
* Sets the col metadata and also the corresponding row
* data.
*/
SetColMetadataForTvp(temp, message, &offset);
SetColMetadataForTvp(temp, message, &offset, request->name.data);
}
break;
case TDS_TYPE_BINARY:
Expand Down
48 changes: 31 additions & 17 deletions contrib/babelfishpg_tds/src/include/tds_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "src/include/tds_typeio.h"
#include "src/collation.h"


/* Different TDS request types returned by GetTDSRequest() */
typedef enum TDSRequestType
{
Expand Down Expand Up @@ -498,16 +499,16 @@ SetTvpRowData(ParameterToken temp, const StringInfo message, uint64_t *offset)
}

static inline void
SetColMetadataForTvp(ParameterToken temp, const StringInfo message, uint64_t *offset)
SetColMetadataForTvp(ParameterToken temp, const StringInfo message, uint64_t *offset, char *proc_name)
{
uint8_t len;
uint16 colCount;
uint16 isTvpNull;
char *tempString;
int i = 0;
char *messageData = message->data;
StringInfo tempStringInfo = palloc(sizeof(StringInfoData));
uint32_t collation;
uint8_t len;
uint16 colCount;
uint16 isTvpNull;
char *tempString;
int i = 0;
char *messageData = message->data;
StringInfo tempStringInfo = palloc(sizeof(StringInfoData));
uint32_t collation;

/* Database-Name.Schema-Name.TableType-Name */
for (; i < 3; i++)
Expand Down Expand Up @@ -536,21 +537,34 @@ SetColMetadataForTvp(ParameterToken temp, const StringInfo message, uint64_t *of
if (i == 1)
temp->tvpInfo->tvpTypeSchemaName = tempStringInfo->data;
else
{
temp->tvpInfo->tvpTypeName = tempStringInfo->data;

temp->tvpInfo->tableName = tempStringInfo->data;
}
}
else if (i == 2)
{
/* Throw error if TabelType-Name is not provided */
ereport(ERROR,
(errcode(ERRCODE_PROTOCOL_VIOLATION),
errmsg("The incoming tabular data stream (TDS) remote procedure call (RPC) protocol stream is incorrect. "
"Table-valued parameter %d, to a parameterized string has no table type defined.",
temp->paramOrdinal + 1)));
char *tvp_type_name;
char *tvp_type_schema_name;
/*
* Fetch the TVP typeName and schemaName from catalog search
* based on object name and argument name.
*/
pltsql_plugin_handler_ptr->get_tvp_typename_typeschemaname(proc_name,
temp->paramMeta.colName.data,
&tvp_type_name,
&tvp_type_schema_name);
temp->len += strlen(tvp_type_schema_name);
temp->tvpInfo->tvpTypeSchemaName = pstrdup(tvp_type_schema_name);

pfree(tvp_type_schema_name);

temp->len += strlen(tvp_type_name);
temp->tvpInfo->tvpTypeName = tvp_type_name;
temp->tvpInfo->tableName = tvp_type_name;
}
}

temp->tvpInfo->tableName = tempStringInfo->data;
i = 0;

memcpy(&isTvpNull, &messageData[*offset], sizeof(uint16));
Expand Down
68 changes: 14 additions & 54 deletions contrib/babelfishpg_tsql/runtime/functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -1935,7 +1935,6 @@ object_id(PG_FUNCTION_ARGS)
char *physical_schema_name;
char *input;
char *object_type = NULL;
char **splited_object_name;
Oid schema_oid;
Oid user_id = GetUserId();
Oid result = InvalidOid;
Expand Down Expand Up @@ -1975,31 +1974,13 @@ object_id(PG_FUNCTION_ARGS)
(errcode(ERRCODE_STRING_DATA_LENGTH_MISMATCH),
errmsg("input value is too long for object name")));

/* resolve the three part name */
splited_object_name = split_object_name(input);
db_name = splited_object_name[1];
schema_name = splited_object_name[2];
object_name = splited_object_name[3];

/* downcase identifier if needed */
if (pltsql_case_insensitive_identifiers)
{
db_name = downcase_identifier(db_name, strlen(db_name), false, false);
schema_name = downcase_identifier(schema_name, strlen(schema_name), false, false);
object_name = downcase_identifier(object_name, strlen(object_name), false, false);
for (int i = 0; i < 4; i++)
pfree(splited_object_name[i]);
}
else
pfree(splited_object_name[0]);
/*
* Split the input string, downcase and truncate if needed
* and return the db_name, schema_name and object_name.
*/
downcase_truncate_split_object_name(input, NULL, &db_name, &schema_name, &object_name);

pfree(input);
pfree(splited_object_name);

/* truncate identifiers if needed */
truncate_tsql_identifier(db_name);
truncate_tsql_identifier(schema_name);
truncate_tsql_identifier(object_name);

if (!strcmp(db_name, ""))
db_name = get_cur_db_name();
Expand Down Expand Up @@ -2417,7 +2398,6 @@ type_id(PG_FUNCTION_ARGS)
*object_name;
char *physical_schema_name;
char *input;
char **splitted_object_name;
Oid schema_oid = InvalidOid;
Oid user_id = GetUserId();
Oid result = InvalidOid;
Expand All @@ -2442,40 +2422,20 @@ type_id(PG_FUNCTION_ARGS)
(errcode(ERRCODE_STRING_DATA_LENGTH_MISMATCH),
errmsg("input value is too long for object name")));

/* resolve the two part name */
splitted_object_name = split_object_name(input);
/*
* Split the input string, downcase and truncate if needed
* and return the db_name, schema_name and object_name.
*/
downcase_truncate_split_object_name(input, NULL, &db_name, &schema_name, &object_name);

pfree(input);

/* If three part name(db_name also included in input) then return null */
if(pg_mbstrlen(splitted_object_name[1]) != 0)
if(pg_mbstrlen(db_name) != 0)
{
pfree(input);
for (int i = 0; i < 4; i++)
pfree(splitted_object_name[i]);
pfree(splitted_object_name);
PG_RETURN_NULL();
}
db_name = get_cur_db_name();
schema_name = splitted_object_name[2];
object_name = splitted_object_name[3];

/* downcase identifier if needed */
if (pltsql_case_insensitive_identifiers)
{
db_name = downcase_identifier(db_name, strlen(db_name), false, false);
schema_name = downcase_identifier(schema_name, strlen(schema_name), false, false);
object_name = downcase_identifier(object_name, strlen(object_name), false, false);
for (int i = 0; i < 4; i++)
pfree(splitted_object_name[i]);
}
else
pfree(splitted_object_name[0]);

pfree(input);
pfree(splitted_object_name);

/* truncate identifiers if needed */
truncate_tsql_identifier(db_name);
truncate_tsql_identifier(schema_name);
truncate_tsql_identifier(object_name);

if (!strcmp(schema_name, ""))
{
Expand Down
153 changes: 153 additions & 0 deletions contrib/babelfishpg_tsql/src/catalog.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
#include "catalog/pg_proc.h"
#include "catalog/pg_foreign_server.h"
#include "catalog/namespace.h"
#include "catalog/pg_type.h"
#include "commands/extension.h"
#include "parser/parse_relation.h"
#include "parser/scansup.h"
#include "tcop/utility.h"
#include "utils/acl.h"
#include "utils/builtins.h"
#include "utils/catcache.h"
#include "utils/fmgroids.h"
#include "utils/formatting.h"
#include "utils/lsyscache.h"
Expand Down Expand Up @@ -4108,3 +4111,153 @@ user_exists_for_db(const char *db_name, const char *user_name)
return user_exists;
}

/*
* get_proc_namespace_oid:
* Find namespace oid of a procedure based on proc name.
*/
static Oid
get_proc_namespace_oid(char **proc_name, char *curr_db)
{
char *physical_sch_name;
char *db_name;
char *schema_name;
char *object_name;
Oid obj_schema_oid = InvalidOid;

if (*proc_name == NULL)
ereport(ERROR,
(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
errmsg("procedure name cannot be NULL")));

/*
* Split the proc name, downcase and truncate if needed
* and return the db_name, schema_name and object_name.
*/
downcase_truncate_split_object_name(*proc_name, NULL, &db_name, &schema_name, &object_name);
*proc_name = object_name;

if (!strcmp(db_name, ""))
db_name = curr_db;

if (!strcmp(schema_name, ""))
{
/* Find the default schema for current user. */
char *user = get_user_for_database(db_name);
schema_name = get_authid_user_ext_schema_name((const char *) db_name, (const char *) user);
}

/* Get physical schema name from logical schema name. */
physical_sch_name = get_physical_schema_name(db_name, schema_name);
/* Get namespace oid from physical schema name. */
obj_schema_oid = get_namespace_oid(physical_sch_name, false);

pfree(db_name);
pfree(schema_name);
pfree(physical_sch_name);

return obj_schema_oid;
}

/*
* get_proargtypes_oid:
* Given a procedure name, namespace, user ID, and target argument name
* return the OID of the argument's type in the procedure.
*
* Returns InvalidOid if no matching procedure argument is found.
*/
static Oid
get_proargtypes_oid(char *proname, Oid pronamespace, Oid user_id, char *targeted_arg_name)
{
HeapTuple tuple;
CatCList *catlist;
Oid matched_type = InvalidOid;

/* Downcase and truncate identifier if needed. */
targeted_arg_name = downcase_truncate_identifier(targeted_arg_name, strlen(targeted_arg_name), true);

/* First search in pg_proc by name. */
catlist = SearchSysCacheList1(PROCNAMEARGSNSP, CStringGetDatum(proname));

for (int i = 0; i < catlist->n_members; i++)
{
Form_pg_proc procform;

tuple = &catlist->members[i]->tuple;
procform = (Form_pg_proc) GETSTRUCT(tuple);

/* Then consider only procs in specified namespace. */
if (procform->pronamespace == pronamespace &&
pg_proc_aclcheck(procform->oid, user_id, ACL_EXECUTE) == ACLCHECK_OK)
{
/* Get the list of proargames and corresponding proargtypes oids. */
char **proargnames = fetch_func_input_arg_names(tuple);
Oid *proargtypes = procform->proargtypes.values;

/* Find the typeoid corresponding to target TVP argument. */
for (int j = 0; j < procform->pronargs; j++)
{
if (strcmp(proargnames[j], targeted_arg_name) == 0)
{
matched_type = proargtypes[j];
break;
}
}
}
}
ReleaseSysCacheList(catlist);
if (matched_type == InvalidOid)
{
ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_FUNCTION),
errmsg("No procedure found with name \"%s\" that has an argument named \"%s\"",
proname, targeted_arg_name)));
}
return matched_type;
}

/*
* get_tvp_typename_typeschemaname:
* Retrieves the type name and schema name of a Table-Valued Parameter (TVP)
* for a given stored procedure and argument name.
*/
void
get_tvp_typename_typeschemaname(char *proc_name, char *target_arg_name, char **tvp_type_name, char **tvp_type_schema_name)
{
bool xactStarted = IsTransactionOrTransactionBlock();
Oid tvp_proargtype = InvalidOid;
Oid user_id = InvalidOid;
Oid obj_schema_oid = InvalidOid;
HeapTuple tuple;
char *typnamespace;
char *curr_db;
MemoryContext oldContext;

if (!xactStarted)
StartTransactionCommand();
user_id = GetUserId();
curr_db = get_cur_db_name();

/* Get procedure namespaceid. */
obj_schema_oid = get_proc_namespace_oid(&proc_name, curr_db);

/* Fetch proargtype value of our targeted variable. */
tvp_proargtype = get_proargtypes_oid(proc_name, obj_schema_oid, user_id, target_arg_name);

/* Search in pg_type by object_id and fetch tvpTypeName and tvpTypeSchemaName. */
tuple = SearchSysCache1(TYPEOID, ObjectIdGetDatum(tvp_proargtype));
/* Check if user have right permission on object. */
if (HeapTupleIsValid(tuple) && pg_type_aclcheck(tvp_proargtype, user_id, ACL_USAGE) == ACLCHECK_OK)
{
Form_pg_type pg_type = (Form_pg_type) GETSTRUCT(tuple);
*tvp_type_name = NameStr(pg_type->typname);
typnamespace = get_namespace_name(pg_type->typnamespace);

oldContext = MemoryContextSwitchTo(TopMemoryContext);
*tvp_type_schema_name = pstrdup((char *) get_logical_schema_name(typnamespace, true));
MemoryContextSwitchTo(oldContext);
ReleaseSysCache(tuple);
}

if(!xactStarted)
CommitTransactionCommand();
}
2 changes: 2 additions & 0 deletions contrib/babelfishpg_tsql/src/catalog.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,4 +459,6 @@ typedef struct Rule
RelData *tbldata; /* extra catalog info */
} Rule;

extern void get_tvp_typename_typeschemaname(char *proc_name, char *target_arg_name, char **tvp_type_name, char **tvp_type_schema_name);

#endif
4 changes: 2 additions & 2 deletions contrib/babelfishpg_tsql/src/hooks.c
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ static bool match_pltsql_func_call(HeapTuple proctup, int nargs, List *argnames,
static ObjectAddress get_trigger_object_address(List *object, Relation *relp, bool missing_ok, bool object_from_input);
Oid get_tsql_trigger_oid(List *object, const char *tsql_trigger_name, bool object_from_input);
static Node *transform_like_in_add_constraint(Node *node);
static char** fetch_func_input_arg_names(HeapTuple func_tuple);
char** fetch_func_input_arg_names(HeapTuple func_tuple);

/*****************************************
* Analyzer Hooks
Expand Down Expand Up @@ -3930,7 +3930,7 @@ static int getDefaultPosition(const List *default_positions, const ListCell *def
* @param func_tuple or proc_tuple
* @return char** list of input arg names
*/
static char** fetch_func_input_arg_names(HeapTuple func_tuple)
char** fetch_func_input_arg_names(HeapTuple func_tuple)
{
Datum proargnames;
Datum proargmodes;
Expand Down
1 change: 1 addition & 0 deletions contrib/babelfishpg_tsql/src/hooks.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ extern Oid get_tsql_trigger_oid(List *object,
bool object_from_input);
extern void pltsql_bbfSelectIntoUtility(ParseState *pstate, PlannedStmt *pstmt, const char *queryString,
QueryEnvironment *queryEnv, ParamListInfo params, QueryCompletion *qc, ObjectAddress *address);
extern char** fetch_func_input_arg_names(HeapTuple func_tuple);

extern char *update_delete_target_alias;
extern bool sp_describe_first_result_set_inprogress;
Expand Down
1 change: 1 addition & 0 deletions contrib/babelfishpg_tsql/src/pl_handler.c
Original file line number Diff line number Diff line change
Expand Up @@ -4738,6 +4738,7 @@ _PG_init(void)
(*pltsql_protocol_plugin_ptr)->tsql_char_input = common_utility_plugin_ptr->tsql_bpchar_input;
(*pltsql_protocol_plugin_ptr)->get_cur_db_name = &get_cur_db_name;
(*pltsql_protocol_plugin_ptr)->get_physical_schema_name = &get_physical_schema_name;
(*pltsql_protocol_plugin_ptr)->get_tvp_typename_typeschemaname = &get_tvp_typename_typeschemaname;

(*pltsql_protocol_plugin_ptr)->quoted_identifier = pltsql_quoted_identifier;
(*pltsql_protocol_plugin_ptr)->arithabort = pltsql_arithabort;
Expand Down
Loading

0 comments on commit 0eac028

Please sign in to comment.