Skip to content

Commit

Permalink
addressed comments
Browse files Browse the repository at this point in the history
Signed-off-by: yashneet vinayak <[email protected]>
  • Loading branch information
yashneet vinayak committed Dec 29, 2024
1 parent 30269cd commit 1bd0b82
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 57 deletions.
1 change: 0 additions & 1 deletion contrib/babelfishpg_tds/src/backend/tds/tds_srv.c
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ pe_tds_init(void)
pltsql_plugin_handler_ptr->get_reset_tds_connection_flag = &GetResetTDSConnectionFlag;
pltsql_plugin_handler_ptr->get_numeric_typmod_from_exp = &resolve_numeric_typmod_from_exp;


invalidate_stat_table_hook = invalidate_stat_table;
guc_newval_hook = TdsSetGucStatVariable;

Expand Down
81 changes: 32 additions & 49 deletions contrib/babelfishpg_tds/src/backend/tds/tdsresponse.c
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,6 @@
#define NUMERIC_UPLUS_OID 1915
#define NUMERIC_UMINUS_OID 1771

/* These OIDs for numeric type casts and conversions used in is_numeric_cast */
#define INT4_NUMERIC 1740
#define INT8_NUMERIC 1781
#define INT2_NUMERIC 1782
#define FLOAT4_NUMERIC 1742
#define FLOAT8_NUMERIC 1743
#define INT48 481
#define INT84 480
#define INT28 754
#define INT82 714
#define I2TOI4 313
#define I4TOI2 314

#define Max(x, y) ((x) > (y) ? (x) : (y))
#define Min(x, y) ((x) < (y) ? (x) : (y))
#define ROWVERSION_SIZE 8
Expand Down Expand Up @@ -151,11 +138,13 @@ static Oid tsql_bit_numeric_oid = InvalidOid;
static Oid tsql_int4_bit_oid = InvalidOid;
static Oid tsql_trunc_numeric_to_int2_oid = InvalidOid;
static Oid tsql_trunc_numeric_to_int8_oid = InvalidOid;
static Oid sys_nspoid = InvalidOid;

static void FillTabNameWithNumParts(StringInfo buf, uint8 numParts, TdsRelationMetaDataInfo relMetaDataInfo);
static void FillTabNameWithoutNumParts(StringInfo buf, uint8 numParts, TdsRelationMetaDataInfo relMetaDataInfo);
static void SetTdsEstateErrorData(void);
static void ResetTdsEstateErrorData(void);
static bool is_numeric_cast(Oid func_oid);
static void SetAttributesForColmetada(TdsColumnMetaData *col);
static int32 resolve_numeric_typmod_outer_var(Plan *plan, AttrNumber attno);
static bool is_this_a_vector_datatype(Oid oid);
Expand Down Expand Up @@ -545,18 +534,13 @@ static bool is_tsql_bit_numeric(Oid oid)
{
if (tsql_bit_numeric_oid == InvalidOid)
{
Oid nspoid;
Oid typoid;
Oid funcargtypes[1];

nspoid = get_namespace_oid("sys", true);
if (nspoid == InvalidOid)
return InvalidOid;

typoid = GetSysCacheOid2(TYPENAMENSP, Anum_pg_type_oid, CStringGetDatum("bit"), ObjectIdGetDatum(nspoid));
typoid = GetSysCacheOid2(TYPENAMENSP, Anum_pg_type_oid, CStringGetDatum("bit"), ObjectIdGetDatum(sys_nspoid));

funcargtypes[0] = typoid;
tsql_bit_numeric_oid = LookupFuncName(list_make2(makeString("sys"), makeString("bit2numeric")), 1, funcargtypes, true);
tsql_bit_numeric_oid = LookupFuncName(list_make2(makeString("sys"), makeString("bit2numeric")), 1, funcargtypes, false);
}
return tsql_bit_numeric_oid == oid;
}
Expand All @@ -565,18 +549,13 @@ static bool is_tsql_fixeddecimal_numeric(Oid oid)
{
if (tsql_fixeddecimal_numeric_oid == InvalidOid)
{
Oid nspoid;
Oid typoid;
Oid funcargtypes[1];

nspoid = get_namespace_oid("sys", true);
if (nspoid == InvalidOid)
return InvalidOid;

typoid = GetSysCacheOid2(TYPENAMENSP, Anum_pg_type_oid, CStringGetDatum("fixeddecimal"), ObjectIdGetDatum(nspoid));
typoid = GetSysCacheOid2(TYPENAMENSP, Anum_pg_type_oid, CStringGetDatum("fixeddecimal"), ObjectIdGetDatum(sys_nspoid));

funcargtypes[0] = typoid;
tsql_fixeddecimal_numeric_oid = LookupFuncName(list_make2(makeString("sys"), makeString("fixeddecimal_numeric")), 1, funcargtypes, true);
tsql_fixeddecimal_numeric_oid = LookupFuncName(list_make2(makeString("sys"), makeString("fixeddecimal_numeric")), 1, funcargtypes, false);
}
return tsql_fixeddecimal_numeric_oid == oid;
}
Expand All @@ -585,31 +564,31 @@ static bool is_tsql_numeric_fixeddecimal(Oid oid)
{
Oid funcargtypes[1] = {NUMERICOID};
if (tsql_numeric_fixeddecimal_oid == InvalidOid)
tsql_numeric_fixeddecimal_oid = LookupFuncName(list_make2(makeString("sys"), makeString("numeric_fixeddecimal")), -1, funcargtypes, true);
tsql_numeric_fixeddecimal_oid = LookupFuncName(list_make2(makeString("sys"), makeString("numeric_fixeddecimal")), -1, funcargtypes, false);
return tsql_numeric_fixeddecimal_oid == oid;
}

