Skip to content

Commit

Permalink
add retention example work with parquet file
Browse files Browse the repository at this point in the history
  • Loading branch information
jiangzhx committed Aug 2, 2023
1 parent 984a8cc commit 7f19d24
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
51 changes: 51 additions & 0 deletions examples/retention_parquet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use std::sync::Arc;

use datafusion::arrow::util::pretty::print_batches;
use datafusion::datasource::MemTable;
use datafusion::error::Result;
use datafusion::prelude::*;
use datafusion_uba::retention::{create_retention_count, create_retention_sum};
use datafusion_uba::test_util;

#[tokio::main]
async fn main() -> Result<()> {
let ctx = SessionContext::new();

ctx.register_parquet(
"event",
format!("{}/event.parquet", test_util::parquet_test_data()).as_str(),
Default::default(),
)
.await
.unwrap();

ctx.register_udaf(create_retention_count());
ctx.register_udaf(create_retention_sum());

let df = ctx
.sql(
"select distinct_id,retention_count(\
case when xwhat='$startup' then true else false end,\
case when xwhat='$pageview' then true else false end,\
20230107-20230101,\
ds-20230101 \
) as stats \
from event group by distinct_id order by distinct_id",
)
.await?;
let results = df.clone().collect().await?;
// print_batches(&results);

let provider = MemTable::try_new(df.schema().clone().into(), vec![results])?;
ctx.register_table("retention_count_result", Arc::new(provider))?;

let results = ctx
.sql("select retention_sum(stats) from retention_count_result")
.await?
.collect()
.await?;

print_batches(&results)?;

Ok(())
}
1 change: 0 additions & 1 deletion src/retention/retention_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ impl Accumulator for RetentionSum {
fn evaluate(&self) -> Result<ScalarValue> {
let arr_ref = &ScalarValue::iter_to_array(self.total_active.clone()).unwrap();
let mut final_result: Vec<Vec<ScalarValue>> = Vec::new();

for index in 0..arr_ref.len() {
if let ScalarValue::List(Some(per_user), _) =
ScalarValue::try_from_array(arr_ref, index)?
Expand Down

0 comments on commit 7f19d24

Please sign in to comment.