Skip to content

Commit

Permalink
Fix hashbytes behaviour for nvarchar input (#3259)
Browse files Browse the repository at this point in the history
PROBLEM: while having nvarchar input in hashbytes we were considering the encoding as input encoding in babelfish where as in tsql we use UTF16 encoding for nvarchar irrespective of input encoding.

RCA: we were considering varchar and nvarchar as same, whereas we should use input encoding for varchar and UTF16 encoding for nvarchar.

FIX: So we need to identify that if the input is nvarchar then we will do the UTF16 encoding.
In the function hashbytes, we were converting varchar/nvarchar to varbinary, so if we get input as nvarchar then we will first convert the input string to UTF16 encoding then convert to varbinary.
There we created the overloaded function and also select the best candidate among depending on the second argument(i.e varchar/nvarchar/varbinary).

Task: BABEL-4891
Signed-off-by: pranav jain <[email protected]>
  • Loading branch information
pranavJ23 authored Jan 7, 2025
1 parent cbe8fc1 commit be9f87c
Show file tree
Hide file tree
Showing 17 changed files with 876 additions and 346 deletions.
1 change: 1 addition & 0 deletions contrib/babelfishpg_common/src/babelfishpg_common.c
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ get_common_utility_plugin(void)
common_utility_plugin_var.resolve_pg_type_to_tsql = &resolve_pg_type_to_tsql;
common_utility_plugin_var.GetUTF8CodePoint = &GetUTF8CodePoint;
common_utility_plugin_var.TsqlUTF8LengthInUTF16 = &TsqlUTF8LengthInUTF16;
common_utility_plugin_var.TsqlUTF8toUTF16StringInfo = &TsqlUTF8toUTF16StringInfo;
}
return &common_utility_plugin_var;
}
1 change: 1 addition & 0 deletions contrib/babelfishpg_common/src/babelfishpg_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,5 @@ typedef struct common_utility_plugin
const char *(*resolve_pg_type_to_tsql) (Oid oid);
int32_t (*GetUTF8CodePoint) (const unsigned char *in, int len, int *consumed_p);
int (*TsqlUTF8LengthInUTF16) (const void *vin, int len);
void (*TsqlUTF8toUTF16StringInfo) (StringInfo utf16_data, const void *data, size_t len);
} common_utility_plugin;
12 changes: 8 additions & 4 deletions contrib/babelfishpg_tsql/sql/datatype_string_operators.sql
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
CREATE OR REPLACE FUNCTION sys.hashbytes(IN alg VARCHAR, IN data VARCHAR) RETURNS sys.bbf_varbinary
CREATE OR REPLACE FUNCTION sys.hashbytes(IN alg sys.VARCHAR, IN data sys.VARCHAR) RETURNS sys.bbf_varbinary
AS 'babelfishpg_tsql', 'hashbytes' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
GRANT EXECUTE ON FUNCTION sys.hashbytes(IN VARCHAR, IN VARCHAR) TO PUBLIC;
GRANT EXECUTE ON FUNCTION sys.hashbytes(IN alg sys.VARCHAR, IN sys.VARCHAR) TO PUBLIC;

CREATE OR REPLACE FUNCTION sys.hashbytes(IN alg VARCHAR, IN data sys.bbf_varbinary) RETURNS sys.bbf_varbinary
CREATE OR REPLACE FUNCTION sys.hashbytes(IN alg sys.VARCHAR, IN data sys.NVARCHAR) RETURNS sys.bbf_varbinary
AS 'babelfishpg_tsql', 'hashbytes' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
GRANT EXECUTE ON FUNCTION sys.hashbytes(IN VARCHAR, IN sys.bbf_varbinary) TO PUBLIC;
GRANT EXECUTE ON FUNCTION sys.hashbytes(IN alg sys.VARCHAR, IN sys.NVARCHAR) TO PUBLIC;

