Skip to content

Commit

Permalink
Handle deparse of INTERVAL correctly when used in SET statements (#184)
Browse files Browse the repository at this point in the history
Co-authored-by: Lukas Fittl <[email protected]>
  • Loading branch information
coderdan and lfittl authored Mar 25, 2023
1 parent ff32f92 commit 43b116b
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 72 deletions.
189 changes: 119 additions & 70 deletions src/pg_query_deparse.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ typedef enum DeparseNodeContext {
DEPARSE_NODE_CONTEXT_XMLNAMESPACES,
DEPARSE_NODE_CONTEXT_CREATE_TYPE,
DEPARSE_NODE_CONTEXT_ALTER_TYPE,
DEPARSE_NODE_CONTEXT_SET_STATEMENT,
// Identifier vs constant context
DEPARSE_NODE_CONTEXT_IDENTIFIER,
DEPARSE_NODE_CONTEXT_CONSTANT
Expand Down Expand Up @@ -159,8 +160,9 @@ static void deparseRangeSubselect(StringInfo str, RangeSubselect *range_subselec
static void deparseRangeFunction(StringInfo str, RangeFunction *range_func);
static void deparseAArrayExpr(StringInfo str, A_ArrayExpr * array_expr);
static void deparseRowExpr(StringInfo str, RowExpr *row_expr);
static void deparseTypeCast(StringInfo str, TypeCast *type_cast);
static void deparseTypeCast(StringInfo str, TypeCast *type_cast, DeparseNodeContext context);
static void deparseTypeName(StringInfo str, TypeName *type_name);
static void deparseIntervalTypmods(StringInfo str, TypeName *type_name);
static void deparseNullTest(StringInfo str, NullTest *null_test);
static void deparseCaseExpr(StringInfo str, CaseExpr *case_expr);
static void deparseCaseWhen(StringInfo str, CaseWhen *case_when);
Expand Down Expand Up @@ -259,7 +261,7 @@ static void deparseExpr(StringInfo str, Node *node)
deparseXmlExpr(str, castNode(XmlExpr, node));
break;
case T_TypeCast:
deparseTypeCast(str, castNode(TypeCast, node));
deparseTypeCast(str, castNode(TypeCast, node), DEPARSE_NODE_CONTEXT_NONE);
break;
case T_A_Const:
deparseAConst(str, castNode(A_Const, node));
Expand Down Expand Up @@ -340,7 +342,7 @@ static void deparseCExpr(StringInfo str, Node *node)
deparseAConst(str, castNode(A_Const, node));
break;
case T_TypeCast:
deparseTypeCast(str, castNode(TypeCast, node));
deparseTypeCast(str, castNode(TypeCast, node), DEPARSE_NODE_CONTEXT_NONE);
break;
case T_A_Expr:
appendStringInfoChar(str, '(');
Expand Down Expand Up @@ -1333,6 +1335,10 @@ static void deparseVarList(StringInfo str, List *l)
else
Assert(false);
}
else if (IsA(lfirst(lc), TypeCast))
{
deparseTypeCast(str, castNode(TypeCast, lfirst(lc)), DEPARSE_NODE_CONTEXT_SET_STATEMENT);
}
else
{
Assert(false);
Expand Down Expand Up @@ -1778,7 +1784,7 @@ static void deparseFuncExprWindowless(StringInfo str, Node* node)
deparseSQLValueFunction(str, castNode(SQLValueFunction, node));
break;
case T_TypeCast:
deparseTypeCast(str, castNode(TypeCast, node));
deparseTypeCast(str, castNode(TypeCast, node), DEPARSE_NODE_CONTEXT_NONE);
break;
case T_CoalesceExpr:
deparseCoalesceExpr(str, castNode(CoalesceExpr, node));
Expand Down Expand Up @@ -3533,7 +3539,7 @@ static void deparseRowExpr(StringInfo str, RowExpr *row_expr)
appendStringInfoChar(str, ')');
}

static void deparseTypeCast(StringInfo str, TypeCast *type_cast)
static void deparseTypeCast(StringInfo str, TypeCast *type_cast, DeparseNodeContext context)
{
bool need_parens = false;

Expand Down Expand Up @@ -3581,6 +3587,13 @@ static void deparseTypeCast(StringInfo str, TypeCast *type_cast)
return;
}
}
else if (strcmp(typename, "interval") == 0 && context == DEPARSE_NODE_CONTEXT_SET_STATEMENT && IsA(&a_const->val, String))
{
appendStringInfoString(str, "interval ");
deparseAConst(str, a_const);
deparseIntervalTypmods(str, type_cast->typeName);
return;
}
}

