From d1b26532066cc5e75bfe0af55bb9c2f8d050520e Mon Sep 17 00:00:00 2001 From: jiangzhx Date: Thu, 27 Jul 2023 13:14:49 +0800 Subject: [PATCH] add example --- Cargo.toml | 4 +++ examples/retention.rs | 70 +++++++++++++++++++++++++++++++++++++++ src/retention.rs | 76 ------------------------------------------- 3 files changed, 74 insertions(+), 76 deletions(-) create mode 100644 examples/retention.rs diff --git a/Cargo.toml b/Cargo.toml index f8b0654..9b49785 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,3 +38,7 @@ harness = false name = "sqllogictests" path = "tests/sqllogictests/src/main.rs" + +[[example]] +name = "retention" +path = "examples/retention.rs" \ No newline at end of file diff --git a/examples/retention.rs b/examples/retention.rs new file mode 100644 index 0000000..d33b2a7 --- /dev/null +++ b/examples/retention.rs @@ -0,0 +1,70 @@ +use std::sync::Arc; + +use datafusion::arrow::array::{Int32Array, StringArray}; +use datafusion::arrow::datatypes::Field; +use datafusion::arrow::util::pretty::print_batches; +use datafusion::arrow::{datatypes::DataType, record_batch::RecordBatch}; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_uba::create_retention_count; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context()?; + + ctx.table("event").await?; + + ctx.register_udaf(create_retention_count()); + + let results = ctx + .sql( + "select distinct_id,retention_count(\ + case when event='add' and ds=20230101 then true else false end,\ + case when event='buy' and ds between 20230101 and 20230102 then true else false end,\ + ds-20230101 \ + ) as stats from event group by distinct_id", + ) + .await? + .collect() + .await?; + + print_batches(&results)?; + Ok(()) +} + +fn create_context() -> Result { + use datafusion::arrow::datatypes::Schema; + use datafusion::datasource::MemTable; + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("distinct_id", DataType::Int32, false), + Field::new("event", DataType::Utf8, false), + Field::new("ds", DataType::Int32, false), + ])); + + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["add", "add", "add"])), + Arc::new(Int32Array::from(vec![20230101, 20230101, 20230101])), + ], + )?; + + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["buy", "buy", "buy"])), + Arc::new(Int32Array::from(vec![20230101, 20230101, 20230101])), + ], + )?; + + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; + ctx.register_table("event", Arc::new(provider))?; + Ok(ctx) +} diff --git a/src/retention.rs b/src/retention.rs index ab75446..7307470 100644 --- a/src/retention.rs +++ b/src/retention.rs @@ -140,79 +140,3 @@ impl Accumulator for RetentionCount { }) } } - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use crate::create_retention_count; - use datafusion::arrow::array::{Int32Array, StringArray}; - use datafusion::arrow::datatypes::Field; - use datafusion::arrow::util::pretty::print_batches; - use datafusion::arrow::{datatypes::DataType, record_batch::RecordBatch}; - use datafusion::error::Result; - use datafusion::prelude::*; - - #[tokio::test] - async fn simple_retention_count() -> Result<()> { - let ctx = create_context()?; - - ctx.table("event").await?; - - ctx.register_udaf(create_retention_count()); - - let results = ctx - .sql( - "select distinct_id,retention_count(\ - case when event='add' and ds=20230101 then true else false end,\ - case when event='buy' and ds between 20230101 and 20230102 then true else false end,\ - ds-20230101 \ - ) as stats from event group by distinct_id", - ) - .await? - .collect() - .await?; - - print_batches(&results)?; - - Ok(()) - } - - // create local session context with an in-memory table - fn create_context() -> Result { - use datafusion::arrow::datatypes::Schema; - use datafusion::datasource::MemTable; - // define a schema. - let schema = Arc::new(Schema::new(vec![ - Field::new("distinct_id", DataType::Int32, false), - Field::new("event", DataType::Utf8, false), - Field::new("ds", DataType::Int32, false), - ])); - - let batch1 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(StringArray::from(vec!["add", "add", "add"])), - Arc::new(Int32Array::from(vec![20230101, 20230101, 20230101])), - ], - )?; - - let batch2 = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(StringArray::from(vec!["buy", "buy", "buy"])), - Arc::new(Int32Array::from(vec![20230101, 20230101, 20230101])), - ], - )?; - - // declare a new context. In spark API, this corresponds to a new spark SQLsession - let ctx = SessionContext::new(); - - // declare a table in memory. In spark API, this corresponds to createDataFrame(...). - let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; - ctx.register_table("event", Arc::new(provider))?; - Ok(ctx) - } -}