diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index e414e594f9080..01d92aa122b7e 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -37,14 +37,14 @@ use std::fmt::Debug; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; +/// Access a sub field of a nested type, such as `Field` or `List` #[derive(Clone, Hash, Debug)] pub enum GetFieldAccessExpr { - /// returns the field `struct[field]`. For example `struct["name"]` + /// Named field, For example `struct["name"]` NamedStructField { name: ScalarValue }, - /// single list index - // list[i] + /// Single list index, for example: `list[i]` ListIndex { key: Arc }, - /// list range `list[i:j]` + /// List range, for example `list[i:j]` ListRange { start: Arc, stop: Arc, @@ -82,12 +82,36 @@ pub struct GetIndexedFieldExpr { } impl GetIndexedFieldExpr { - /// Create new get field expression + /// Create new [`GetIndexedFieldExpr`] pub fn new(arg: Arc, field: GetFieldAccessExpr) -> Self { Self { arg, field } } - /// Get the input field + /// Create a new [`GetIndexedFieldExpr`] for accessing the named field + pub fn new_field(arg: Arc, name: impl Into) -> Self { + Self::new( + arg, + GetFieldAccessExpr::NamedStructField { + name: ScalarValue::Utf8(Some(name.into())), + }, + ) + } + + /// Create a new [`GetIndexedFieldExpr`] for accessing the specified index + pub fn new_index(arg: Arc, key: Arc) -> Self { + Self::new(arg, GetFieldAccessExpr::ListIndex { key }) + } + + /// Create a new [`GetIndexedFieldExpr`] for accessing the range + pub fn new_range( + arg: Arc, + start: Arc, + stop: Arc, + ) -> Self { + Self::new(arg, GetFieldAccessExpr::ListRange { start, stop }) + } + + /// Get the description of what field should be accessed pub fn field(&self) -> &GetFieldAccessExpr { &self.field } @@ -286,12 +310,7 @@ mod tests { let expr = col("str", &schema).unwrap(); // only one row should be processed let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?; - let expr = Arc::new(GetIndexedFieldExpr::new( - expr, - GetFieldAccessExpr::NamedStructField { - name: ScalarValue::Utf8(Some(String::from("a"))), - }, - )); + let expr = Arc::new(GetIndexedFieldExpr::new_field(expr, "a")); let result = expr.evaluate(&batch)?.into_array(1); let result = as_boolean_array(&result).expect("failed to downcast to BooleanArray"); @@ -348,10 +367,7 @@ mod tests { Arc::new(schema), vec![Arc::new(list_col), Arc::new(key_col)], )?; - let expr = Arc::new(GetIndexedFieldExpr::new( - expr, - GetFieldAccessExpr::ListIndex { key }, - )); + let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); let result = expr.evaluate(&batch)?.into_array(1); let result = as_string_array(&result).expect("failed to downcast to ListArray"); let expected = StringArray::from(expected_list); @@ -387,10 +403,7 @@ mod tests { Arc::new(schema), vec![Arc::new(list_col), Arc::new(start_col), Arc::new(stop_col)], )?; - let expr = Arc::new(GetIndexedFieldExpr::new( - expr, - GetFieldAccessExpr::ListRange { start, stop }, - )); + let expr = Arc::new(GetIndexedFieldExpr::new_range(expr, start, stop)); let result = expr.evaluate(&batch)?.into_array(1); let result = as_list_array(&result).expect("failed to downcast to ListArray"); let (expected, _, _) = @@ -411,10 +424,7 @@ mod tests { Arc::new(schema), vec![Arc::new(list_builder.finish()), key_array], )?; - let expr = Arc::new(GetIndexedFieldExpr::new( - expr, - GetFieldAccessExpr::ListIndex { key }, - )); + let expr = Arc::new(GetIndexedFieldExpr::new_index(expr, key)); let result = expr.evaluate(&batch)?.into_array(batch.num_rows()); assert!(result.is_null(0)); Ok(()) @@ -435,10 +445,7 @@ mod tests { Arc::new(schema), vec![Arc::new(list_builder.finish()), Arc::new(key_array)], )?; - let expr = Arc::new(GetIndexedFieldExpr::new( - expr, - GetFieldAccessExpr::ListIndex { key }, - )); + let expr = Arc::new(GetIndexedFieldExpr::new_index(key)); let result = expr.evaluate(&batch)?.into_array(1); assert!(result.is_null(0)); Ok(())