Skip to content

Commit

Permalink
Refactor ShapeBuilder away from a single shared global value (#491)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpschorr authored Jan 13, 2025
1 parent 1b19d11 commit 76764b7
Show file tree
Hide file tree
Showing 14 changed files with 651 additions and 460 deletions.
37 changes: 18 additions & 19 deletions extension/partiql-extension-ddl/src/ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,38 +228,37 @@ mod tests {
use indexmap::IndexSet;
use partiql_types::{
struct_fields, type_array, type_bag, type_float64, type_int8, type_string, type_struct,
PartiqlShapeBuilder, StructConstraint,
PartiqlShapeBuilder, ShapeBuilderExtensions, StructConstraint,
};

#[test]
fn ddl_test() {
let mut bld = PartiqlShapeBuilder::default();
let nested_attrs = struct_fields![
(
"a",
PartiqlShapeBuilder::init_or_get().any_of(vec![
PartiqlShapeBuilder::init_or_get().new_static(Static::DecimalP(5, 4)),
PartiqlShapeBuilder::init_or_get().new_static(Static::Int8),
])
[
bld.new_static(Static::DecimalP(5, 4)),
bld.new_static(Static::Int8),
]
.into_any_of(&mut bld)
),
("b", type_array![type_string![]]),
("c", type_float64!()),
("b", type_string![bld].into_array(&mut bld)),
("c", type_float64!(bld)),
];
let details = type_struct![IndexSet::from([nested_attrs])];
let details = type_struct![bld, IndexSet::from([nested_attrs])];

let fields = struct_fields![
("employee_id", type_int8![]),
("full_name", type_string![]),
(
"salary",
PartiqlShapeBuilder::init_or_get().new_static(Static::DecimalP(8, 2))
),
("employee_id", type_int8![bld]),
("full_name", type_string![bld]),
("salary", bld.new_static(Static::DecimalP(8, 2))),
("details", details),
("dependents", type_array![type_string![]])
("dependents", type_array![bld, type_string![bld]])
];
let ty = type_bag![
bld,
type_struct![bld, IndexSet::from([fields, StructConstraint::Open(false)])]
];
let ty = type_bag![type_struct![IndexSet::from([
fields,
StructConstraint::Open(false)
])]];

let expected_compact = r#""employee_id" TINYINT,"full_name" VARCHAR,"salary" DECIMAL(8, 2),"details" STRUCT<"a": UNION<DECIMAL(5, 4),TINYINT>,"b": type_array<VARCHAR>,"c": DOUBLE>,"dependents" type_array<VARCHAR>"#;
let expected_pretty = r#""employee_id" TINYINT,
Expand Down
30 changes: 17 additions & 13 deletions extension/partiql-extension-ddl/tests/ddl-tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,30 @@ use partiql_types::{
struct_fields, type_bag, type_int, type_string, type_struct, PartiqlShapeBuilder,
StructConstraint, StructField,
};
use partiql_types::{BagType, Static, StructType};
use partiql_types::{Static, StructType};

#[test]
fn basic_ddl_test() {
let details_fields = struct_fields![("age", type_int!())];
let details = type_struct![IndexSet::from([details_fields])];
let mut bld = PartiqlShapeBuilder::default();
let details_fields = struct_fields![("age", type_int!(bld))];
let details = type_struct![bld, IndexSet::from([details_fields])];
let fields = [
StructField::new("id", type_int!()),
StructField::new("name", type_string!()),
StructField::new(
"address",
PartiqlShapeBuilder::init_or_get().new_non_nullable_static(Static::String),
),
StructField::new("id", type_int!(bld)),
StructField::new("name", type_string!(bld)),
StructField::new("address", bld.new_non_nullable_static(Static::String)),
StructField::new_optional("details", details.clone()),
]
.into();
let shape = type_bag![type_struct![IndexSet::from([
StructConstraint::Fields(fields),
StructConstraint::Open(false)
])]];
let shape = type_bag![
bld,
type_struct![
bld,
IndexSet::from([
StructConstraint::Fields(fields),
StructConstraint::Open(false)
])
]
];

let ddl_compact = PartiqlBasicDdlEncoder::new(DdlFormat::Compact);
let actual = ddl_compact.ddl(&shape).expect("ddl_output");
Expand Down
59 changes: 56 additions & 3 deletions partiql-ast/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ where
}

pub fn node<T>(&mut self, node: T) -> AstNode<T> {
let id = self.id_gen.id();
let id = id.read().expect("NodId read lock");
AstNode { id: *id, node }
let id = self.id_gen.next_id();
AstNode { id, node }
}
}

Expand All @@ -36,3 +35,57 @@ pub type AstNodeBuilderWithAutoId = AstNodeBuilder<AutoNodeIdGenerator>;

/// A [`AstNodeBuilder`] whose 'fresh' [`NodeId`]s are always `0`; Useful for testing
pub type AstNodeBuilderWithNullId = AstNodeBuilder<NullIdGenerator>;

#[cfg(test)]
mod tests {
use super::AstNodeBuilderWithAutoId;
use crate::ast;

use crate::visit::{Traverse, Visit, Visitor};
use partiql_common::node::NodeId;
use partiql_common::pretty::ToPretty;

#[test]
fn unique_ids() {
let mut bld = AstNodeBuilderWithAutoId::default();

let mut i64_to_expr = |n| Box::new(ast::Expr::Lit(bld.node(ast::Lit::Int64Lit(n))));

let lhs = i64_to_expr(5);
let v1 = i64_to_expr(42);
let v2 = i64_to_expr(13);
let list = bld.node(ast::List {
values: vec![v1, v2],
});
let rhs = Box::new(ast::Expr::List(list));
let op = bld.node(ast::In { lhs, rhs });

let pretty_printed = op.to_pretty_string(80).expect("pretty print");
println!("{pretty_printed}");

dbg!(&op);

#[derive(Default)]
pub struct IdVisitor {
ids: Vec<NodeId>,
}

impl Visitor<'_> for IdVisitor {
fn enter_ast_node(&mut self, id: NodeId) -> Traverse {
self.ids.push(id);
Traverse::Continue
}
}

let mut idv = IdVisitor::default();
op.visit(&mut idv);
let IdVisitor { ids } = idv;
dbg!(&ids);

for i in 0..ids.len() {
for j in i + 1..ids.len() {
assert_ne!(ids[i], ids[j]);
}
}
}
}
49 changes: 27 additions & 22 deletions partiql-common/src/node.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use indexmap::IndexMap;
use std::sync::{Arc, RwLock};
use std::hash::Hash;

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand All @@ -10,39 +10,30 @@ pub type NodeMap<T> = IndexMap<NodeId, T>;
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct NodeId(pub u32);

#[derive(Debug)]
/// Auto-incrementing [`NodeIdGenerator`]
pub struct AutoNodeIdGenerator {
next_id: Arc<RwLock<NodeId>>,
next_id: NodeId,
}

impl Default for AutoNodeIdGenerator {
fn default() -> Self {
AutoNodeIdGenerator {
next_id: Arc::new(RwLock::from(NodeId(1))),
}
AutoNodeIdGenerator { next_id: NodeId(1) }
}
}

/// A provider of 'fresh' [`NodeId`]s.
pub trait NodeIdGenerator {
fn id(&self) -> Arc<RwLock<NodeId>>;

/// Provides a 'fresh' [`NodeId`].
fn next_id(&self) -> NodeId;
fn next_id(&mut self) -> NodeId;
}

impl NodeIdGenerator for AutoNodeIdGenerator {
fn id(&self) -> Arc<RwLock<NodeId>> {
let id = self.next_id();
let mut w = self.next_id.write().expect("NodId write lock");
*w = id;
Arc::clone(&self.next_id)
}

#[inline]
fn next_id(&self) -> NodeId {
let id = &self.next_id.read().expect("NodId read lock");
NodeId(id.0 + 1)
fn next_id(&mut self) -> NodeId {
let mut next = NodeId(&self.next_id.0 + 1);
std::mem::swap(&mut self.next_id, &mut next);
next
}
}

Expand All @@ -51,11 +42,25 @@ impl NodeIdGenerator for AutoNodeIdGenerator {
pub struct NullIdGenerator {}

impl NodeIdGenerator for NullIdGenerator {
fn id(&self) -> Arc<RwLock<NodeId>> {
Arc::new(RwLock::from(self.next_id()))
fn next_id(&mut self) -> NodeId {
NodeId(0)
}
}

fn next_id(&self) -> NodeId {
NodeId(0)
#[cfg(test)]
mod tests {
use crate::node::{AutoNodeIdGenerator, NodeIdGenerator};

#[test]
fn unique_ids() {
let mut gen = AutoNodeIdGenerator::default();

let ids: Vec<_> = std::iter::repeat_with(|| gen.next_id()).take(15).collect();
dbg!(&ids);
for i in 0..ids.len() {
for j in i + 1..ids.len() {
assert_ne!(ids[i], ids[j]);
}
}
}
}
4 changes: 2 additions & 2 deletions partiql-eval/src/eval/eval_expr_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::eval::expr::{BindError, EvalExpr};
use crate::eval::EvalContext;
use itertools::Itertools;

use partiql_types::{type_dynamic, PartiqlShape, Static, TYPE_DYNAMIC};
use partiql_types::{PartiqlShape, Static, TYPE_DYNAMIC};
use partiql_value::Value::{Missing, Null};
use partiql_value::{Tuple, Value};

Expand Down Expand Up @@ -469,7 +469,7 @@ impl UnaryValueExpr {
where
F: 'static + Fn(&Value) -> Value,
{
Self::create_typed::<STRICT, F>([type_dynamic!(); 1], args, f)
Self::create_typed::<STRICT, F>([PartiqlShape::Dynamic; 1], args, f)
}

#[allow(dead_code)]
Expand Down
51 changes: 25 additions & 26 deletions partiql-eval/src/eval/expr/coll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr};
use itertools::{Itertools, Unique};

use partiql_types::{
type_bool, type_numeric, ArrayType, BagType, PartiqlShape, PartiqlShapeBuilder, Static,
type_numeric, PartiqlNoIdShapeBuilder, PartiqlShape, ShapeBuilderExtensions, Static,
};
use partiql_value::Value::{Missing, Null};
use partiql_value::{BinaryAnd, BinaryOr, Value, ValueIter};
Expand Down Expand Up @@ -51,46 +51,45 @@ impl BindEvalExpr for EvalCollFn {
value.sequence_iter().map_or(Missing, &f)
})
}
let boolean_elems = [PartiqlShapeBuilder::init_or_get().any_of([
PartiqlShapeBuilder::init_or_get()
.new_static(Static::Array(ArrayType::new(Box::new(type_bool!())))),
PartiqlShapeBuilder::init_or_get()
.new_static(Static::Bag(BagType::new(Box::new(type_bool!())))),
])];
let numeric_elems = [PartiqlShapeBuilder::init_or_get().any_of([
PartiqlShapeBuilder::init_or_get().new_static(Static::Array(ArrayType::new(Box::new(
PartiqlShapeBuilder::init_or_get().any_of(type_numeric!()),
)))),
PartiqlShapeBuilder::init_or_get().new_static(Static::Bag(BagType::new(Box::new(
PartiqlShapeBuilder::init_or_get().any_of(type_numeric!()),
)))),
])];
let any_elems = [PartiqlShapeBuilder::init_or_get().any_of([
PartiqlShapeBuilder::init_or_get().new_static(Static::Array(ArrayType::new_any())),
PartiqlShapeBuilder::init_or_get().new_static(Static::Bag(BagType::new_any())),
])];

// use DummyShapeBuilder, as we don't care about shape Ids for evaluation dispatch
let mut bld = PartiqlNoIdShapeBuilder::default();

let boolean_elems = [
bld.new_array_of_static(Static::Bool),
bld.new_bag_of_static(Static::Bool),
]
.into_any_of(&mut bld);

let numeric_elems = [
type_numeric!(&mut bld).into_array(&mut bld),
type_numeric!(&mut bld).into_bag(&mut bld),
]
.into_any_of(&mut bld);

let any_elems = [bld.new_array_of_dyn(), bld.new_bag_of_dyn()].into_any_of(&mut bld);

match self {
EvalCollFn::Count(setq) => {
create::<{ STRICT }, _>(any_elems, args, move |it| it.coll_count(setq))
create::<{ STRICT }, _>([any_elems], args, move |it| it.coll_count(setq))
}
EvalCollFn::Avg(setq) => {
create::<{ STRICT }, _>(numeric_elems, args, move |it| it.coll_avg(setq))
create::<{ STRICT }, _>([numeric_elems], args, move |it| it.coll_avg(setq))
}
EvalCollFn::Max(setq) => {
create::<{ STRICT }, _>(any_elems, args, move |it| it.coll_max(setq))
create::<{ STRICT }, _>([any_elems], args, move |it| it.coll_max(setq))
}
EvalCollFn::Min(setq) => {
create::<{ STRICT }, _>(any_elems, args, move |it| it.coll_min(setq))
create::<{ STRICT }, _>([any_elems], args, move |it| it.coll_min(setq))
}
EvalCollFn::Sum(setq) => {
create::<{ STRICT }, _>(numeric_elems, args, move |it| it.coll_sum(setq))
create::<{ STRICT }, _>([numeric_elems], args, move |it| it.coll_sum(setq))
}
EvalCollFn::Any(setq) => {
create::<{ STRICT }, _>(boolean_elems, args, move |it| it.coll_any(setq))
create::<{ STRICT }, _>([boolean_elems], args, move |it| it.coll_any(setq))
}
EvalCollFn::Every(setq) => {
create::<{ STRICT }, _>(boolean_elems, args, move |it| it.coll_every(setq))
create::<{ STRICT }, _>([boolean_elems], args, move |it| it.coll_every(setq))
}
}
}
Expand Down
14 changes: 9 additions & 5 deletions partiql-eval/src/eval/expr/datetime.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::eval::expr::{BindError, BindEvalExpr, EvalExpr};

use partiql_types::type_datetime;
use partiql_types::{type_datetime, PartiqlNoIdShapeBuilder};
use partiql_value::Value::Missing;
use partiql_value::{DateTime, Value};

Expand Down Expand Up @@ -41,14 +41,18 @@ impl BindEvalExpr for EvalExtractFn {
let total = Duration::new(u64::from(second), nanosecond).as_nanos() as i128;
Decimal::from_i128_with_scale(total, NANOSECOND_SCALE).into()
}
// use DummyShapeBuilder, as we don't care about shape Ids for evaluation dispatch
let mut bld = PartiqlNoIdShapeBuilder::default();

let create = |f: fn(&DateTime) -> Value| {
UnaryValueExpr::create_typed::<{ STRICT }, _>([type_datetime!()], args, move |value| {
match value {
UnaryValueExpr::create_typed::<{ STRICT }, _>(
[type_datetime!(bld)],
args,
move |value| match value {
Value::DateTime(dt) => f(dt.as_ref()),
_ => Missing,
}
})
},
)
};

match self {
Expand Down
Loading

1 comment on commit 76764b7

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PartiQL (rust) Benchmark

Benchmark suite Current: 76764b7 Previous: 1b19d11 Ratio
arith_agg-avg 770153 ns/iter (± 2588) 766648 ns/iter (± 4238) 1.00
arith_agg-avg_distinct 859709 ns/iter (± 8950) 861535 ns/iter (± 1550) 1.00
arith_agg-count 820097 ns/iter (± 16502) 815139 ns/iter (± 12287) 1.01
arith_agg-count_distinct 851393 ns/iter (± 6059) 854865 ns/iter (± 1904) 1.00
arith_agg-min 825236 ns/iter (± 2850) 819999 ns/iter (± 2348) 1.01
arith_agg-min_distinct 855755 ns/iter (± 6185) 857922 ns/iter (± 1733) 1.00
arith_agg-max 835713 ns/iter (± 32001) 825238 ns/iter (± 3604) 1.01
arith_agg-max_distinct 867964 ns/iter (± 15459) 866332 ns/iter (± 2757) 1.00
arith_agg-sum 828048 ns/iter (± 2813) 820739 ns/iter (± 3041) 1.01
arith_agg-sum_distinct 859180 ns/iter (± 2473) 864789 ns/iter (± 2361) 0.99
arith_agg-avg-count-min-max-sum 977325 ns/iter (± 4155) 972890 ns/iter (± 3096) 1.00
arith_agg-avg-count-min-max-sum-group_by 1240333 ns/iter (± 4115) 1243160 ns/iter (± 11355) 1.00
arith_agg-avg-count-min-max-sum-group_by-group_as 1836969 ns/iter (± 23523) 1849620 ns/iter (± 10064) 0.99
arith_agg-avg_distinct-count_distinct-min_distinct-max_distinct-sum_distinct 1178308 ns/iter (± 12217) 1186602 ns/iter (± 4488) 0.99
arith_agg-avg_distinct-count_distinct-min_distinct-max_distinct-sum_distinct-group_by 1448869 ns/iter (± 19573) 1467134 ns/iter (± 7227) 0.99
arith_agg-avg_distinct-count_distinct-min_distinct-max_distinct-sum_distinct-group_by-group_as 2044763 ns/iter (± 5312) 2067090 ns/iter (± 14009) 0.99
parse-1 5434 ns/iter (± 27) 6571 ns/iter (± 13) 0.83
parse-15 47958 ns/iter (± 172) 54306 ns/iter (± 232) 0.88
parse-30 92866 ns/iter (± 333) 105393 ns/iter (± 244) 0.88
compile-1 4178 ns/iter (± 13) 4193 ns/iter (± 11) 1.00
compile-15 31003 ns/iter (± 209) 30585 ns/iter (± 134) 1.01
compile-30 63253 ns/iter (± 281) 62898 ns/iter (± 408) 1.01
plan-1 67876 ns/iter (± 389) 72782 ns/iter (± 197) 0.93
plan-15 1069256 ns/iter (± 24766) 1130914 ns/iter (± 79606) 0.95
plan-30 2193025 ns/iter (± 37115) 2264952 ns/iter (± 11751) 0.97
eval-1 12079025 ns/iter (± 163991) 11877539 ns/iter (± 51184) 1.02
eval-15 77791834 ns/iter (± 1120096) 75970932 ns/iter (± 3530087) 1.02
eval-30 147818887 ns/iter (± 2140975) 145051263 ns/iter (± 959139) 1.02
join 9949 ns/iter (± 39) 10004 ns/iter (± 453) 0.99
simple 2592 ns/iter (± 13) 2486 ns/iter (± 14) 1.04
simple-no 486 ns/iter (± 1) 459 ns/iter (± 1) 1.06
numbers 48 ns/iter (± 0) 48 ns/iter (± 0) 1
parse-simple 704 ns/iter (± 3) 884 ns/iter (± 2) 0.80
parse-ion 2316 ns/iter (± 22) 2611 ns/iter (± 7) 0.89
parse-group 7264 ns/iter (± 23) 7809 ns/iter (± 72) 0.93
parse-complex 18563 ns/iter (± 109) 20394 ns/iter (± 60) 0.91
parse-complex-fexpr 26129 ns/iter (± 126) 27578 ns/iter (± 158) 0.95

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.