From 531e9ca3bf7aa1c89381dce2915c3a29c72a6adc Mon Sep 17 00:00:00 2001 From: Michele Vigilante Date: Wed, 14 Aug 2024 14:58:25 +0200 Subject: [PATCH 1/6] feat: implement json_get_array UDF --- src/json_get_array.rs | 130 ++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 4 ++ tests/main.rs | 26 +++++++++ 3 files changed, 160 insertions(+) create mode 100644 src/json_get_array.rs diff --git a/src/json_get_array.rs b/src/json_get_array.rs new file mode 100644 index 0000000..0a37b6a --- /dev/null +++ b/src/json_get_array.rs @@ -0,0 +1,130 @@ +use std::any::Any; +use std::sync::Arc; + +use arrow::array::{GenericListArray, ListBuilder, StringBuilder}; +use arrow_schema::{DataType, Field}; +use datafusion_common::arrow::array::ArrayRef; +use datafusion_common::{Result as DataFusionResult, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use jiter::Peek; + +use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; +use crate::common_macros::make_udf_function; + +struct StrArrayColumn { + rows: GenericListArray, +} + +impl FromIterator>> for StrArrayColumn { + fn from_iter>>>(iter: T) -> Self { + let string_builder = StringBuilder::new(); + let mut list_builder = ListBuilder::new(string_builder); + + for row in iter { + if let Some(row) = row { + for elem in row { + list_builder.values().append_value(elem); + } + + list_builder.append(true); + } else { + list_builder.append(false); + } + } + + Self { + rows: list_builder.finish(), + } + } +} + +make_udf_function!( + JsonGetArray, + json_get_array, + json_data path, + r#"Get an arrow array value from a JSON string by its "path""# +); + +#[derive(Debug)] +pub(super) struct JsonGetArray { + signature: Signature, + aliases: [String; 1], +} + +impl Default for JsonGetArray { + fn default() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + aliases: ["json_get_array".to_string()], + } + } +} + +impl ScalarUDFImpl for JsonGetArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.aliases[0].as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + check_args(arg_types, self.name()).map(|()| DataType::List(Field::new("item", DataType::Utf8, true).into())) + } + + fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { + invoke::>( + args, + jiter_json_get_array, + |c| Ok(Arc::new(c.rows) as ArrayRef), + |i| { + let string_builder = StringBuilder::new(); + let mut list_builder = ListBuilder::new(string_builder); + + if let Some(row) = i { + for elem in row { + list_builder.values().append_value(elem); + } + } + + ScalarValue::List(list_builder.finish().into()) + }, + ) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +fn jiter_json_get_array(json_data: Option<&str>, path: &[JsonPath]) -> Result, GetError> { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { + match peek { + Peek::Array => { + let mut peek_opt = jiter.known_array()?; + let mut array_values = Vec::new(); + + while let Some(peek) = peek_opt { + let start = jiter.current_index(); + jiter.known_skip(peek)?; + let object_slice = jiter.slice_to_current(start); + let object_string = std::str::from_utf8(object_slice)?; + + array_values.push(object_string.to_owned()); + + peek_opt = jiter.array_step()?; + } + + Ok(array_values) + } + _ => get_err!(), + } + } else { + get_err!() + } +} diff --git a/src/lib.rs b/src/lib.rs index c576794..8f55540 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ mod common_union; mod json_as_text; mod json_contains; mod json_get; +mod json_get_array; mod json_get_bool; mod json_get_float; mod json_get_int; @@ -22,6 +23,7 @@ pub mod functions { pub use crate::json_as_text::json_as_text; pub use crate::json_contains::json_contains; pub use crate::json_get::json_get; + pub use crate::json_get_array::json_get_array; pub use crate::json_get_bool::json_get_bool; pub use crate::json_get_float::json_get_float; pub use crate::json_get_int::json_get_int; @@ -34,6 +36,7 @@ pub mod udfs { pub use crate::json_as_text::json_as_text_udf; pub use crate::json_contains::json_contains_udf; pub use crate::json_get::json_get_udf; + pub use crate::json_get_array::json_get_array_udf; pub use crate::json_get_bool::json_get_bool_udf; pub use crate::json_get_float::json_get_float_udf; pub use crate::json_get_int::json_get_int_udf; @@ -54,6 +57,7 @@ pub mod udfs { pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { let functions: Vec> = vec![ json_get::json_get_udf(), + json_get_array::json_get_array_udf(), json_get_bool::json_get_bool_udf(), json_get_float::json_get_float_udf(), json_get_int::json_get_int_udf(), diff --git a/tests/main.rs b/tests/main.rs index 1bbd85c..9d52f03 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -476,6 +476,31 @@ async fn test_json_length_vec() { assert_batches_eq!(expected, &batches); } +#[tokio::test] +async fn test_json_get_arrow_array() { + let sql = r#"select name, json_get_array(json_data, 'foo') from test"#; + let batches = run_query(sql).await.unwrap(); + + let expected = [ + "+------------------+--------------------------------------------+", + "| name | json_get_array(test.json_data,Utf8(\"foo\")) |", + "+------------------+--------------------------------------------+", + "| object_foo | |", + "| object_foo_array | [1] |", + "| object_foo_obj | |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+--------------------------------------------+", + ]; + + assert_batches_eq!(expected, &batches); + + let batches = run_query_large(sql).await.unwrap(); + assert_batches_eq!(expected, &batches); +} + #[tokio::test] async fn test_no_args() { let err = run_query(r#"select json_len()"#).await.unwrap_err(); @@ -1131,6 +1156,7 @@ async fn test_long_arrow_cast() { assert_batches_eq!(expected, &batches); } +#[tokio::test] async fn test_arrow_cast_numeric() { let sql = r#"select ('{"foo": 420}'->'foo')::numeric = 420"#; let batches = run_query(sql).await.unwrap(); From 9807bc4ea95145fc10b03a0fe36110db6ca80ea0 Mon Sep 17 00:00:00 2001 From: Michele Vigilante Date: Fri, 16 Aug 2024 09:28:53 +0200 Subject: [PATCH 2/6] add column structs --- src/json_get_array.rs | 70 +++++++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/src/json_get_array.rs b/src/json_get_array.rs index 0a37b6a..9bbb5d5 100644 --- a/src/json_get_array.rs +++ b/src/json_get_array.rs @@ -11,33 +11,6 @@ use jiter::Peek; use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; use crate::common_macros::make_udf_function; -struct StrArrayColumn { - rows: GenericListArray, -} - -impl FromIterator>> for StrArrayColumn { - fn from_iter>>>(iter: T) -> Self { - let string_builder = StringBuilder::new(); - let mut list_builder = ListBuilder::new(string_builder); - - for row in iter { - if let Some(row) = row { - for elem in row { - list_builder.values().append_value(elem); - } - - list_builder.append(true); - } else { - list_builder.append(false); - } - } - - Self { - rows: list_builder.finish(), - } - } -} - make_udf_function!( JsonGetArray, json_get_array, @@ -78,7 +51,7 @@ impl ScalarUDFImpl for JsonGetArray { } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::>( + invoke::( args, jiter_json_get_array, |c| Ok(Arc::new(c.rows) as ArrayRef), @@ -87,7 +60,7 @@ impl ScalarUDFImpl for JsonGetArray { let mut list_builder = ListBuilder::new(string_builder); if let Some(row) = i { - for elem in row { + for elem in row.elements { list_builder.values().append_value(elem); } } @@ -102,12 +75,43 @@ impl ScalarUDFImpl for JsonGetArray { } } -fn jiter_json_get_array(json_data: Option<&str>, path: &[JsonPath]) -> Result, GetError> { +struct JsonArray { + rows: GenericListArray, +} + +struct JsonArrayField { + elements: Vec, +} + +impl FromIterator> for JsonArray { + fn from_iter>>(iter: T) -> Self { + let string_builder = StringBuilder::new(); + let mut list_builder = ListBuilder::new(string_builder); + + for row in iter { + if let Some(row) = row { + for elem in row.elements { + list_builder.values().append_value(elem); + } + + list_builder.append(true); + } else { + list_builder.append(false); + } + } + + Self { + rows: list_builder.finish(), + } + } +} + +fn jiter_json_get_array(json_data: Option<&str>, path: &[JsonPath]) -> Result { if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { match peek { Peek::Array => { let mut peek_opt = jiter.known_array()?; - let mut array_values = Vec::new(); + let mut elements = Vec::new(); while let Some(peek) = peek_opt { let start = jiter.current_index(); @@ -115,12 +119,12 @@ fn jiter_json_get_array(json_data: Option<&str>, path: &[JsonPath]) -> Result get_err!(), } From 94e9dddf0090274a089e25c17fe980d7c451e37c Mon Sep 17 00:00:00 2001 From: Michele Vigilante Date: Fri, 16 Aug 2024 09:41:38 +0200 Subject: [PATCH 3/6] add tests, run clippy fix --- src/rewrite.rs | 2 +- tests/main.rs | 148 ++++++++++++++++++++++----------------------- tests/utils/mod.rs | 7 ++- 3 files changed, 78 insertions(+), 79 deletions(-) diff --git a/src/rewrite.rs b/src/rewrite.rs index 60403c8..814bd62 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -78,7 +78,7 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option> { fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> { match expr { Expr::ScalarFunction(func) => Some(func), - Expr::Alias(alias) => extract_scalar_function(&*alias.expr), + Expr::Alias(alias) => extract_scalar_function(&alias.expr), _ => None, } } diff --git a/tests/main.rs b/tests/main.rs index 9d52f03..e07fe33 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -62,23 +62,23 @@ async fn test_json_get_union() { .unwrap(); let expected = [ - "+------------------+--------------------------------------+", - "| name | json_get(test.json_data,Utf8(\"foo\")) |", - "+------------------+--------------------------------------+", - "| object_foo | {str=abc} |", - "| object_foo_array | {array=[1]} |", - "| object_foo_obj | {object={}} |", - "| object_foo_null | {null=} |", - "| object_bar | {null=} |", - "| list_foo | {null=} |", - "| invalid_json | {null=} |", - "+------------------+--------------------------------------+", + "+------------------+--------------------------------------------------------------+", + "| name | json_get(test.json_data,Utf8(\"foo\")) |", + "+------------------+--------------------------------------------------------------+", + "| object_foo | {str=abc} |", + "| object_foo_array | {array=[1, true, {\"nested_foo\": \"baz\", \"nested_bar\": null}]} |", + "| object_foo_obj | {object={}} |", + "| object_foo_null | {null=} |", + "| object_bar | {null=} |", + "| list_foo | {null=} |", + "| invalid_json | {null=} |", + "+------------------+--------------------------------------------------------------+", ]; assert_batches_eq!(expected, &batches); } #[tokio::test] -async fn test_json_get_array() { +async fn test_json_get_array_index() { let sql = "select json_get('[1, 2, 3]', 2)"; let batches = run_query(sql).await.unwrap(); let (value_type, value_repr) = display_val(batches).await; @@ -196,11 +196,11 @@ async fn test_json_get_no_path() { let batches = run_query(r#"select json_get('"foo"')::string"#).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::Utf8, "foo".to_string())); - let batches = run_query(r#"select json_get('123')::int"#).await.unwrap(); + let batches = run_query(r"select json_get('123')::int").await.unwrap(); assert_eq!(display_val(batches).await, (DataType::Int64, "123".to_string())); - let batches = run_query(r#"select json_get('true')::int"#).await.unwrap(); - assert_eq!(display_val(batches).await, (DataType::Int64, "".to_string())); + let batches = run_query(r"select json_get('true')::int").await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Int64, String::new())); } #[tokio::test] @@ -314,17 +314,17 @@ async fn test_json_get_json() { .unwrap(); let expected = [ - "+------------------+-------------------------------------------+", - "| name | json_get_json(test.json_data,Utf8(\"foo\")) |", - "+------------------+-------------------------------------------+", - "| object_foo | \"abc\" |", - "| object_foo_array | [1] |", - "| object_foo_obj | {} |", - "| object_foo_null | null |", - "| object_bar | |", - "| list_foo | |", - "| invalid_json | |", - "+------------------+-------------------------------------------+", + "+------------------+------------------------------------------------------+", + "| name | json_get_json(test.json_data,Utf8(\"foo\")) |", + "+------------------+------------------------------------------------------+", + "| object_foo | \"abc\" |", + "| object_foo_array | [1, true, {\"nested_foo\": \"baz\", \"nested_bar\": null}] |", + "| object_foo_obj | {} |", + "| object_foo_null | null |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+------------------------------------------------------+", ]; assert_batches_eq!(expected, &batches); } @@ -349,7 +349,7 @@ async fn test_json_length_object() { let batches = run_query(sql).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::UInt64, "3".to_string())); - let sql = r#"select json_length('{}')"#; + let sql = r"select json_length('{}')"; let batches = run_query(sql).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::UInt64, "0".to_string())); } @@ -358,7 +358,7 @@ async fn test_json_length_object() { async fn test_json_length_string() { let sql = r#"select json_length('"foobar"')"#; let batches = run_query(sql).await.unwrap(); - assert_eq!(display_val(batches).await, (DataType::UInt64, "".to_string())); + assert_eq!(display_val(batches).await, (DataType::UInt64, String::new())); } #[tokio::test] @@ -369,7 +369,7 @@ async fn test_json_length_object_nested() { let sql = r#"select json_length('{"a": 1, "b": 2, "c": []}', 'b')"#; let batches = run_query(sql).await.unwrap(); - assert_eq!(display_val(batches).await, (DataType::UInt64, "".to_string())); + assert_eq!(display_val(batches).await, (DataType::UInt64, String::new())); } #[tokio::test] @@ -454,7 +454,7 @@ async fn test_json_contains_large_both_params() { #[tokio::test] async fn test_json_length_vec() { - let sql = r#"select name, json_len(json_data) as len from test"#; + let sql = r"select name, json_len(json_data) as len from test"; let batches = run_query(sql).await.unwrap(); let expected = [ @@ -477,22 +477,18 @@ async fn test_json_length_vec() { } #[tokio::test] -async fn test_json_get_arrow_array() { - let sql = r#"select name, json_get_array(json_data, 'foo') from test"#; +async fn test_json_get_array() { + let sql = r"select name, unnest(json_get_array(json_data, 'foo')) from test"; let batches = run_query(sql).await.unwrap(); let expected = [ - "+------------------+--------------------------------------------+", - "| name | json_get_array(test.json_data,Utf8(\"foo\")) |", - "+------------------+--------------------------------------------+", - "| object_foo | |", - "| object_foo_array | [1] |", - "| object_foo_obj | |", - "| object_foo_null | |", - "| object_bar | |", - "| list_foo | |", - "| invalid_json | |", - "+------------------+--------------------------------------------+", + "+------------------+----------------------------------------------------+", + "| name | unnest(json_get_array(test.json_data,Utf8(\"foo\"))) |", + "+------------------+----------------------------------------------------+", + "| object_foo_array | 1 |", + "| object_foo_array | true |", + "| object_foo_array | {\"nested_foo\": \"baz\", \"nested_bar\": null} |", + "+------------------+----------------------------------------------------+", ]; assert_batches_eq!(expected, &batches); @@ -503,7 +499,7 @@ async fn test_json_get_arrow_array() { #[tokio::test] async fn test_no_args() { - let err = run_query(r#"select json_len()"#).await.unwrap_err(); + let err = run_query(r"select json_len()").await.unwrap_err(); assert!(err .to_string() .contains("No function matches the given name and argument types 'json_length()'.")); @@ -586,10 +582,10 @@ async fn test_json_get_nested_collapsed() { #[tokio::test] async fn test_json_get_cte() { // avoid auto-un-nesting with a CTE - let sql = r#" + let sql = r" with t as (select name, json_get(json_data, 'foo') j from test) select name, json_get(j, 0) v from t - "#; + "; let expected = [ "+------------------+---------+", "| name | v |", @@ -611,11 +607,11 @@ async fn test_json_get_cte() { #[tokio::test] async fn test_plan_json_get_cte() { // avoid auto-unnesting with a CTE - let sql = r#" + let sql = r" explain with t as (select name, json_get(json_data, 'foo') j from test) select name, json_get(j, 0) v from t - "#; + "; let expected = [ "Projection: t.name, json_get(t.j, Int64(0)) AS v", " SubqueryAlias: t", @@ -758,24 +754,24 @@ async fn test_arrow() { let batches = run_query("select name, json_data->'foo' from test").await.unwrap(); let expected = [ - "+------------------+--------------------------+", - "| name | json_data -> Utf8(\"foo\") |", - "+------------------+--------------------------+", - "| object_foo | {str=abc} |", - "| object_foo_array | {array=[1]} |", - "| object_foo_obj | {object={}} |", - "| object_foo_null | {null=} |", - "| object_bar | {null=} |", - "| list_foo | {null=} |", - "| invalid_json | {null=} |", - "+------------------+--------------------------+", + "+------------------+--------------------------------------------------------------+", + "| name | json_data -> Utf8(\"foo\") |", + "+------------------+--------------------------------------------------------------+", + "| object_foo | {str=abc} |", + "| object_foo_array | {array=[1, true, {\"nested_foo\": \"baz\", \"nested_bar\": null}]} |", + "| object_foo_obj | {object={}} |", + "| object_foo_null | {null=} |", + "| object_bar | {null=} |", + "| list_foo | {null=} |", + "| invalid_json | {null=} |", + "+------------------+--------------------------------------------------------------+", ]; assert_batches_eq!(expected, &batches); } #[tokio::test] async fn test_plan_arrow() { - let lines = logical_plan(r#"explain select json_data->'foo' from test"#).await; + let lines = logical_plan(r"explain select json_data->'foo' from test").await; let expected = [ "Projection: json_get(test.json_data, Utf8(\"foo\")) AS json_data -> Utf8(\"foo\")", @@ -790,24 +786,24 @@ async fn test_long_arrow() { let batches = run_query("select name, json_data->>'foo' from test").await.unwrap(); let expected = [ - "+------------------+---------------------------+", - "| name | json_data ->> Utf8(\"foo\") |", - "+------------------+---------------------------+", - "| object_foo | abc |", - "| object_foo_array | [1] |", - "| object_foo_obj | {} |", - "| object_foo_null | |", - "| object_bar | |", - "| list_foo | |", - "| invalid_json | |", - "+------------------+---------------------------+", + "+------------------+------------------------------------------------------+", + "| name | json_data ->> Utf8(\"foo\") |", + "+------------------+------------------------------------------------------+", + "| object_foo | abc |", + "| object_foo_array | [1, true, {\"nested_foo\": \"baz\", \"nested_bar\": null}] |", + "| object_foo_obj | {} |", + "| object_foo_null | |", + "| object_bar | |", + "| list_foo | |", + "| invalid_json | |", + "+------------------+------------------------------------------------------+", ]; assert_batches_eq!(expected, &batches); } #[tokio::test] async fn test_plan_long_arrow() { - let lines = logical_plan(r#"explain select json_data->>'foo' from test"#).await; + let lines = logical_plan(r"explain select json_data->>'foo' from test").await; let expected = [ "Projection: json_as_text(test.json_data, Utf8(\"foo\")) AS json_data ->> Utf8(\"foo\")", @@ -858,7 +854,7 @@ async fn test_arrow_cast_int() { #[tokio::test] async fn test_plan_arrow_cast_int() { - let lines = logical_plan(r#"explain select (json_data->'foo')::int from test"#).await; + let lines = logical_plan(r"explain select (json_data->'foo')::int from test").await; let expected = [ "Projection: json_get_int(test.json_data, Utf8(\"foo\")) AS json_data -> Utf8(\"foo\")", @@ -890,7 +886,7 @@ async fn test_arrow_double_nested() { #[tokio::test] async fn test_plan_arrow_double_nested() { - let lines = logical_plan(r#"explain select json_data->'foo'->0 from test"#).await; + let lines = logical_plan(r"explain select json_data->'foo'->0 from test").await; let expected = [ "Projection: json_get(test.json_data, Utf8(\"foo\"), Int64(0)) AS json_data -> Utf8(\"foo\") -> Int64(0)", @@ -924,7 +920,7 @@ async fn test_arrow_double_nested_cast() { #[tokio::test] async fn test_plan_arrow_double_nested_cast() { - let lines = logical_plan(r#"explain select (json_data->'foo'->0)::int from test"#).await; + let lines = logical_plan(r"explain select (json_data->'foo'->0)::int from test").await; let expected = [ "Projection: json_get_int(test.json_data, Utf8(\"foo\"), Int64(0)) AS json_data -> Utf8(\"foo\") -> Int64(0)", @@ -972,7 +968,7 @@ async fn test_arrow_nested_double_columns() { async fn test_lexical_precedence_wrong() { let sql = r#"select '{"a": "b"}'->>'a'='b' as v"#; let err = run_query(sql).await.unwrap_err(); - assert_eq!(err.to_string(), "Error during planning: Unexpected argument type to 'json_as_text' at position 2, expected string or int, got Boolean.") + assert_eq!(err.to_string(), "Error during planning: Unexpected argument type to 'json_as_text' at position 2, expected string or int, got Boolean."); } #[tokio::test] diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index a5279f9..f160064 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -19,7 +19,10 @@ async fn create_test_table(large_utf8: bool) -> Result { let test_data = [ ("object_foo", r#" {"foo": "abc"} "#), - ("object_foo_array", r#" {"foo": [1]} "#), + ( + "object_foo_array", + r#" {"foo": [1, true, {"nested_foo": "baz", "nested_bar": null}]} "#, + ), ("object_foo_obj", r#" {"foo": {}} "#), ("object_foo_null", r#" {"foo": null} "#), ("object_bar", r#" {"bar": true} "#), @@ -149,5 +152,5 @@ pub async fn logical_plan(sql: &str) -> Vec { let batches = run_query(sql).await.unwrap(); let plan_col = batches[0].column(1).as_any().downcast_ref::().unwrap(); let logical_plan = plan_col.value(0); - logical_plan.split('\n').map(|s| s.to_string()).collect() + logical_plan.split('\n').map(std::string::ToString::to_string).collect() } From 02ab13403fdd7e15575b5f15d863bf04497d5250 Mon Sep 17 00:00:00 2001 From: Michele Vigilante Date: Fri, 16 Aug 2024 09:49:17 +0200 Subject: [PATCH 4/6] remove async from create_test_table --- tests/utils/mod.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/utils/mod.rs b/tests/utils/mod.rs index f160064..4f98056 100644 --- a/tests/utils/mod.rs +++ b/tests/utils/mod.rs @@ -12,7 +12,7 @@ use datafusion_common::ParamValues; use datafusion_execution::config::SessionConfig; use datafusion_functions_json::register_all; -async fn create_test_table(large_utf8: bool) -> Result { +fn create_test_table(large_utf8: bool) -> Result { let config = SessionConfig::new().set_str("datafusion.sql_parser.dialect", "postgres"); let mut ctx = SessionContext::new_with_config(config); register_all(&mut ctx)?; @@ -117,12 +117,12 @@ async fn create_test_table(large_utf8: bool) -> Result { } pub async fn run_query(sql: &str) -> Result> { - let ctx = create_test_table(false).await?; + let ctx = create_test_table(false)?; ctx.sql(sql).await?.collect().await } pub async fn run_query_large(sql: &str) -> Result> { - let ctx = create_test_table(true).await?; + let ctx = create_test_table(true)?; ctx.sql(sql).await?.collect().await } @@ -131,7 +131,7 @@ pub async fn run_query_params( large_utf8: bool, query_values: impl Into, ) -> Result> { - let ctx = create_test_table(large_utf8).await?; + let ctx = create_test_table(large_utf8)?; ctx.sql(sql).await?.with_param_values(query_values)?.collect().await } From d1144b91cdf3c3dad80b7320ca0f4ba42e559cd2 Mon Sep 17 00:00:00 2001 From: Michele Vigilante Date: Fri, 16 Aug 2024 12:41:05 +0200 Subject: [PATCH 5/6] implement using JsonArrayField --- src/common_union.rs | 118 ++++++++++++++++++++++++++++++++++-------- src/json_get.rs | 2 +- src/json_get_array.rs | 61 ++++------------------ 3 files changed, 107 insertions(+), 74 deletions(-) diff --git a/src/common_union.rs b/src/common_union.rs index ae4433b..780a651 100644 --- a/src/common_union.rs +++ b/src/common_union.rs @@ -1,6 +1,9 @@ use std::sync::{Arc, OnceLock}; -use arrow::array::{Array, ArrayRef, BooleanArray, Float64Array, Int64Array, NullArray, StringArray, UnionArray}; +use arrow::array::{ + Array, ArrayRef, BooleanArray, Float64Array, Int64Array, ListArray, ListBuilder, NullArray, StringArray, + StringBuilder, UnionArray, +}; use arrow::buffer::Buffer; use arrow_schema::{DataType, Field, UnionFields, UnionMode}; use datafusion_common::ScalarValue; @@ -46,7 +49,7 @@ pub(crate) struct JsonUnion { ints: Vec>, floats: Vec>, strings: Vec>, - arrays: Vec>, + arrays: Vec>>, objects: Vec>, type_ids: Vec, index: usize, @@ -93,24 +96,6 @@ impl JsonUnion { } } -/// So we can do `collect::()` -impl FromIterator> for JsonUnion { - fn from_iter>>(iter: I) -> Self { - let inner = iter.into_iter(); - let (lower, upper) = inner.size_hint(); - let mut union = Self::new(upper.unwrap_or(lower)); - - for opt_field in inner { - if let Some(union_field) = opt_field { - union.push(union_field); - } else { - union.push_none(); - } - } - union - } -} - impl TryFrom for UnionArray { type Error = arrow::error::ArrowError; @@ -121,13 +106,42 @@ impl TryFrom for UnionArray { Arc::new(Int64Array::from(value.ints)), Arc::new(Float64Array::from(value.floats)), Arc::new(StringArray::from(value.strings)), - Arc::new(StringArray::from(value.arrays)), + Arc::new(StringArray::from( + value + .arrays + .into_iter() + .map(|r| r.map(|e| e.join(","))) + .collect::>(), + )), Arc::new(StringArray::from(value.objects)), ]; UnionArray::try_new(union_fields(), Buffer::from_vec(value.type_ids).into(), None, children) } } +impl TryFrom for ListArray { + type Error = arrow::error::ArrowError; + + fn try_from(value: JsonUnion) -> Result { + let string_builder = StringBuilder::new(); + let mut list_builder = ListBuilder::new(string_builder); + + for row in value.arrays { + if let Some(row) = row { + for elem in row { + list_builder.values().append_value(elem); + } + + list_builder.append(true); + } else { + list_builder.append(false); + } + } + + Ok(list_builder.finish()) + } +} + #[derive(Debug)] pub(crate) enum JsonUnionField { JsonNull, @@ -135,7 +149,7 @@ pub(crate) enum JsonUnionField { Int(i64), Float(f64), Str(String), - Array(String), + Array(Vec), Object(String), } @@ -193,7 +207,65 @@ impl From for ScalarValue { JsonUnionField::Bool(b) => Self::Boolean(Some(b)), JsonUnionField::Int(i) => Self::Int64(Some(i)), JsonUnionField::Float(f) => Self::Float64(Some(f)), - JsonUnionField::Str(s) | JsonUnionField::Array(s) | JsonUnionField::Object(s) => Self::Utf8(Some(s)), + JsonUnionField::Array(a) => Self::List(Self::new_list_nullable( + &a.into_iter().map(|e| Self::Utf8(Some(e))).collect::>(), + &DataType::Utf8, + )), + JsonUnionField::Str(s) | JsonUnionField::Object(s) => Self::Utf8(Some(s)), + } + } +} + +/// So we can do `collect::()` +impl FromIterator> for JsonUnion { + fn from_iter>>(iter: I) -> Self { + let inner = iter.into_iter(); + let (lower, upper) = inner.size_hint(); + let mut union = Self::new(upper.unwrap_or(lower)); + + for opt_field in inner { + if let Some(union_field) = opt_field { + union.push(union_field); + } else { + union.push_none(); + } } + union + } +} + +#[derive(Debug)] +pub(crate) struct JsonArrayField(pub(crate) Vec); + +impl From for ScalarValue { + fn from(JsonArrayField(elems): JsonArrayField) -> Self { + Self::List(Self::new_list_nullable( + &elems.into_iter().map(|e| Self::Utf8(Some(e))).collect::>(), + &DataType::Utf8, + )) + } +} + +impl From for JsonUnionField { + fn from(JsonArrayField(elems): JsonArrayField) -> Self { + JsonUnionField::Array(elems) + } +} + +impl FromIterator> for JsonUnion { + fn from_iter>>(iter: T) -> Self { + let inner = iter.into_iter(); + let (lower, upper) = inner.size_hint(); + let mut union = Self::new(upper.unwrap_or(lower)); + + for opt_field in inner { + if let Some(array_field) = opt_field { + union.push(array_field.into()); + } else { + union.push_none(); + } + } + + union } } diff --git a/src/json_get.rs b/src/json_get.rs index b1e4810..95db10b 100644 --- a/src/json_get.rs +++ b/src/json_get.rs @@ -93,7 +93,7 @@ fn build_union(jiter: &mut Jiter, peek: Peek) -> Result { let start = jiter.current_index(); diff --git a/src/json_get_array.rs b/src/json_get_array.rs index 9bbb5d5..063d228 100644 --- a/src/json_get_array.rs +++ b/src/json_get_array.rs @@ -1,7 +1,7 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{GenericListArray, ListBuilder, StringBuilder}; +use arrow::array::ListArray; use arrow_schema::{DataType, Field}; use datafusion_common::arrow::array::ArrayRef; use datafusion_common::{Result as DataFusionResult, ScalarValue}; @@ -10,6 +10,7 @@ use jiter::Peek; use crate::common::{check_args, get_err, invoke, jiter_json_find, GetError, JsonPath}; use crate::common_macros::make_udf_function; +use crate::common_union::{JsonArrayField, JsonUnion}; make_udf_function!( JsonGetArray, @@ -51,23 +52,14 @@ impl ScalarUDFImpl for JsonGetArray { } fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { - invoke::( - args, - jiter_json_get_array, - |c| Ok(Arc::new(c.rows) as ArrayRef), - |i| { - let string_builder = StringBuilder::new(); - let mut list_builder = ListBuilder::new(string_builder); - - if let Some(row) = i { - for elem in row.elements { - list_builder.values().append_value(elem); - } - } - - ScalarValue::List(list_builder.finish().into()) - }, - ) + let to_array = |c: JsonUnion| { + let array: ListArray = c.try_into()?; + Ok(Arc::new(array) as ArrayRef) + }; + + invoke::(args, jiter_json_get_array, to_array, |i| { + i.map_or_else(|| ScalarValue::Null, Into::into) + }) } fn aliases(&self) -> &[String] { @@ -75,37 +67,6 @@ impl ScalarUDFImpl for JsonGetArray { } } -struct JsonArray { - rows: GenericListArray, -} - -struct JsonArrayField { - elements: Vec, -} - -impl FromIterator> for JsonArray { - fn from_iter>>(iter: T) -> Self { - let string_builder = StringBuilder::new(); - let mut list_builder = ListBuilder::new(string_builder); - - for row in iter { - if let Some(row) = row { - for elem in row.elements { - list_builder.values().append_value(elem); - } - - list_builder.append(true); - } else { - list_builder.append(false); - } - } - - Self { - rows: list_builder.finish(), - } - } -} - fn jiter_json_get_array(json_data: Option<&str>, path: &[JsonPath]) -> Result { if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { match peek { @@ -124,7 +85,7 @@ fn jiter_json_get_array(json_data: Option<&str>, path: &[JsonPath]) -> Result get_err!(), } From a74c57d51b9536b587ebc5b2d0035d118e7ceaf2 Mon Sep 17 00:00:00 2001 From: Michele Vigilante Date: Fri, 16 Aug 2024 12:59:31 +0200 Subject: [PATCH 6/6] revert behavior change --- src/common_union.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/common_union.rs b/src/common_union.rs index 780a651..a7278e4 100644 --- a/src/common_union.rs +++ b/src/common_union.rs @@ -207,10 +207,7 @@ impl From for ScalarValue { JsonUnionField::Bool(b) => Self::Boolean(Some(b)), JsonUnionField::Int(i) => Self::Int64(Some(i)), JsonUnionField::Float(f) => Self::Float64(Some(f)), - JsonUnionField::Array(a) => Self::List(Self::new_list_nullable( - &a.into_iter().map(|e| Self::Utf8(Some(e))).collect::>(), - &DataType::Utf8, - )), + JsonUnionField::Array(a) => Self::Utf8(Some(a.join(","))), JsonUnionField::Str(s) | JsonUnionField::Object(s) => Self::Utf8(Some(s)), } }