Skip to content

Commit

Permalink
feat: add join_multiset() (#804)
Browse files Browse the repository at this point in the history
* feat: add join_multiset()

also remove documentation about HalfJoinMultiset, the way to access
that now is to use join_multiset()

* address comments

* fix assert
  • Loading branch information
zzlk authored Jun 30, 2023
1 parent efac9ba commit 0105246
Show file tree
Hide file tree
Showing 12 changed files with 139 additions and 94 deletions.
2 changes: 1 addition & 1 deletion benches/benches/reachability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ fn benchmark_hydroflow_surface(c: &mut Criterion) {
hydroflow_syntax! {
origin = source_iter(vec![1]);
stream_of_edges = source_iter(edges);
reached_vertices = union() -> unique();
reached_vertices = union();
origin -> reached_vertices;

my_join_tee = join() -> flat_map(|(src, ((), dst))| [src, dst]) -> tee();
Expand Down
8 changes: 6 additions & 2 deletions docs/docs/hydroflow/quickstart/example_5_reachability.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,14 @@ We route the `origin` vertex into it as one input right away:

<CodeBlock language="rust">{getLines(exampleCode, 8, 12)}</CodeBlock>

Note the square-bracket syntax for differentiating the multiple inputs to `union()`
is the same as that of `join()` (except that union can have an unbounded number of inputs,
whereas `join()` is defined to only have two.)

Now, `join()` is defined to only have one output. In our program, we want to copy the joined
output to two places: to the original `for_each` from above to print output, and *also*
back to the `union` operator we called `reached_vertices`. This is also the reason why `reached_vertices` must now also contain a `unique()` because without it the data would cycle endlessly in the graph.
We feed the `join()` output through a `flat_map()` as before, and then we feed the result into a [`tee()`](../syntax/surface_ops_gen.md#tee) operator,
back to the `union` operator we called `reached_vertices`. We feed the `join()` output
through a `flat_map()` as before, and then we feed the result into a [`tee()`](../syntax/surface_ops_gen.md#tee) operator,
which is the mirror image of `union()`: instead of merging many inputs to one output,
it copies one input to many different outputs. Each input element is _cloned_, in Rust terms, and
given to each of the outputs. The syntax for the outputs of `tee()` mirrors that of union: we *append*
Expand Down
10 changes: 5 additions & 5 deletions hydroflow/examples/example_5_reachability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ pub fn main() {
// inputs: the origin vertex (vertex 0) and stream of input edges
origin = source_iter(vec![0]);
stream_of_edges = source_stream(edges_recv);
origin -> reached_vertices;
reached_vertices = union() -> unique();
origin -> [0]reached_vertices;
reached_vertices = union();

// the join
reached_vertices -> map(|v| (v, ())) -> [0]my_join_tee;
stream_of_edges -> [1]my_join_tee;
my_join_tee = join() -> flat_map(|(src, ((), dst))| [src, dst]) -> unique() -> tee();
my_join_tee = join() -> flat_map(|(src, ((), dst))| [src, dst]) -> tee();

// the loop and the output
my_join_tee[0] -> reached_vertices;
my_join_tee[1] -> for_each(|x| println!("Reached: {}", x));
my_join_tee[0] -> [1]reached_vertices;
my_join_tee[1] -> unique() -> for_each(|x| println!("Reached: {}", x));
};

println!(
Expand Down
14 changes: 7 additions & 7 deletions hydroflow/examples/example_6_unreachability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,26 @@ pub fn main() {
let mut flow = hydroflow_syntax! {
origin = source_iter(vec![0]);
stream_of_edges = source_stream(pairs_recv) -> tee();
reached_vertices = union() -> unique() -> tee();
origin -> reached_vertices;
reached_vertices = union()->tee();
origin -> [0]reached_vertices;

// the join for reachable vertices
my_join = join() -> flat_map(|(src, ((), dst))| [src, dst]);
reached_vertices -> map(|v| (v, ())) -> [0]my_join;
reached_vertices[0] -> map(|v| (v, ())) -> [0]my_join;
stream_of_edges[1] -> [1]my_join;

// the loop
my_join -> reached_vertices;
my_join -> [1]reached_vertices;

// the difference all_vertices - reached_vertices
all_vertices = stream_of_edges[0]
-> flat_map(|(src, dst)| [src, dst]) -> tee();
unreached_vertices = difference();
all_vertices -> [pos]unreached_vertices;
reached_vertices -> [neg]unreached_vertices;
all_vertices[0] -> [pos]unreached_vertices;
reached_vertices[1] -> [neg]unreached_vertices;

// the output
all_vertices -> for_each(|v| println!("Received vertex: {}", v));
all_vertices[1] -> unique() -> for_each(|v| println!("Received vertex: {}", v));
unreached_vertices -> for_each(|v| println!("unreached_vertices vertex: {}", v));
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ digraph {
n4v1 [label="(n4v1) source_stream(diagnosed_recv)", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n9v1 [label="(n9v1) map(|(pid, t)| (pid, (t, t + TRANSMISSIBLE_DURATION)))", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n3v1 [label="(n3v1) union()", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n5v1 [label="(n5v1) join::<HalfSetJoinState>()", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n5v1 [label="(n5v1) join()", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n6v1 [label="(n6v1) filter(|(_pid_a, ((_pid_b, t_contact), (t_from, t_to)))| {\l (t_from..=t_to).contains(&t_contact)\l})\l", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n7v1 [label="(n7v1) map(|(_pid_a, (pid_b_t_contact, _t_from_to))| pid_b_t_contact)", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n8v1 [label="(n8v1) tee()", fontname=Monaco, shape=house, style = filled, color = "#ffff00"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ subgraph sg_1v1 ["sg_1v1 stratum 0"]
4v1[\"(4v1) <code>source_stream(diagnosed_recv)</code>"/]:::pullClass
9v1[\"(9v1) <code>map(|(pid, t)| (pid, (t, t + TRANSMISSIBLE_DURATION)))</code>"/]:::pullClass
3v1[\"(3v1) <code>union()</code>"/]:::pullClass
5v1[\"(5v1) <code>join::&lt;HalfSetJoinState&gt;()</code>"/]:::pullClass
5v1[\"(5v1) <code>join()</code>"/]:::pullClass
6v1[\"<div style=text-align:center>(6v1)</div> <code>filter(|(_pid_a, ((_pid_b, t_contact), (t_from, t_to)))| {<br> (t_from..=t_to).contains(&amp;t_contact)<br>})</code>"/]:::pullClass
7v1[\"(7v1) <code>map(|(_pid_a, (pid_b_t_contact, _t_from_to))| pid_b_t_contact)</code>"/]:::pullClass
8v1[/"(8v1) <code>tee()</code>"\]:::pushClass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,33 @@ subgraph sg_1v1 ["sg_1v1 stratum 0"]
1v1[\"(1v1) <code>source_iter(vec![0])</code>"/]:::pullClass
2v1[\"(2v1) <code>source_stream(edges_recv)</code>"/]:::pullClass
3v1[\"(3v1) <code>union()</code>"/]:::pullClass
4v1[\"(4v1) <code>unique()</code>"/]:::pullClass
5v1[\"(5v1) <code>map(|v| (v, ()))</code>"/]:::pullClass
6v1[\"(6v1) <code>join()</code>"/]:::pullClass
7v1[\"(7v1) <code>flat_map(|(src, ((), dst))| [src, dst])</code>"/]:::pullClass
8v1[\"(8v1) <code>unique()</code>"/]:::pullClass
9v1[/"(9v1) <code>tee()</code>"\]:::pushClass
10v1[/"(10v1) <code>for_each(|x| println!(&quot;Reached: {}&quot;, x))</code>"\]:::pushClass
11v1["(11v1) <code>handoff</code>"]:::otherClass
11v1--->3v1
1v1--->3v1
2v1--1--->6v1
4v1[\"(4v1) <code>map(|v| (v, ()))</code>"/]:::pullClass
5v1[\"(5v1) <code>join()</code>"/]:::pullClass
6v1[\"(6v1) <code>flat_map(|(src, ((), dst))| [src, dst])</code>"/]:::pullClass
7v1[/"(7v1) <code>tee()</code>"\]:::pushClass
8v1[/"(8v1) <code>unique()</code>"\]:::pushClass
9v1[/"(9v1) <code>for_each(|x| println!(&quot;Reached: {}&quot;, x))</code>"\]:::pushClass
10v1["(10v1) <code>handoff</code>"]:::otherClass
10v1--1--->3v1
1v1--0--->3v1
2v1--1--->5v1
3v1--->4v1
4v1--->5v1
5v1--0--->6v1
4v1--0--->5v1
5v1--->6v1
6v1--->7v1
7v1--->8v1
7v1--0--->10v1
7v1--1--->8v1
8v1--->9v1
9v1--0--->11v1
9v1--1--->10v1
subgraph sg_1v1_var_my_join_tee ["var <tt>my_join_tee</tt>"]
5v1
6v1
7v1
8v1
9v1
end
subgraph sg_1v1_var_origin ["var <tt>origin</tt>"]
1v1
end
subgraph sg_1v1_var_reached_vertices ["var <tt>reached_vertices</tt>"]
3v1
4v1
end
subgraph sg_1v1_var_stream_of_edges ["var <tt>stream_of_edges</tt>"]
2v1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,79 +9,75 @@ classDef pushClass fill:#ff8,stroke:#000,text-align:left,white-space:pre
linkStyle default stroke:#aaa,stroke-width:4px,color:red,font-size:1.5em;
subgraph sg_1v1 ["sg_1v1 stratum 0"]
1v1[\"(1v1) <code>source_iter(vec![0])</code>"/]:::pullClass
9v1[\"(9v1) <code>map(|v| (v, ()))</code>"/]:::pullClass
7v1[\"(7v1) <code>join()</code>"/]:::pullClass
8v1[\"(8v1) <code>flat_map(|(src, ((), dst))| [src, dst])</code>"/]:::pullClass
8v1[\"(8v1) <code>map(|v| (v, ()))</code>"/]:::pullClass
6v1[\"(6v1) <code>join()</code>"/]:::pullClass
7v1[\"(7v1) <code>flat_map(|(src, ((), dst))| [src, dst])</code>"/]:::pullClass
4v1[\"(4v1) <code>union()</code>"/]:::pullClass
5v1[\"(5v1) <code>unique()</code>"/]:::pullClass
6v1[/"(6v1) <code>tee()</code>"\]:::pushClass
5v1[/"(5v1) <code>tee()</code>"\]:::pushClass
15v1["(15v1) <code>handoff</code>"]:::otherClass
15v1--->9v1
1v1--->4v1
9v1--0--->7v1
7v1--->8v1
8v1--->4v1
15v1--->8v1
1v1--0--->4v1
8v1--0--->6v1
6v1--->7v1
7v1--1--->4v1
4v1--->5v1
5v1--->6v1
6v1--->15v1
5v1--0--->15v1
subgraph sg_1v1_var_my_join ["var <tt>my_join</tt>"]
6v1
7v1
8v1
end
subgraph sg_1v1_var_origin ["var <tt>origin</tt>"]
1v1
end
subgraph sg_1v1_var_reached_vertices ["var <tt>reached_vertices</tt>"]
4v1
5v1
6v1
end
end
subgraph sg_2v1 ["sg_2v1 stratum 0"]
2v1[\"(2v1) <code>source_stream(pairs_recv)</code>"/]:::pullClass
3v1[/"(3v1) <code>tee()</code>"\]:::pushClass
10v1[/"(10v1) <code>flat_map(|(src, dst)| [src, dst])</code>"\]:::pushClass
11v1[/"(11v1) <code>tee()</code>"\]:::pushClass
9v1[/"(9v1) <code>flat_map(|(src, dst)| [src, dst])</code>"\]:::pushClass
10v1[/"(10v1) <code>tee()</code>"\]:::pushClass
12v1[/"(12v1) <code>unique()</code>"\]:::pushClass
13v1[/"(13v1) <code>for_each(|v| println!(&quot;Received vertex: {}&quot;, v))</code>"\]:::pushClass
2v1--->3v1
3v1--0--->10v1
10v1--->11v1
11v1--->13v1
3v1--0--->9v1
9v1--->10v1
10v1--1--->12v1
12v1--->13v1
subgraph sg_2v1_var_all_vertices ["var <tt>all_vertices</tt>"]
9v1
10v1
11v1
end
subgraph sg_2v1_var_stream_of_edges ["var <tt>stream_of_edges</tt>"]
2v1
3v1
end
end
subgraph sg_3v1 ["sg_3v1 stratum 1"]
12v1[\"(12v1) <code>difference()</code>"/]:::pullClass
11v1[\"(11v1) <code>difference()</code>"/]:::pullClass
14v1[/"(14v1) <code>for_each(|v| println!(&quot;unreached_vertices vertex: {}&quot;, v))</code>"\]:::pushClass
12v1--->14v1
11v1--->14v1
subgraph sg_3v1_var_unreached_vertices ["var <tt>unreached_vertices</tt>"]
12v1
11v1
end
end
3v1--1--->16v1
6v1--->18v1
11v1--->17v1
5v1--1--->18v1
10v1--0--->17v1
16v1["(16v1) <code>handoff</code>"]:::otherClass
16v1--1--->7v1
16v1--1--->6v1
17v1["(17v1) <code>handoff</code>"]:::otherClass
17v1--pos--->12v1
17v1--pos--->11v1
18v1["(18v1) <code>handoff</code>"]:::otherClass
18v1==neg===o12v1
18v1==neg===o11v1

Received vertex: 5
Received vertex: 10
Received vertex: 0
Received vertex: 3
Received vertex: 3
Received vertex: 6
Received vertex: 6
Received vertex: 5
Received vertex: 11
Received vertex: 12
unreached_vertices vertex: 11
Expand Down
3 changes: 1 addition & 2 deletions hydroflow/tests/surface_codegen.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::collections::HashSet;

use hydroflow::compiled::pull::HalfSetJoinState;
use hydroflow::scheduled::graph::Hydroflow;
use hydroflow::util::collect_ready;
use hydroflow::{assert_graphvis_snapshots, hydroflow_syntax};
Expand Down Expand Up @@ -740,7 +739,7 @@ pub fn test_covid_tracing() {
source_stream(diagnosed_recv) -> [0]exposed;

new_exposed = (
join::<HalfSetJoinState>() ->
join() ->
filter(|(_pid_a, ((_pid_b, t_contact), (t_from, t_to)))| {
(t_from..=t_to).contains(&t_contact)
}) ->
Expand Down
30 changes: 5 additions & 25 deletions hydroflow_lang/src/graph/ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ use crate::graph::{OpInstGenerics, OperatorInstance};
/// Forms the equijoin of the tuples in the input streams by their first (key) attribute. Note that the result nests the 2nd input field (values) into a tuple in the 2nd output field.
///
/// ```hydroflow
/// // should print `(hello, (world, cleveland))`
/// source_iter(vec![("hello", "world"), ("stay", "gold")]) -> [0]my_join;
/// source_iter(vec![("hello", "cleveland"), ("hello", "cleveland")]) -> [1]my_join;
/// my_join = join() -> assert([("hello", ("world", "cleveland")), ("hello", ("world", "cleveland"))]);
/// source_iter(vec![("hello", "world"), ("stay", "gold"), ("hello", "world")]) -> [0]my_join;
/// source_iter(vec![("hello", "cleveland")]) -> [1]my_join;
/// my_join = join()
/// -> assert([("hello", ("world", "cleveland"))]);
/// ```
///
/// `join` can also be provided with one or two generic lifetime persistence arguments, either
Expand All @@ -43,26 +43,6 @@ use crate::graph::{OpInstGenerics, OperatorInstance};
/// // etc.
/// ```
///
/// Join also accepts one type argument that controls how the join state is built up. This (currently) allows switching between a SetUnion and NonSetUnion implementation.
/// The default is HalfMultisetJoinState
/// For example:
/// ```hydroflow
/// lhs = source_iter([("a", 0), ("a", 0)]) -> tee();
/// rhs = source_iter([("a", 0)]) -> tee();
///
/// lhs -> [0]default_join;
/// rhs -> [1]default_join;
/// default_join = join() -> assert([("a", (0, 0)), ("a", (0, 0))]);
///
/// lhs -> [0]multiset_join;
/// rhs -> [1]multiset_join;
/// multiset_join = join::<hydroflow::compiled::pull::HalfMultisetJoinState>() -> assert([("a", (0, 0)), ("a", (0, 0))]);
///
/// lhs -> [0]set_join;
/// rhs -> [1]set_join;
/// set_join = join::<hydroflow::compiled::pull::HalfSetJoinState>() -> assert([("a", (0, 0))]);
/// ```
///
/// ### Examples
///
/// ```rustbook
Expand Down Expand Up @@ -140,7 +120,7 @@ pub const JOIN: OperatorConstraints = OperatorConstraints {
.get(0)
.map(ToTokens::to_token_stream)
.unwrap_or(quote_spanned!(op_span=>
#root::compiled::pull::HalfMultisetJoinState
#root::compiled::pull::HalfSetJoinState
));

// TODO: This is really bad.
Expand Down
69 changes: 69 additions & 0 deletions hydroflow_lang/src/graph/ops/join_multiset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use syn::{parse_quote, parse_quote_spanned};

use super::{
FlowProperties, FlowPropertyVal, OperatorCategory, OperatorConstraints, WriteContextArgs,
RANGE_0, RANGE_1,
};
use crate::graph::{OpInstGenerics, OperatorInstance};

/// > 2 input streams of type <(K, V1)> and <(K, V2)>, 1 output stream of type <(K, (V1, V2))>
///
/// This operator is equivalent to `join` except that the LHS and RHS are collected into multisets rather than sets before joining.
///
/// For example:
/// ```hydroflow
/// lhs = source_iter([("a", 0), ("a", 0)]) -> tee();
/// rhs = source_iter([("a", "hydro")]) -> tee();
///
/// lhs -> [0]multiset_join;
/// rhs -> [1]multiset_join;
/// multiset_join = join_multiset() -> assert([("a", (0, "hydro")), ("a", (0, "hydro"))]);
///
/// lhs -> [0]set_join;
/// rhs -> [1]set_join;
/// set_join = join() -> assert([("a", (0, "hydro"))]);
/// ```
pub const JOIN_MULTISET: OperatorConstraints = OperatorConstraints {
name: "join_multiset",
categories: &[OperatorCategory::MultiIn],
hard_range_inn: &(2..=2),
soft_range_inn: &(2..=2),
hard_range_out: RANGE_1,
soft_range_out: RANGE_1,
num_args: 0,
persistence_args: &(0..=2),
type_args: RANGE_0,
is_external_input: false,
ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
ports_out: None,
properties: FlowProperties {
deterministic: FlowPropertyVal::Preserve,
monotonic: FlowPropertyVal::Preserve,
inconsistency_tainted: false,
},
input_delaytype_fn: |_| None,
write_fn: |wc @ &WriteContextArgs {
root,
op_span,
op_inst: op_inst @ OperatorInstance { .. },
..
},
diagnostics| {
let join_type = parse_quote_spanned! {op_span=> // Uses `lat_type.span()`!
#root::compiled::pull::HalfMultisetJoinState
};

let wc = WriteContextArgs {
op_inst: &OperatorInstance {
generics: OpInstGenerics {
type_args: vec![join_type],
..wc.op_inst.generics.clone()
},
..op_inst.clone()
},
..wc.clone()
};

(super::join::JOIN.write_fn)(&wc, diagnostics)
},
};
Loading

0 comments on commit 0105246

Please sign in to comment.