Skip to content

Commit

Permalink
Support custom struct field names with new scalar function named_stru…
Browse files Browse the repository at this point in the history
…ct (#9743)

* Support custom struct field names with new scalar function named_struct

* add tests and corretly handle mixed arrray and scalar values

* fix slt

* fmt

* port test to slt

---------

Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
gstvg and alamb authored Mar 30, 2024
1 parent a5f7714 commit aa879bf
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 27 deletions.
3 changes: 3 additions & 0 deletions datafusion/functions/src/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
mod arrow_cast;
mod arrowtypeof;
mod getfield;
mod named_struct;
mod nullif;
mod nvl;
mod nvl2;
Expand All @@ -32,6 +33,7 @@ make_udf_function!(nvl::NVLFunc, NVL, nvl);
make_udf_function!(nvl2::NVL2Func, NVL2, nvl2);
make_udf_function!(arrowtypeof::ArrowTypeOfFunc, ARROWTYPEOF, arrow_typeof);
make_udf_function!(r#struct::StructFunc, STRUCT, r#struct);
make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct);
make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field);

// Export the functions out of this package, both as expr_fn as well as a list of functions
Expand All @@ -42,5 +44,6 @@ export_functions!(
(nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."),
(arrow_typeof, arg_1, "Returns the Arrow type of the input expression."),
(r#struct, args, "Returns a struct with the given arguments"),
(named_struct, args, "Returns a struct with the given names and arguments pairs"),
(get_field, arg_1 arg_2, "Returns the value of the field with the given name from the struct")
);
148 changes: 148 additions & 0 deletions datafusion/functions/src/core/named_struct.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::StructArray;
use arrow::datatypes::{DataType, Field, Fields};
use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, Expr, ExprSchemable};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;

/// put values in a struct array.
fn named_struct_expr(args: &[ColumnarValue]) -> Result<ColumnarValue> {
// do not accept 0 arguments.
if args.is_empty() {
return exec_err!(
"named_struct requires at least one pair of arguments, got 0 instead"
);
}

if args.len() % 2 != 0 {
return exec_err!(
"named_struct requires an even number of arguments, got {} instead",
args.len()
);
}

let (names, values): (Vec<_>, Vec<_>) = args
.chunks_exact(2)
.enumerate()
.map(|(i, chunk)| {

let name_column = &chunk[0];

let name = match name_column {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => name_scalar,
_ => return exec_err!("named_struct even arguments must be string literals, got {name_column:?} instead at position {}", i * 2)
};

Ok((name, chunk[1].clone()))
})
.collect::<Result<Vec<_>>>()?
.into_iter()
.unzip();

let arrays = ColumnarValue::values_to_arrays(&values)?;

let fields = names
.into_iter()
.zip(arrays)
.map(|(name, value)| {
(
Arc::new(Field::new(name, value.data_type().clone(), true)),
value,
)
})
.collect::<Vec<_>>();

Ok(ColumnarValue::Array(Arc::new(StructArray::from(fields))))
}

#[derive(Debug)]
pub(super) struct NamedStructFunc {
signature: Signature,
}

impl NamedStructFunc {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for NamedStructFunc {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"named_struct"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
internal_err!(
"named_struct: return_type called instead of return_type_from_exprs"
)
}

fn return_type_from_exprs(
&self,
args: &[datafusion_expr::Expr],
schema: &dyn datafusion_common::ExprSchema,
_arg_types: &[DataType],
) -> Result<DataType> {
// do not accept 0 arguments.
if args.is_empty() {
return exec_err!(
"named_struct requires at least one pair of arguments, got 0 instead"
);
}

if args.len() % 2 != 0 {
return exec_err!(
"named_struct requires an even number of arguments, got {} instead",
args.len()
);
}

let return_fields = args
.chunks_exact(2)
.enumerate()
.map(|(i, chunk)| {
let name = &chunk[0];
let value = &chunk[1];

if let Expr::Literal(ScalarValue::Utf8(Some(name))) = name {
Ok(Field::new(name, value.get_type(schema)?, true))
} else {
exec_err!("named_struct even arguments must be string literals, got {name} instead at position {}", i * 2)
}
})
.collect::<Result<Vec<Field>>>()?;
Ok(DataType::Struct(Fields::from(return_fields)))
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
named_struct_expr(args)
}
}
45 changes: 36 additions & 9 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ use datafusion_expr::expr::InList;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::{
col, expr, lit, AggregateFunction, Between, BinaryExpr, BuiltinScalarFunction, Cast,
Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Operator, TryCast,
Expr, ExprSchemable, GetFieldAccess, GetIndexedField, Like, Literal, Operator,
TryCast,
};

use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
Expand Down Expand Up @@ -604,18 +605,44 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
let args = values
.into_iter()
.map(|value| {
self.sql_expr_to_logical_expr(value, input_schema, planner_context)
.enumerate()
.map(|(i, value)| {
let args = if let SQLExpr::Named { expr, name } = value {
[
name.value.lit(),
self.sql_expr_to_logical_expr(
*expr,
input_schema,
planner_context,
)?,
]
} else {
[
format!("c{i}").lit(),
self.sql_expr_to_logical_expr(
value,
input_schema,
planner_context,
)?,
]
};

Ok(args)
})
.collect::<Result<Vec<_>>>()?;
let struct_func = self
.collect::<Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect();

let named_struct_func = self
.context_provider
.get_function_meta("struct")
.get_function_meta("named_struct")
.ok_or_else(|| {
internal_datafusion_err!("Unable to find expected 'struct' function")
})?;
internal_datafusion_err!("Unable to find expected 'named_struct' function")
})?;

Ok(Expr::ScalarFunction(ScalarFunction::new_udf(
struct_func,
named_struct_func,
args,
)))
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sqllogictest/test_files/explain.slt
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,8 @@ query TT
explain select struct(1, 2.3, 'abc');
----
logical_plan
Projection: Struct({c0:1,c1:2.3,c2:abc}) AS struct(Int64(1),Float64(2.3),Utf8("abc"))
Projection: Struct({c0:1,c1:2.3,c2:abc}) AS named_struct(Utf8("c0"),Int64(1),Utf8("c1"),Float64(2.3),Utf8("c2"),Utf8("abc"))
--EmptyRelation
physical_plan
ProjectionExec: expr=[{c0:1,c1:2.3,c2:abc} as struct(Int64(1),Float64(2.3),Utf8("abc"))]
ProjectionExec: expr=[{c0:1,c1:2.3,c2:abc} as named_struct(Utf8("c0"),Int64(1),Utf8("c1"),Float64(2.3),Utf8("c2"),Utf8("abc"))]
--PlaceholderRowExec
112 changes: 106 additions & 6 deletions datafusion/sqllogictest/test_files/struct.slt
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ statement ok
CREATE TABLE values(
a INT,
b FLOAT,
c VARCHAR
c VARCHAR,
n VARCHAR,
) AS VALUES
(1, 1.1, 'a'),
(2, 2.2, 'b'),
(3, 3.3, 'c')
(1, 1.1, 'a', NULL),
(2, 2.2, 'b', NULL),
(3, 3.3, 'c', NULL)
;

# struct[i]
Expand All @@ -50,6 +51,18 @@ select struct(1, 3.14, 'e');
----
{c0: 1, c1: 3.14, c2: e}

# struct scalar function with named values
query ?
select struct(1 as "name0", 3.14 as name1, 'e', true as 'name3');
----
{name0: 1, name1: 3.14, c2: e, name3: true}

# struct scalar function with mixed named and unnamed values
query ?
select struct(1, 3.14 as name1, 'e', true);
----
{c0: 1, name1: 3.14, c2: e, c3: true}

# struct scalar function with columns #1
query ?
select struct(a, b, c) from values;
Expand All @@ -72,11 +85,98 @@ query TT
explain select struct(a, b, c) from values;
----
logical_plan
Projection: struct(values.a, values.b, values.c)
Projection: named_struct(Utf8("c0"), values.a, Utf8("c1"), values.b, Utf8("c2"), values.c)
--TableScan: values projection=[a, b, c]
physical_plan
ProjectionExec: expr=[struct(a@0, b@1, c@2) as struct(values.a,values.b,values.c)]
ProjectionExec: expr=[named_struct(c0, a@0, c1, b@1, c2, c@2) as named_struct(Utf8("c0"),values.a,Utf8("c1"),values.b,Utf8("c2"),values.c)]
--MemoryExec: partitions=1, partition_sizes=[1]

# error on 0 arguments
query error DataFusion error: Error during planning: No function matches the given name and argument types 'named_struct\(\)'. You might need to add explicit type casts.
select named_struct();

# error on odd number of arguments #1
query error DataFusion error: Execution error: named_struct requires an even number of arguments, got 1 instead
select named_struct('a');

# error on odd number of arguments #2
query error DataFusion error: Execution error: named_struct requires an even number of arguments, got 1 instead
select named_struct(1);

# error on odd number of arguments #3
query error DataFusion error: Execution error: named_struct requires an even number of arguments, got 1 instead
select named_struct(values.a) from values;

# error on odd number of arguments #4
query error DataFusion error: Execution error: named_struct requires an even number of arguments, got 3 instead
select named_struct('a', 1, 'b');

# error on even argument not a string literal #1
query error DataFusion error: Execution error: named_struct even arguments must be string literals, got Int64\(1\) instead at position 0
select named_struct(1, 'a');

# error on even argument not a string literal #2
query error DataFusion error: Execution error: named_struct even arguments must be string literals, got Int64\(0\) instead at position 2
select named_struct('corret', 1, 0, 'wrong');

# error on even argument not a string literal #3
query error DataFusion error: Execution error: named_struct even arguments must be string literals, got values\.a instead at position 0
select named_struct(values.a, 'a') from values;

# error on even argument not a string literal #4
query error DataFusion error: Execution error: named_struct even arguments must be string literals, got values\.c instead at position 0
select named_struct(values.c, 'c') from values;

# named_struct with mixed scalar and array values #1
query ?
select named_struct('scalar', 27, 'array', values.a, 'null', NULL) from values;
----
{scalar: 27, array: 1, null: }
{scalar: 27, array: 2, null: }
{scalar: 27, array: 3, null: }

# named_struct with mixed scalar and array values #2
query ?
select named_struct('array', values.a, 'scalar', 27, 'null', NULL) from values;
----
{array: 1, scalar: 27, null: }
{array: 2, scalar: 27, null: }
{array: 3, scalar: 27, null: }

# named_struct with mixed scalar and array values #3
query ?
select named_struct('null', NULL, 'array', values.a, 'scalar', 27) from values;
----
{null: , array: 1, scalar: 27}
{null: , array: 2, scalar: 27}
{null: , array: 3, scalar: 27}

# named_struct with mixed scalar and array values #4
query ?
select named_struct('null_array', values.n, 'array', values.a, 'scalar', 27, 'null', NULL) from values;
----
{null_array: , array: 1, scalar: 27, null: }
{null_array: , array: 2, scalar: 27, null: }
{null_array: , array: 3, scalar: 27, null: }

# named_struct arrays only
query ?
select named_struct('field_a', a, 'field_b', b) from values;
----
{field_a: 1, field_b: 1.1}
{field_a: 2, field_b: 2.2}
{field_a: 3, field_b: 3.3}

# named_struct scalars only
query ?
select named_struct('field_a', 1, 'field_b', 2);
----
{field_a: 1, field_b: 2}

statement ok
drop table values;

query T
select arrow_typeof(named_struct('first', 1, 'second', 2, 'third', 3));
----
Struct([Field { name: "first", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "second", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "third", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }])
Loading

0 comments on commit aa879bf

Please sign in to comment.