Skip to content

Commit

Permalink
feat(hydroflow_lang): Added state_by operator. (#1469)
Browse files Browse the repository at this point in the history
For #1467
  • Loading branch information
rohitkulshreshtha authored Sep 26, 2024
1 parent 87a6834 commit d83cb83
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 128 deletions.
1 change: 1 addition & 0 deletions hydroflow_lang/src/graph/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ declare_ops![
source_stream::SOURCE_STREAM,
source_stream_serde::SOURCE_STREAM_SERDE,
state::STATE,
state_by::STATE_BY,
tee::TEE,
unique::UNIQUE,
unzip::UNZIP,
Expand Down
136 changes: 8 additions & 128 deletions hydroflow_lang/src/graph/ops/state.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use quote::{quote_spanned, ToTokens};

use syn::parse_quote_spanned;
use super::{
OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
Persistence, WriteContextArgs, RANGE_1,
OperatorCategory, OperatorConstraints,
WriteContextArgs, RANGE_1,
};
use crate::diagnostic::{Diagnostic, Level};

// TODO(mingwei): Improve example when things are more stable.
/// A lattice-based state operator, used for accumulating lattice state
Expand Down Expand Up @@ -36,132 +34,14 @@ pub const STATE: OperatorConstraints = OperatorConstraints {
ports_inn: None,
ports_out: None,
input_delaytype_fn: |_| None,
write_fn: |&WriteContextArgs {
root,
context,
hydroflow,
op_span,
ident,
inputs,
outputs,
is_pull,
singleton_output_ident,
op_name,
op_inst:
OperatorInstance {
generics:
OpInstGenerics {
type_args,
persistence_args,
..
},
..
},
..
},
write_fn: |wc @ &WriteContextArgs { op_span, .. },
diagnostics| {
let lattice_type = type_args
.first()
.map(ToTokens::to_token_stream)
.unwrap_or(quote_spanned!(op_span=> _));

let persistence = match persistence_args[..] {
[] => Persistence::Tick,
[Persistence::Mutable] => {
diagnostics.push(Diagnostic::spanned(
op_span,
Level::Error,
format!("{} does not support `'mut`.", op_name),
));
Persistence::Tick
}
[a] => a,
_ => unreachable!(),
};

let state_ident = singleton_output_ident;
let mut write_prologue = quote_spanned! {op_span=>
let #state_ident = #hydroflow.add_state(::std::cell::RefCell::new(
<#lattice_type as ::std::default::Default>::default()
));
let wc = WriteContextArgs {
arguments: &parse_quote_spanned!(op_span => ::std::convert::identity),
..wc.clone()
};
if Persistence::Tick == persistence {
write_prologue.extend(quote_spanned! {op_span=>
#hydroflow.set_state_tick_hook(#state_ident, |rcell| { rcell.take(); }); // Resets state to `Default::default()`.
});
}

// TODO(mingwei): deduplicate codegen
let write_iterator = if is_pull {
let input = &inputs[0];
quote_spanned! {op_span=>
let #ident = {
fn check_input<'a, Item, Iter, Lat>(
iter: Iter,
state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
context: &'a #root::scheduled::context::Context,
) -> impl 'a + ::std::iter::Iterator<Item = Item>
where
Item: ::std::clone::Clone,
Iter: 'a + ::std::iter::Iterator<Item = Item>,
Lat: 'static + #root::lattices::Merge<Item>,
{
iter.filter(move |item| {
let state = context.state_ref(state_handle);
let mut state = state.borrow_mut();
#root::lattices::Merge::merge(&mut *state, ::std::clone::Clone::clone(item))
})
}
check_input::<_, _, #lattice_type>(#input, #state_ident, #context)
};
}
} else if let Some(output) = outputs.first() {
quote_spanned! {op_span=>
let #ident = {
fn check_output<'a, Item, Push, Lat>(
push: Push,
state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
context: &'a #root::scheduled::context::Context,
) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
where
Item: 'a + ::std::clone::Clone,
Push: #root::pusherator::Pusherator<Item = Item>,
Lat: 'static + #root::lattices::Merge<Item>,
{
#root::pusherator::filter::Filter::new(move |item| {
let state = context.state_ref(state_handle);
let mut state = state.borrow_mut();
#root::lattices::Merge::merge(&mut *state, ::std::clone::Clone::clone(item))
}, push)
}
check_output::<_, _, #lattice_type>(#output, #state_ident, #context)
};
}
} else {
quote_spanned! {op_span=>
let #ident = {
fn check_output<'a, Item, Lat>(
state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
context: &'a #root::scheduled::context::Context,
) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
where
Item: 'a,
Lat: 'static + #root::lattices::Merge<Item>,
{
#root::pusherator::for_each::ForEach::new(move |item| {
let state = context.state_ref(state_handle);
let mut state = state.borrow_mut();
#root::lattices::Merge::merge(&mut *state, item);
})
}
check_output::<_, #lattice_type>(#state_ident, #context)
};
}
};
Ok(OperatorWriteOutput {
write_prologue,
write_iterator,
..Default::default()
})
(super::state_by::STATE_BY.write_fn)(&wc, diagnostics)
},
};
176 changes: 176 additions & 0 deletions hydroflow_lang/src/graph/ops/state_by.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
use quote::{quote_spanned, ToTokens};