// Ensure negative values have wrapping parentheses
Expand Down Expand Up @@ -3705,70 +3718,9 @@ static void deparseTypeName(StringInfo str, TypeName *type_name)
}
else if (strcmp(name, "interval") == 0 && list_length(type_name->typmods) >= 1)
{
Assert(IsA(linitial(type_name->typmods), A_Const));
Assert(IsA(&castNode(A_Const, linitial(type_name->typmods))->val, Integer));

int fields = intVal(&castNode(A_Const, linitial(type_name->typmods))->val);

appendStringInfoString(str, "interval");
deparseIntervalTypmods(str, type_name);

// This logic is based on intervaltypmodout in timestamp.c
switch (fields)
{
case INTERVAL_MASK(YEAR):
appendStringInfoString(str, " year");
break;
case INTERVAL_MASK(MONTH):
appendStringInfoString(str, " month");
break;
case INTERVAL_MASK(DAY):
appendStringInfoString(str, " day");
break;
case INTERVAL_MASK(HOUR):
appendStringInfoString(str, " hour");
break;
case INTERVAL_MASK(MINUTE):
appendStringInfoString(str, " minute");
break;
case INTERVAL_MASK(SECOND):
appendStringInfoString(str, " second");
break;
case INTERVAL_MASK(YEAR) | INTERVAL_MASK(MONTH):
appendStringInfoString(str, " year to month");
break;
case INTERVAL_MASK(DAY) | INTERVAL_MASK(HOUR):
appendStringInfoString(str, " day to hour");
break;
case INTERVAL_MASK(DAY) | INTERVAL_MASK(HOUR) | INTERVAL_MASK(MINUTE):
appendStringInfoString(str, " day to minute");
break;
case INTERVAL_MASK(DAY) | INTERVAL_MASK(HOUR) | INTERVAL_MASK(MINUTE) | INTERVAL_MASK(SECOND):
appendStringInfoString(str, " day to second");
break;
case INTERVAL_MASK(HOUR) | INTERVAL_MASK(MINUTE):
appendStringInfoString(str, " hour to minute");
break;
case INTERVAL_MASK(HOUR) | INTERVAL_MASK(MINUTE) | INTERVAL_MASK(SECOND):
appendStringInfoString(str, " hour to second");
break;
case INTERVAL_MASK(MINUTE) | INTERVAL_MASK(SECOND):
appendStringInfoString(str, " minute to second");
break;
case INTERVAL_FULL_RANGE:
// Nothing
break;
default:
Assert(false);
break;
}

if (list_length(type_name->typmods) == 2)
{
int precision = intVal(&castNode(A_Const, lsecond(type_name->typmods))->val);
if (precision != INTERVAL_FULL_PRECISION)
appendStringInfo(str, "(%d)", precision);
}

skip_typmods = true;
}
else
Expand Down Expand Up @@ -3814,6 +3766,79 @@ static void deparseTypeName(StringInfo str, TypeName *type_name)
appendStringInfoString(str, "%type");
}

// Handle typemods for Interval types separately
// so that they can be applied appropriately for different contexts.
// For example, when using `SET` a query like `INTERVAL 'x' hour TO minute`
// the `INTERVAL` keyword is specified first.
// In all other contexts, intervals use the `'x'::interval` style.
static void deparseIntervalTypmods(StringInfo str, TypeName *type_name)
{
const char *name = strVal(lsecond(type_name->names));
Assert(strcmp(name, "interval") == 0);
Assert(list_length(type_name->typmods) >= 1);
Assert(IsA(linitial(type_name->typmods), A_Const));
Assert(IsA(&castNode(A_Const, linitial(type_name->typmods))->val, Integer));

int fields = intVal(&castNode(A_Const, linitial(type_name->typmods))->val);

// This logic is based on intervaltypmodout in timestamp.c
switch (fields)
{
case INTERVAL_MASK(YEAR):
appendStringInfoString(str, " year");
break;
case INTERVAL_MASK(MONTH):
appendStringInfoString(str, " month");
break;
case INTERVAL_MASK(DAY):
appendStringInfoString(str, " day");
break;
case INTERVAL_MASK(HOUR):
appendStringInfoString(str, " hour");
break;
case INTERVAL_MASK(MINUTE):
appendStringInfoString(str, " minute");
break;
case INTERVAL_MASK(SECOND):
appendStringInfoString(str, " second");
break;
case INTERVAL_MASK(YEAR) | INTERVAL_MASK(MONTH):
appendStringInfoString(str, " year to month");
break;
case INTERVAL_MASK(DAY) | INTERVAL_MASK(HOUR):
appendStringInfoString(str, " day to hour");
break;
case INTERVAL_MASK(DAY) | INTERVAL_MASK(HOUR) | INTERVAL_MASK(MINUTE):
appendStringInfoString(str, " day to minute");
break;
case INTERVAL_MASK(DAY) | INTERVAL_MASK(HOUR) | INTERVAL_MASK(MINUTE) | INTERVAL_MASK(SECOND):
appendStringInfoString(str, " day to second");
break;
case INTERVAL_MASK(HOUR) | INTERVAL_MASK(MINUTE):
appendStringInfoString(str, " hour to minute");
break;
case INTERVAL_MASK(HOUR) | INTERVAL_MASK(MINUTE) | INTERVAL_MASK(SECOND):
appendStringInfoString(str, " hour to second");
break;
case INTERVAL_MASK(MINUTE) | INTERVAL_MASK(SECOND):
appendStringInfoString(str, " minute to second");
break;
case INTERVAL_FULL_RANGE:
// Nothing
break;
default:
Assert(false);
break;
}

if (list_length(type_name->typmods) == 2)
{
int precision = intVal(&castNode(A_Const, lsecond(type_name->typmods))->val);
if (precision != INTERVAL_FULL_PRECISION)
appendStringInfo(str, "(%d)", precision);
}
}

