Skip to content

Commit

Permalink
add example
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangzhx committed Jul 27, 2023
1 parent 7231163 commit d1b2653
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 76 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@ harness = false
name = "sqllogictests"
path = "tests/sqllogictests/src/main.rs"


[[example]]
name = "retention"
path = "examples/retention.rs"
70 changes: 70 additions & 0 deletions examples/retention.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use std::sync::Arc;

use datafusion::arrow::array::{Int32Array, StringArray};
use datafusion::arrow::datatypes::Field;
use datafusion::arrow::util::pretty::print_batches;
use datafusion::arrow::{datatypes::DataType, record_batch::RecordBatch};
use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_uba::create_retention_count;

#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;

ctx.table("event").await?;

ctx.register_udaf(create_retention_count());

let results = ctx
.sql(
"select distinct_id,retention_count(\
case when event='add' and ds=20230101 then true else false end,\
case when event='buy' and ds between 20230101 and 20230102 then true else false end,\
ds-20230101 \
) as stats from event group by distinct_id",
)
.await?
.collect()
.await?;

print_batches(&results)?;
Ok(())
}

fn create_context() -> Result<SessionContext> {
use datafusion::arrow::datatypes::Schema;
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("distinct_id", DataType::Int32, false),
Field::new("event", DataType::Utf8, false),
Field::new("ds", DataType::Int32, false),
]));

let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["add", "add", "add"])),
Arc::new(Int32Array::from(vec![20230101, 20230101, 20230101])),
],
)?;

let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["buy", "buy", "buy"])),
Arc::new(Int32Array::from(vec![20230101, 20230101, 20230101])),
],
)?;

// declare a new context. In spark API, this corresponds to a new spark SQLsession
let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("event", Arc::new(provider))?;
Ok(ctx)
}
76 changes: 0 additions & 76 deletions src/retention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,79 +140,3 @@ impl Accumulator for RetentionCount {
})
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use crate::create_retention_count;
use datafusion::arrow::array::{Int32Array, StringArray};
use datafusion::arrow::datatypes::Field;
use datafusion::arrow::util::pretty::print_batches;
use datafusion::arrow::{datatypes::DataType, record_batch::RecordBatch};
use datafusion::error::Result;
use datafusion::prelude::*;

#[tokio::test]
async fn simple_retention_count() -> Result<()> {
let ctx = create_context()?;

ctx.table("event").await?;

ctx.register_udaf(create_retention_count());

let results = ctx
.sql(
"select distinct_id,retention_count(\
case when event='add' and ds=20230101 then true else false end,\
case when event='buy' and ds between 20230101 and 20230102 then true else false end,\
ds-20230101 \
) as stats from event group by distinct_id",
)
.await?
.collect()
.await?;

print_batches(&results)?;

Ok(())
}

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::arrow::datatypes::Schema;
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("distinct_id", DataType::Int32, false),
Field::new("event", DataType::Utf8, false),
Field::new("ds", DataType::Int32, false),
]));

let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["add", "add", "add"])),
Arc::new(Int32Array::from(vec![20230101, 20230101, 20230101])),
],
)?;

let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(StringArray::from(vec!["buy", "buy", "buy"])),
Arc::new(Int32Array::from(vec![20230101, 20230101, 20230101])),
],
)?;

// declare a new context. In spark API, this corresponds to a new spark SQLsession
let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("event", Arc::new(provider))?;
Ok(ctx)
}
}

0 comments on commit d1b2653

Please sign in to comment.