static bool is_tsql_int4_bit(Oid oid)
{
Oid funcargtypes[1] = {INT4OID};
if (tsql_int4_bit_oid == InvalidOid)
tsql_int4_bit_oid = LookupFuncName(list_make2(makeString("sys"), makeString("int4bit")), -1, funcargtypes, true);
tsql_int4_bit_oid = LookupFuncName(list_make2(makeString("sys"), makeString("int4bit")), -1, funcargtypes, false);
return tsql_int4_bit_oid == oid;
}

static bool is_tsql_trunc_numeric_to_int2(Oid oid)
{
Oid funcargtypes[1] = {NUMERICOID};
if (tsql_trunc_numeric_to_int2_oid == InvalidOid)
tsql_trunc_numeric_to_int2_oid = LookupFuncName(list_make2(makeString("sys"), makeString("_trunc_numeric_to_int2")), -1, funcargtypes, true);
tsql_trunc_numeric_to_int2_oid = LookupFuncName(list_make2(makeString("sys"), makeString("_trunc_numeric_to_int2")), -1, funcargtypes, false);
return tsql_trunc_numeric_to_int2_oid == oid;
}

static bool is_tsql_trunc_numeric_to_int8(Oid oid)
{
Oid funcargtypes[1] = {NUMERICOID};
if (tsql_trunc_numeric_to_int8_oid == InvalidOid)
tsql_trunc_numeric_to_int8_oid = LookupFuncName(list_make2(makeString("sys"), makeString("_trunc_numeric_to_int8")), -1, funcargtypes, true);
tsql_trunc_numeric_to_int8_oid = LookupFuncName(list_make2(makeString("sys"), makeString("_trunc_numeric_to_int8")), -1, funcargtypes, false);
return tsql_trunc_numeric_to_int8_oid == oid;
}