CREATE OR REPLACE FUNCTION sys.hashbytes(IN alg sys.VARCHAR, IN data sys.bbf_varbinary) RETURNS sys.bbf_varbinary
AS 'babelfishpg_tsql', 'hashbytes' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
GRANT EXECUTE ON FUNCTION sys.hashbytes(IN alg sys.VARCHAR, IN sys.bbf_varbinary) TO PUBLIC;

CREATE OR REPLACE FUNCTION sys.quotename(IN input_string VARCHAR, IN delimiter char default '[') RETURNS
sys.nvarchar AS 'babelfishpg_tsql', 'quotename' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,49 @@ LANGUAGE plpgsql;
* So make sure that any SQL statement (DDL/DML) being added here can be executed multiple times without affecting
* final behaviour.
*/

DO $$
DECLARE
exception_message text;
BEGIN
ALTER FUNCTION sys.hashbytes(IN alg pg_catalog.VARCHAR, IN data pg_catalog.VARCHAR) RENAME TO hashbytes_varchar_deprecated_4_5_0;

EXCEPTION WHEN OTHERS THEN
GET STACKED DIAGNOSTICS
exception_message = MESSAGE_TEXT;
RAISE WARNING '%', exception_message;
END;
$$;

CALL sys.babelfish_drop_deprecated_object('function', 'sys', 'hashbytes_varchar_deprecated_4_5_0');

DO $$
DECLARE
exception_message text;
BEGIN
ALTER FUNCTION sys.hashbytes(IN alg pg_catalog.VARCHAR, IN data sys.bbf_varbinary) RENAME TO hashbytes_varbinary_deprecated_4_5_0;

EXCEPTION WHEN OTHERS THEN
GET STACKED DIAGNOSTICS
exception_message = MESSAGE_TEXT;
RAISE WARNING '%', exception_message;
END;
$$;

CALL sys.babelfish_drop_deprecated_object('function', 'sys', 'hashbytes_varbinary_deprecated_4_5_0');

CREATE OR REPLACE FUNCTION sys.hashbytes(IN alg sys.VARCHAR, IN data sys.VARCHAR) RETURNS sys.bbf_varbinary
AS 'babelfishpg_tsql', 'hashbytes' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
GRANT EXECUTE ON FUNCTION sys.hashbytes(IN alg sys.VARCHAR, IN sys.VARCHAR) TO PUBLIC;

CREATE OR REPLACE FUNCTION sys.hashbytes(IN alg sys.VARCHAR, IN data sys.NVARCHAR) RETURNS sys.bbf_varbinary
AS 'babelfishpg_tsql', 'hashbytes' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
GRANT EXECUTE ON FUNCTION sys.hashbytes(IN alg sys.VARCHAR, IN sys.NVARCHAR) TO PUBLIC;

CREATE OR REPLACE FUNCTION sys.hashbytes(IN alg sys.VARCHAR, IN data sys.bbf_varbinary) RETURNS sys.bbf_varbinary
AS 'babelfishpg_tsql', 'hashbytes' LANGUAGE C IMMUTABLE STRICT PARALLEL SAFE;
GRANT EXECUTE ON FUNCTION sys.hashbytes(IN alg sys.VARCHAR, IN sys.bbf_varbinary) TO PUBLIC;