static void deparseNullTest(StringInfo str, NullTest *null_test)
{
// argisrow is always false in raw parser output
Expand Down Expand Up @@ -6898,6 +6923,22 @@ static void deparseTransactionStmt(StringInfo str, TransactionStmt *transaction_
removeTrailingSpace(str);
}

// Determine if we hit SET TIME ZONE INTERVAL, that has special syntax not
// supported for other SET statements
static bool isSetTimeZoneInterval(VariableSetStmt* stmt)
{
if (!(strcmp(stmt->name, "timezone") == 0 &&
list_length(stmt->args) == 1 &&
IsA(linitial(stmt->args), TypeCast)))
return false;

TypeName* typeName = castNode(TypeCast, linitial(stmt->args))->typeName;

return (list_length(typeName->names) == 2 &&
strcmp(strVal(linitial(typeName->names)), "pg_catalog") == 0 &&
strcmp(strVal(llast(typeName->names)), "interval") == 0);
}

static void deparseVariableSetStmt(StringInfo str, VariableSetStmt* variable_set_stmt)
{
ListCell *lc;
Expand All @@ -6908,9 +6949,17 @@ static void deparseVariableSetStmt(StringInfo str, VariableSetStmt* variable_set
appendStringInfoString(str, "SET ");
if (variable_set_stmt->is_local)
appendStringInfoString(str, "LOCAL ");
deparseVarName(str, variable_set_stmt->name);
appendStringInfoString(str, " TO ");
deparseVarList(str, variable_set_stmt->args);
if (isSetTimeZoneInterval(variable_set_stmt))
{
appendStringInfoString(str, "TIME ZONE ");
deparseVarList(str, variable_set_stmt->args);
}
else
{
deparseVarName(str, variable_set_stmt->name);
appendStringInfoString(str, " TO ");
deparseVarList(str, variable_set_stmt->args);
}
break;
case VAR_SET_DEFAULT: /* SET var TO DEFAULT */
appendStringInfoString(str, "SET ");
Expand Down
6 changes: 4 additions & 2 deletions test/deparse.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@ void remove_node_locations(char *parse_tree_json)
int run_test(const char *query, bool compare_query_text) {
PgQueryProtobufParseResult parse_result = pg_query_parse_protobuf(query);
if (parse_result.error) {
pg_query_free_protobuf_parse_result(parse_result);
if (!compare_query_text) // Silently fail for regression tests which can contain intentional syntax errors
if (!compare_query_text) { // Silently fail for regression tests which can contain intentional syntax errors
pg_query_free_protobuf_parse_result(parse_result);
return EXIT_SUCCESS;
}
printf("\nERROR for \"%s\"\n %s\n", query, parse_result.error->message);
pg_query_free_protobuf_parse_result(parse_result);
return EXIT_FAILURE;
}

Expand Down
2 changes: 2 additions & 0 deletions test/deparse_tests.c
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ const char* tests[] = {
"SET search_path TO my_schema, public",
"SET LOCAL search_path TO my_schema, public",
"SET \"user\" TO 4",
"SET TIME ZONE interval '+00:00' hour to minute",
"SET timezone TO -7",
"VACUUM",
"VACUUM t",
"VACUUM (FULL) t",
Expand Down

0 comments on commit 43b116b

Please sign in to comment.