Expand All @@ -619,20 +598,22 @@ static bool is_tsql_trunc_numeric_to_int8(Oid oid)
* if resolve_numeric_typmod_from_exp should be called recursively.
* This ensures proper typmod resolution for nested numeric conversions.
*/
bool
static bool
is_numeric_cast(Oid func_oid)
{
if (func_oid == INT4_NUMERIC ||
func_oid == INT8_NUMERIC ||
func_oid == INT2_NUMERIC ||
func_oid == FLOAT4_NUMERIC ||
func_oid == FLOAT8_NUMERIC ||
func_oid == INT48 ||
func_oid == INT84 ||
func_oid == INT28 ||
func_oid == INT82 ||
func_oid == I2TOI4 ||
func_oid == I4TOI2 ||
sys_nspoid = get_namespace_oid("sys", false);

if (func_oid == F_NUMERIC_INT4 ||
func_oid == F_NUMERIC_INT8 ||
func_oid == F_NUMERIC_INT2 ||
func_oid == F_NUMERIC_FLOAT4 ||
func_oid == F_NUMERIC_FLOAT8 ||
func_oid == F_INT8_INT4 ||
func_oid == F_INT4_INT8 ||
func_oid == F_INT8_INT2 ||
func_oid == F_INT2_INT8 ||
func_oid == F_INT4_INT2 ||
func_oid == F_INT2_INT4 ||
is_tsql_bit_numeric(func_oid) ||
is_tsql_int4_bit(func_oid) ||
is_tsql_fixeddecimal_numeric(func_oid) ||
Expand Down Expand Up @@ -704,12 +685,14 @@ resolve_numeric_typmod_from_exp(Plan *plan, Node *expr)
* the appropriate typmod. This process ensures correct
* numeric precision handling in Babelfish TSQL operations.
*/
if (!plan && con->consttype == INT4OID)
if (plan == NULL && con->consttype == INT4OID)
{
val = con->constvalue;
num = int64_to_numeric(val);
return numeric_get_typmod(num);
}
else if (plan != NULL && con->consttype == INT4OID)
return -1;
num = (Numeric) con->constvalue;
return numeric_get_typmod(num);
}
Expand All @@ -719,7 +702,7 @@ resolve_numeric_typmod_from_exp(Plan *plan, Node *expr)
Var *var = (Var *) expr;

/* If this var referes to tuple returned by its outer plan then find the original tle from it */
if (plan && var->varno == OUTER_VAR)
if (plan != NULL && var->varno == OUTER_VAR)
{
Assert(plan);
return (resolve_numeric_typmod_outer_var(plan, var->varattno));
Expand Down Expand Up @@ -929,11 +912,11 @@ resolve_numeric_typmod_from_exp(Plan *plan, Node *expr)
func->funcresulttype);

/*
* !plan means we are not invoking resolve_numeric_typmod_from_exp
* from tds side. Here we make resursive call so as to calculate
* typmod from other base nodes in parse tree.
* 1) plan == NULL means we are invoking this function from babelfishtsql_extension.
* 2) rettypmod == -1 means unable to find typmod till now.
* 3) check if only one args and then is that castable to numeric.
*/
if (!plan &&
if (plan == NULL &&
rettypmod == -1 &&
list_length(func->args) == 1 &&
is_numeric_cast(func_oid))
Expand Down
1 change: 0 additions & 1 deletion contrib/babelfishpg_tds/src/include/tds_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,5 @@ extern void TDSStatementExceptionCallback(PLtsql_execstate *estate, PLtsql_stmt
extern void SendColumnMetadata(TupleDesc typeinfo, List *targetlist, int16 *formats);
extern bool GetTdsEstateErrorData(int *number, int *severity, int *state);
extern int32 resolve_numeric_typmod_from_exp(Plan *plan, Node *expr);
extern bool is_numeric_cast(Oid func_oid);

#endif /* TDS_H */
1 change: 0 additions & 1 deletion contrib/babelfishpg_tsql/src/pltsql.h
Original file line number Diff line number Diff line change
Expand Up @@ -1804,7 +1804,6 @@ typedef struct PLtsql_protocol_plugin
void (*get_tvp_typename_typeschemaname) (char *proc_name, char *target_arg_name,
char **tvp_type_name, char **tvp_type_schema_name);
int32 (*get_numeric_typmod_from_exp) (Plan *plan, Node *expr);

/* Session level GUCs */
bool quoted_identifier;
bool arithabort;
Expand Down
5 changes: 0 additions & 5 deletions contrib/babelfishpg_tsql/src/pltsql_coerce.c
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
#define BPCHAR_MAX_TYPMOD 8000

#define TDS_MAX_NUM_PRECISION 38

/* Hooks for engine*/
extern find_coercion_pathway_hook_type find_coercion_pathway_hook;
extern determine_datatype_precedence_hook_type determine_datatype_precedence_hook;
Expand Down Expand Up @@ -2043,7 +2042,6 @@ tsql_select_common_typmod_hook(ParseState *pstate, List *exprs, Oid common_type)
int32 typmod = exprTypmod(expr);
Oid type = exprType(expr);
Oid immediate_base_type = get_immediate_base_type_of_UDT_internal(type);


if (common_type == NUMERICOID ||
getBaseType(common_type) == NUMERICOID)
Expand Down Expand Up @@ -2138,15 +2136,12 @@ tsql_select_common_typmod_hook(ParseState *pstate, List *exprs, Oid common_type)
else
max_typmods = Max(max_typmods, typmod);
}

}

if (common_type == NUMERICOID || getBaseType(common_type) == NUMERICOID)
return numeric_result_typmod;

return max_typmods;


}

/*
Expand Down

0 comments on commit 1bd0b82

Please sign in to comment.