CREATE OR REPLACE FUNCTION sys.babelfish_update_server_collation_name() RETURNS VOID
LANGUAGE C
AS 'babelfishpg_common', 'babelfish_update_server_collation_name';
Expand Down
66 changes: 53 additions & 13 deletions contrib/babelfishpg_tsql/src/pltsql_coerce.c
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ tsql_special_function_t tsql_special_function_list[] =
{"sys", "stuff", "stuff", false, 4},
{"sys", "translate", "translate", false, 3},
{"sys", "trim", "Trim", false, 1},
{"sys", "hashbytes", "hashbytes", false, 2},
{"sys", "trim", "Trim", false, 2},
{"sys", "ltrim", "ltrim", false, 1},
{"sys", "rtrim", "rtrim", false, 1},
Expand Down Expand Up @@ -1142,15 +1143,19 @@ validate_special_function(char *func_nsname, char *func_name, int nargs, bool nu
static FuncCandidateList
tsql_func_select_candidate_for_special_func(List *names, int nargs, Oid *input_typeids, FuncCandidateList candidates)
{
FuncCandidateList current_candidate, best_candidate;
Oid expr_result_type;
char *proc_nsname;
char *proc_name;
bool is_func_validated;
int ncandidates;
Oid rettype;
Oid sys_oid = get_namespace_oid("sys", false);
Oid *new_input_typeids;
FuncCandidateList current_candidate, best_candidate;
Oid expr_result_type;
char *proc_nsname;
char *proc_name;
bool is_func_validated;
int ncandidates;
Oid rettype;
Oid sys_oid = get_namespace_oid("sys", false);
Oid *new_input_typeids;
Oid *argtypes;
int nargs_func;
Oid second_arg_type = InvalidOid;
Oid expr_arg_type;

DeconstructQualifiedName(names, &proc_nsname, &proc_name);

Expand Down Expand Up @@ -1181,6 +1186,7 @@ tsql_func_select_candidate_for_special_func(List *names, int nargs, Oid *input_t

/* function based logic to decide return type */
expr_result_type = InvalidOid;
expr_arg_type = InvalidOid;
if (strlen(proc_name) == 4 && strncmp(proc_name,"trim", 4) == 0 && nargs == 2)
{
if ((*common_utility_plugin_ptr->is_tsql_nvarchar_datatype)(new_input_typeids[1])
Expand Down Expand Up @@ -1329,15 +1335,38 @@ tsql_func_select_candidate_for_special_func(List *names, int nargs, Oid *input_t
expr_result_type = get_sys_varcharoid();
}
}
else if (strlen(proc_name) == 9 && strncmp(proc_name,"hashbytes", 9) == 0 && nargs == 2)
{
if ((*common_utility_plugin_ptr->is_tsql_varchar_datatype) (new_input_typeids[1])
|| (*common_utility_plugin_ptr->is_tsql_bpchar_datatype) (new_input_typeids[1])
|| (*common_utility_plugin_ptr->is_tsql_text_datatype) (new_input_typeids[1])
|| new_input_typeids[1] == UNKNOWNOID)
{
expr_arg_type = get_sys_varcharoid();
}
else if((*common_utility_plugin_ptr->is_tsql_nvarchar_datatype) (new_input_typeids[1])
|| (*common_utility_plugin_ptr->is_tsql_nchar_datatype) (new_input_typeids[1])
|| (*common_utility_plugin_ptr->is_tsql_ntext_datatype) (new_input_typeids[1]))
{
expr_arg_type = (*common_utility_plugin_ptr->lookup_tsql_datatype_oid) ("nvarchar");
}
else if(is_tsql_binary_family_datatype(new_input_typeids[1]))
{
expr_arg_type = (*common_utility_plugin_ptr->lookup_tsql_datatype_oid) ("bbf_varbinary");
}
}

/* free new_input_typeids, as they are no longer needed */
if (new_input_typeids)
pfree(new_input_typeids);

if (!OidIsValid(expr_result_type))
if (!OidIsValid(expr_result_type) && !OidIsValid(expr_arg_type))
return NULL;

/* Get the candidate with matching return type */
/*
* Get the candidate with matching return type or
* second argument type(specifically for hashbytes function)
*/
ncandidates = 0;
best_candidate = NULL;
for (current_candidate = candidates;
Expand All @@ -1349,13 +1378,24 @@ tsql_func_select_candidate_for_special_func(List *names, int nargs, Oid *input_t
continue;

rettype = get_func_rettype(current_candidate->oid);
/* get the function second argument if we have hashbytes function */
if(strlen(proc_name) == 9 && strncmp(proc_name,"hashbytes", 9) == 0 && nargs == 2)
{
get_func_signature(current_candidate->oid, &argtypes, &nargs_func);
second_arg_type = argtypes[1];
}

/* Ignore following definitions as these are used when no other potential definition can be used. */
if ((current_candidate->args[0] == TEXTOID && rettype == get_sys_varcharoid())
|| (current_candidate->args[0] == BYTEAOID && rettype == BYTEAOID))
continue;

if (expr_result_type == rettype)
/*
* Find the best candidate based on second_arg_type(this will be valid only for the case of hasbytes)
* for hashbytes function. For other special functions we are selecting best candidate on the basis
* of return type.
*/
if ((OidIsValid(expr_result_type) && expr_result_type == rettype)
|| (OidIsValid(expr_arg_type) && OidIsValid(second_arg_type) && expr_arg_type == second_arg_type))
{
best_candidate = current_candidate;
ncandidates++;
Expand Down
28 changes: 23 additions & 5 deletions contrib/babelfishpg_tsql/src/string.c
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,31 @@ static Datum return_varchar_pointer(char *buf, int size);
Datum
hashbytes(PG_FUNCTION_ARGS)
{
const char *algorithm = text_to_cstring(PG_GETARG_TEXT_P(0));
bytea *in = PG_GETARG_BYTEA_PP(1);
size_t len = VARSIZE_ANY_EXHDR(in);
const uint8 *data = (unsigned char *) VARDATA_ANY(in);
bytea *result;
Oid input_type = get_fn_expr_argtype(fcinfo->flinfo,1);
const char *algorithm = text_to_cstring(PG_GETARG_TEXT_P(0));
bytea *in = PG_GETARG_BYTEA_PP(1);
size_t len = VARSIZE_ANY_EXHDR(in);
const uint8 *data = (unsigned char *) VARDATA_ANY(in);
bytea *result;
StringInfoData utf16_data;

/* If the input_type is nvarchar then we convert it to UTF-16 encoding */
initStringInfo(&utf16_data);
if(((*common_utility_plugin_ptr->is_tsql_nvarchar_datatype)(input_type)))
{
(common_utility_plugin_ptr->TsqlUTF8toUTF16StringInfo)(&utf16_data, data, len);
len = utf16_data.len;
data = (const uint8 *)utf16_data.data;
}

if (strcasecmp(algorithm, "MD2") == 0)
{
pfree(utf16_data.data);
PG_RETURN_NULL();
}
else if (strcasecmp(algorithm, "MD4") == 0)
{
pfree(utf16_data.data);
PG_RETURN_NULL();
}
else if (strcasecmp(algorithm, "MD5") == 0)
Expand All @@ -87,6 +100,7 @@ hashbytes(PG_FUNCTION_ARGS)

SET_VARSIZE(result, sizeof(buf) + VARHDRSZ);
memcpy(VARDATA(result), buf, sizeof(buf));
pfree(utf16_data.data);

PG_RETURN_BYTEA_P(result);
}
Expand All @@ -101,6 +115,7 @@ hashbytes(PG_FUNCTION_ARGS)

SET_VARSIZE(result, sizeof(buf) + VARHDRSZ);
memcpy(VARDATA(result), buf, sizeof(buf));
pfree(utf16_data.data);

PG_RETURN_BYTEA_P(result);
}
Expand All @@ -121,6 +136,7 @@ hashbytes(PG_FUNCTION_ARGS)

SET_VARSIZE(result, sizeof(buf) + VARHDRSZ);
memcpy(VARDATA(result), buf, sizeof(buf));
pfree(utf16_data.data);

PG_RETURN_BYTEA_P(result);
}
Expand All @@ -134,11 +150,13 @@ hashbytes(PG_FUNCTION_ARGS)

SET_VARSIZE(result, sizeof(buf) + VARHDRSZ);
memcpy(VARDATA(result), buf, sizeof(buf));
pfree(utf16_data.data);

PG_RETURN_BYTEA_P(result);
}
else
{
pfree(utf16_data.data);
PG_RETURN_NULL();
}
}
Expand Down
18 changes: 15 additions & 3 deletions test/JDBC/expected/cast_nvarchar_test-vu-cleanup.out
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
-- to do in hashbytes PR
-- DROP TABLE IF EXISTS TestHash;
-- GO
DROP TABLE IF EXISTS TestHash;
GO
DROP type user_defined_varbinary;
GO
DROP type user_defined_nvarchar;
Expand All @@ -17,4 +16,17 @@ DROP FUNCTION IF EXISTS dbo.CastbinaryToNVarchar;
GO
DROP TABLE IF EXISTS casttable;
GO
DROP FUNCTION IF EXISTS dbo.HashMultipleTypes;
GO
DROP PROCEDURE IF EXISTS dbo.PrintHashResults;
GO
DROP VIEW IF EXISTS dbo.HashDemoView;
GO
drop view if exists hasheddataview;
GO
drop view if exists hasheddataview1;
GO
drop table IF EXISTS hashbytes_table;
GO


44 changes: 34 additions & 10 deletions test/JDBC/expected/cast_nvarchar_test-vu-prepare.out
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
CREATE TABLE TestHash(
nvarchar_data nvarchar(32) NOT NULL,
varchar_data varchar(32) NOT NULL,
cast_hashbytes_nvarchar_data AS (cast ( hashbytes('sha1', nvarchar_data) AS varbinary(20) )) PERSISTED NOT NULL,
convert_hashbytes_nvarchar_data AS (convert( varbinary(20), hashbytes('sha1',nvarchar_data))) PERSISTED NOT NULL,
cast_hashbytes_varchar_data AS (cast ( hashbytes('sha1', varchar_data) AS varbinary(20) )) PERSISTED NOT NULL,
convert_hashbytes_varchar_data AS (convert( varbinary(20), hashbytes('sha1',varchar_data))) PERSISTED NOT NULL
);
GO

-- to do in hashbytes PR
-- CREATE TABLE TestHash(
-- nvarchar_data nvarchar(32) NOT NULL,
-- varchar_data varchar(32) NOT NULL,
-- cast_hashbytes_nvarchar_data AS (cast ( hashbytes('sha1', nvarchar_data) AS varbinary(20) )) PERSISTED NOT NULL,
-- convert_hashbytes_nvarchar_data AS (convert( varbinary(20), hashbytes('sha1',nvarchar_data))) PERSISTED NOT NULL,
-- cast_hashbytes_varchar_data AS (cast ( hashbytes('sha1', varchar_data) AS varbinary(20) )) PERSISTED NOT NULL,
-- convert_hashbytes_varchar_data AS (convert( varbinary(20), hashbytes('sha1',varchar_data))) PERSISTED NOT NULL
-- );
-- GO
-- Function to cast NVARCHAR to VARBINARY
CREATE FUNCTION dbo.CastNVarcharToVarbinary
(
Expand Down Expand Up @@ -71,6 +70,31 @@ SELECT
FROM casttable
GO

-- Function hashbytes different input types
CREATE FUNCTION dbo.HashMultipleTypes
(
@VarcharInput VARCHAR(MAX),
@NVarcharInput NVARCHAR(MAX),
@VarbinaryInput VARBINARY(MAX)
)
RETURNS TABLE
AS
RETURN
(
SELECT
HASHBYTES('SHA2_256', @VarcharInput) AS VarcharHash,
HASHBYTES('SHA2_256', @NVarcharInput) AS NVarcharHash,
HASHBYTES('SHA2_256', @VarbinaryInput) AS VarbinaryHash
)
GO


-- View to demonstrate hashing
CREATE VIEW dbo.HashDemoView
AS
SELECT
HASHBYTES('SHA2_256', CAST('Hello' AS VARCHAR(MAX))) AS VarcharHash,
HASHBYTES('SHA2_256', CAST(N'Hello' AS NVARCHAR(MAX))) AS NVarcharHash,
HASHBYTES('SHA2_256', CAST(0x48656C6C6F AS VARBINARY(MAX))) AS VarbinaryHash
GO

Loading

0 comments on commit be9f87c

Please sign in to comment.