use super::{
OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
Persistence, WriteContextArgs, RANGE_1,
};
use crate::diagnostic::{Diagnostic, Level};

/// List state operator, but with a closure to map the input to the state lattice.
///
/// The emitted outputs (both the referencable singleton and the optional pass-through stream) are
/// of the same type as the inputs to the state_by operator and are not required to be a lattice
/// type. This is useful receiving pass-through context information on the output side.
///
/// ```hydroflow
/// use std::collections::HashSet;
///
/// use lattices::set_union::{CartesianProductBimorphism, SetUnionHashSet, SetUnionSingletonSet};
///
/// my_state = source_iter(0..3)
/// -> state_by::<SetUnionHashSet<usize>>(SetUnionSingletonSet::new_from);
/// ```
pub const STATE_BY: OperatorConstraints = OperatorConstraints {
name: "state_by",
categories: &[OperatorCategory::Persistence],
hard_range_inn: RANGE_1,
soft_range_inn: RANGE_1,
hard_range_out: &(0..=1),
soft_range_out: &(0..=1),
num_args: 1,
persistence_args: &(0..=1),
type_args: &(0..=1),
is_external_input: false,
has_singleton_output: true,
ports_inn: None,
ports_out: None,
input_delaytype_fn: |_| None,
write_fn: |&WriteContextArgs {
root,
context,
hydroflow,
op_span,
ident,
inputs,
outputs,
is_pull,
singleton_output_ident,
op_name,
op_inst:
OperatorInstance {
generics:
OpInstGenerics {
type_args,
persistence_args,
..
},
..
},
arguments,
..
},
diagnostics| {
let lattice_type = type_args
.first()
.map(ToTokens::to_token_stream)
.unwrap_or(quote_spanned!(op_span=> _));

let persistence = match persistence_args[..] {
[] => Persistence::Tick,
[Persistence::Mutable] => {
diagnostics.push(Diagnostic::spanned(
op_span,
Level::Error,
format!("{} does not support `'mut`.", op_name),
));
Persistence::Tick
}
[a] => a,
_ => unreachable!(),
};

let state_ident = singleton_output_ident;
let mut write_prologue = quote_spanned! {op_span=>
let #state_ident = #hydroflow.add_state(::std::cell::RefCell::new(
<#lattice_type as ::std::default::Default>::default()
));
};
if Persistence::Tick == persistence {
write_prologue.extend(quote_spanned! {op_span=>
#hydroflow.set_state_tick_hook(#state_ident, |rcell| { rcell.take(); }); // Resets state to `Default::default()`.
});
}

let func = &arguments[0];

// TODO(mingwei): deduplicate codegen
let write_iterator = if is_pull {
let input = &inputs[0];
quote_spanned! {op_span=>
let #ident = {
fn check_input<'a, Item, MappingFn, MappedItem, Iter, Lat>(
iter: Iter,
mapfn: MappingFn,
state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
context: &'a #root::scheduled::context::Context,
) -> impl 'a + ::std::iter::Iterator<Item = Item>
where
Item: ::std::clone::Clone,
MappingFn: 'a + Fn(Item) -> MappedItem,
Iter: 'a + ::std::iter::Iterator<Item = Item>,
Lat: 'static + #root::lattices::Merge<MappedItem>,
{
iter.filter(move |item| {
let state = context.state_ref(state_handle);
let mut state = state.borrow_mut();
#root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item)))
})
}
check_input::<_, _, _, _, #lattice_type>(#input, #func, #state_ident, #context)
};
}
} else if let Some(output) = outputs.first() {
quote_spanned! {op_span=>
let #ident = {
fn check_output<'a, Item, MappingFn, MappedItem, Push, Lat>(
push: Push,
mapfn: MappingFn,
state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
context: &'a #root::scheduled::context::Context,
) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
where
Item: 'a + ::std::clone::Clone,
MappingFn: 'a + Fn(Item) -> MappedItem,
Push: 'a + #root::pusherator::Pusherator<Item = Item>,
Lat: 'static + #root::lattices::Merge<MappedItem>,
{
#root::pusherator::filter::Filter::new(move |item| {
let state = context.state_ref(state_handle);
let mut state = state.borrow_mut();
#root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item)))
}, push)
}
check_output::<_, _, _, _, #lattice_type>(#output, #func, #state_ident, #context)
};
}
} else {
quote_spanned! {op_span=>
let #ident = {
fn check_output<'a, Item, MappingFn, MappedItem, Lat>(
state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
mapfn: MappingFn,
context: &'a #root::scheduled::context::Context,
) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
where
Item: 'a,
MappedItem: 'a,
MappingFn: 'a + Fn(Item) -> MappedItem,
Lat: 'static + #root::lattices::Merge<MappedItem>,
{
#root::pusherator::for_each::ForEach::new(move |item| {
let state = context.state_ref(state_handle);
let mut state = state.borrow_mut();
#root::lattices::Merge::merge(&mut *state, (mapfn)(item));
})
}
check_output::<_, _, _, #lattice_type>(#state_ident, #func, #context)
};
}
};
Ok(OperatorWriteOutput {
write_prologue,
write_iterator,
..Default::default()
})
},
};

0 comments on commit d83cb83

Please sign in to comment.