Skip to content

Commit

Permalink
fix(join)!: make join default to multiset join (#774)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzlk authored Jun 29, 2023
1 parent 8f67c26 commit 6f3c536
Show file tree
Hide file tree
Showing 10 changed files with 88 additions and 71 deletions.
2 changes: 1 addition & 1 deletion benches/benches/reachability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ fn benchmark_hydroflow_surface(c: &mut Criterion) {
hydroflow_syntax! {
origin = source_iter(vec![1]);
stream_of_edges = source_iter(edges);
reached_vertices = union();
reached_vertices = union() -> unique();
origin -> reached_vertices;

my_join_tee = join() -> flat_map(|(src, ((), dst))| [src, dst]) -> tee();
Expand Down
8 changes: 2 additions & 6 deletions docs/docs/hydroflow/quickstart/example_5_reachability.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,10 @@ We route the `origin` vertex into it as one input right away:

<CodeBlock language="rust">{getLines(exampleCode, 8, 12)}</CodeBlock>

Note the square-bracket syntax for differentiating the multiple inputs to `union()`
is the same as that of `join()` (except that union can have an unbounded number of inputs,
whereas `join()` is defined to only have two.)

Now, `join()` is defined to only have one output. In our program, we want to copy the joined
output to two places: to the original `for_each` from above to print output, and *also*
back to the `union` operator we called `reached_vertices`. We feed the `join()` output
through a `flat_map()` as before, and then we feed the result into a [`tee()`](../syntax/surface_ops_gen.md#tee) operator,
back to the `union` operator we called `reached_vertices`. This is also the reason why `reached_vertices` must now also contain a `unique()` because without it the data would cycle endlessly in the graph.
We feed the `join()` output through a `flat_map()` as before, and then we feed the result into a [`tee()`](../syntax/surface_ops_gen.md#tee) operator,
which is the mirror image of `union()`: instead of merging many inputs to one output,
it copies one input to many different outputs. Each input element is _cloned_, in Rust terms, and
given to each of the outputs. The syntax for the outputs of `tee()` mirrors that of union: we *append*
Expand Down
10 changes: 5 additions & 5 deletions hydroflow/examples/example_5_reachability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ pub fn main() {
// inputs: the origin vertex (vertex 0) and stream of input edges
origin = source_iter(vec![0]);
stream_of_edges = source_stream(edges_recv);
origin -> [0]reached_vertices;
reached_vertices = union();
origin -> reached_vertices;
reached_vertices = union() -> unique();

// the join
reached_vertices -> map(|v| (v, ())) -> [0]my_join_tee;
stream_of_edges -> [1]my_join_tee;
my_join_tee = join() -> flat_map(|(src, ((), dst))| [src, dst]) -> tee();
my_join_tee = join() -> flat_map(|(src, ((), dst))| [src, dst]) -> unique() -> tee();

// the loop and the output
my_join_tee[0] -> [1]reached_vertices;
my_join_tee[1] -> unique() -> for_each(|x| println!("Reached: {}", x));
my_join_tee[0] -> reached_vertices;
my_join_tee[1] -> for_each(|x| println!("Reached: {}", x));
};

println!(
Expand Down
14 changes: 7 additions & 7 deletions hydroflow/examples/example_6_unreachability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,26 @@ pub fn main() {
let mut flow = hydroflow_syntax! {
origin = source_iter(vec![0]);
stream_of_edges = source_stream(pairs_recv) -> tee();
reached_vertices = union()->tee();
origin -> [0]reached_vertices;
reached_vertices = union() -> unique() -> tee();
origin -> reached_vertices;

// the join for reachable vertices
my_join = join() -> flat_map(|(src, ((), dst))| [src, dst]);
reached_vertices[0] -> map(|v| (v, ())) -> [0]my_join;
reached_vertices -> map(|v| (v, ())) -> [0]my_join;
stream_of_edges[1] -> [1]my_join;

// the loop
my_join -> [1]reached_vertices;
my_join -> reached_vertices;

// the difference all_vertices - reached_vertices
all_vertices = stream_of_edges[0]
-> flat_map(|(src, dst)| [src, dst]) -> tee();
unreached_vertices = difference();
all_vertices[0] -> [pos]unreached_vertices;
reached_vertices[1] -> [neg]unreached_vertices;
all_vertices -> [pos]unreached_vertices;
reached_vertices -> [neg]unreached_vertices;

// the output
all_vertices[1] -> unique() -> for_each(|v| println!("Received vertex: {}", v));
all_vertices -> for_each(|v| println!("Received vertex: {}", v));
unreached_vertices -> for_each(|v| println!("unreached_vertices vertex: {}", v));
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ digraph {
n4v1 [label="(n4v1) source_stream(diagnosed_recv)", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n9v1 [label="(n9v1) map(|(pid, t)| (pid, (t, t + TRANSMISSIBLE_DURATION)))", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n3v1 [label="(n3v1) union()", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n5v1 [label="(n5v1) join()", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n5v1 [label="(n5v1) join::<HalfSetJoinState>()", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n6v1 [label="(n6v1) filter(|(_pid_a, ((_pid_b, t_contact), (t_from, t_to)))| {\l (t_from..=t_to).contains(&t_contact)\l})\l", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n7v1 [label="(n7v1) map(|(_pid_a, (pid_b_t_contact, _t_from_to))| pid_b_t_contact)", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n8v1 [label="(n8v1) tee()", fontname=Monaco, shape=house, style = filled, color = "#ffff00"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ subgraph sg_1v1 ["sg_1v1 stratum 0"]
4v1[\"(4v1) <code>source_stream(diagnosed_recv)</code>"/]:::pullClass
9v1[\"(9v1) <code>map(|(pid, t)| (pid, (t, t + TRANSMISSIBLE_DURATION)))</code>"/]:::pullClass
3v1[\"(3v1) <code>union()</code>"/]:::pullClass
5v1[\"(5v1) <code>join()</code>"/]:::pullClass
5v1[\"(5v1) <code>join::&lt;HalfSetJoinState&gt;()</code>"/]:::pullClass
6v1[\"<div style=text-align:center>(6v1)</div> <code>filter(|(_pid_a, ((_pid_b, t_contact), (t_from, t_to)))| {<br> (t_from..=t_to).contains(&amp;t_contact)<br>})</code>"/]:::pullClass
7v1[\"(7v1) <code>map(|(_pid_a, (pid_b_t_contact, _t_from_to))| pid_b_t_contact)</code>"/]:::pullClass
8v1[/"(8v1) <code>tee()</code>"\]:::pushClass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,37 @@ subgraph sg_1v1 ["sg_1v1 stratum 0"]
1v1[\"(1v1) <code>source_iter(vec![0])</code>"/]:::pullClass
2v1[\"(2v1) <code>source_stream(edges_recv)</code>"/]:::pullClass
3v1[\"(3v1) <code>union()</code>"/]:::pullClass
4v1[\"(4v1) <code>map(|v| (v, ()))</code>"/]:::pullClass
5v1[\"(5v1) <code>join()</code>"/]:::pullClass
6v1[\"(6v1) <code>flat_map(|(src, ((), dst))| [src, dst])</code>"/]:::pullClass
7v1[/"(7v1) <code>tee()</code>"\]:::pushClass
8v1[/"(8v1) <code>unique()</code>"\]:::pushClass
9v1[/"(9v1) <code>for_each(|x| println!(&quot;Reached: {}&quot;, x))</code>"\]:::pushClass
10v1["(10v1) <code>handoff</code>"]:::otherClass
10v1--1--->3v1
1v1--0--->3v1
2v1--1--->5v1
4v1[\"(4v1) <code>unique()</code>"/]:::pullClass
5v1[\"(5v1) <code>map(|v| (v, ()))</code>"/]:::pullClass
6v1[\"(6v1) <code>join()</code>"/]:::pullClass
7v1[\"(7v1) <code>flat_map(|(src, ((), dst))| [src, dst])</code>"/]:::pullClass
8v1[\"(8v1) <code>unique()</code>"/]:::pullClass
9v1[/"(9v1) <code>tee()</code>"\]:::pushClass
10v1[/"(10v1) <code>for_each(|x| println!(&quot;Reached: {}&quot;, x))</code>"\]:::pushClass
11v1["(11v1) <code>handoff</code>"]:::otherClass
11v1--->3v1
1v1--->3v1
2v1--1--->6v1
3v1--->4v1
4v1--0--->5v1
5v1--->6v1
4v1--->5v1
5v1--0--->6v1
6v1--->7v1
7v1--0--->10v1
7v1--1--->8v1
7v1--->8v1
8v1--->9v1
9v1--0--->11v1
9v1--1--->10v1
subgraph sg_1v1_var_my_join_tee ["var <tt>my_join_tee</tt>"]
5v1
6v1
7v1
8v1
9v1
end
subgraph sg_1v1_var_origin ["var <tt>origin</tt>"]
1v1
end
subgraph sg_1v1_var_reached_vertices ["var <tt>reached_vertices</tt>"]
3v1
4v1
end
subgraph sg_1v1_var_stream_of_edges ["var <tt>stream_of_edges</tt>"]
2v1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,75 +9,79 @@ classDef pushClass fill:#ff8,stroke:#000,text-align:left,white-space:pre
linkStyle default stroke:#aaa,stroke-width:4px,color:red,font-size:1.5em;
subgraph sg_1v1 ["sg_1v1 stratum 0"]
1v1[\"(1v1) <code>source_iter(vec![0])</code>"/]:::pullClass
8v1[\"(8v1) <code>map(|v| (v, ()))</code>"/]:::pullClass
6v1[\"(6v1) <code>join()</code>"/]:::pullClass
7v1[\"(7v1) <code>flat_map(|(src, ((), dst))| [src, dst])</code>"/]:::pullClass
9v1[\"(9v1) <code>map(|v| (v, ()))</code>"/]:::pullClass
7v1[\"(7v1) <code>join()</code>"/]:::pullClass
8v1[\"(8v1) <code>flat_map(|(src, ((), dst))| [src, dst])</code>"/]:::pullClass
4v1[\"(4v1) <code>union()</code>"/]:::pullClass
5v1[/"(5v1) <code>tee()</code>"\]:::pushClass
5v1[\"(5v1) <code>unique()</code>"/]:::pullClass
6v1[/"(6v1) <code>tee()</code>"\]:::pushClass
15v1["(15v1) <code>handoff</code>"]:::otherClass
15v1--->8v1
1v1--0--->4v1
8v1--0--->6v1
6v1--->7v1
7v1--1--->4v1
15v1--->9v1
1v1--->4v1
9v1--0--->7v1
7v1--->8v1
8v1--->4v1
4v1--->5v1
5v1--0--->15v1
5v1--->6v1
6v1--->15v1
subgraph sg_1v1_var_my_join ["var <tt>my_join</tt>"]
6v1
7v1
8v1
end
subgraph sg_1v1_var_origin ["var <tt>origin</tt>"]
1v1
end
subgraph sg_1v1_var_reached_vertices ["var <tt>reached_vertices</tt>"]
4v1
5v1
6v1
end
end
subgraph sg_2v1 ["sg_2v1 stratum 0"]
2v1[\"(2v1) <code>source_stream(pairs_recv)</code>"/]:::pullClass
3v1[/"(3v1) <code>tee()</code>"\]:::pushClass
9v1[/"(9v1) <code>flat_map(|(src, dst)| [src, dst])</code>"\]:::pushClass
10v1[/"(10v1) <code>tee()</code>"\]:::pushClass
12v1[/"(12v1) <code>unique()</code>"\]:::pushClass
10v1[/"(10v1) <code>flat_map(|(src, dst)| [src, dst])</code>"\]:::pushClass
11v1[/"(11v1) <code>tee()</code>"\]:::pushClass
13v1[/"(13v1) <code>for_each(|v| println!(&quot;Received vertex: {}&quot;, v))</code>"\]:::pushClass
2v1--->3v1
3v1--0--->9v1
9v1--->10v1
10v1--1--->12v1
12v1--->13v1
3v1--0--->10v1
10v1--->11v1
11v1--->13v1
subgraph sg_2v1_var_all_vertices ["var <tt>all_vertices</tt>"]
9v1
10v1
11v1
end
subgraph sg_2v1_var_stream_of_edges ["var <tt>stream_of_edges</tt>"]
2v1
3v1
end
end
subgraph sg_3v1 ["sg_3v1 stratum 1"]
11v1[\"(11v1) <code>difference()</code>"/]:::pullClass
12v1[\"(12v1) <code>difference()</code>"/]:::pullClass
14v1[/"(14v1) <code>for_each(|v| println!(&quot;unreached_vertices vertex: {}&quot;, v))</code>"\]:::pushClass
11v1--->14v1
12v1--->14v1
subgraph sg_3v1_var_unreached_vertices ["var <tt>unreached_vertices</tt>"]
11v1
12v1
end
end
3v1--1--->16v1
5v1--1--->18v1
10v1--0--->17v1
6v1--->18v1
11v1--->17v1
16v1["(16v1) <code>handoff</code>"]:::otherClass
16v1--1--->6v1
16v1--1--->7v1
17v1["(17v1) <code>handoff</code>"]:::otherClass
17v1--pos--->11v1
17v1--pos--->12v1
18v1["(18v1) <code>handoff</code>"]:::otherClass
18v1==neg===o11v1
18v1==neg===o12v1

Received vertex: 5
Received vertex: 10
Received vertex: 0
Received vertex: 3
Received vertex: 3
Received vertex: 6
Received vertex: 6
Received vertex: 5
Received vertex: 11
Received vertex: 12
unreached_vertices vertex: 11
Expand Down
3 changes: 2 additions & 1 deletion hydroflow/tests/surface_codegen.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashSet;

use hydroflow::compiled::pull::HalfSetJoinState;
use hydroflow::scheduled::graph::Hydroflow;
use hydroflow::util::collect_ready;
use hydroflow::{assert_graphvis_snapshots, hydroflow_syntax};
Expand Down Expand Up @@ -739,7 +740,7 @@ pub fn test_covid_tracing() {
source_stream(diagnosed_recv) -> [0]exposed;

new_exposed = (
join() ->
join::<HalfSetJoinState>() ->
filter(|(_pid_a, ((_pid_b, t_contact), (t_from, t_to)))| {
(t_from..=t_to).contains(&t_contact)
}) ->
Expand Down
26 changes: 19 additions & 7 deletions hydroflow_lang/src/graph/ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ use crate::graph::{OpInstGenerics, OperatorInstance};
/// ```hydroflow
/// // should print `(hello, (world, cleveland))`
/// source_iter(vec![("hello", "world"), ("stay", "gold")]) -> [0]my_join;
/// source_iter(vec![("hello", "cleveland")]) -> [1]my_join;
/// my_join = join()
/// -> assert([("hello", ("world", "cleveland"))]);
/// source_iter(vec![("hello", "cleveland"), ("hello", "cleveland")]) -> [1]my_join;
/// my_join = join() -> assert([("hello", ("world", "cleveland")), ("hello", ("world", "cleveland"))]);
/// ```
///
/// `join` can also be provided with one or two generic lifetime persistence arguments, either
Expand Down Expand Up @@ -45,10 +44,23 @@ use crate::graph::{OpInstGenerics, OperatorInstance};
/// ```
///
/// Join also accepts one type argument that controls how the join state is built up. This (currently) allows switching between a SetUnion and NonSetUnion implementation.
/// The default is HalfMultisetJoinState
/// For example:
/// ```hydroflow,ignore
/// join::<HalfSetJoinState>();
/// join::<HalfMultisetJoinState>();
/// ```hydroflow
/// lhs = source_iter([("a", 0), ("a", 0)]) -> tee();
/// rhs = source_iter([("a", 0)]) -> tee();
///
/// lhs -> [0]default_join;
/// rhs -> [1]default_join;
/// default_join = join() -> assert([("a", (0, 0)), ("a", (0, 0))]);
///
/// lhs -> [0]multiset_join;
/// rhs -> [1]multiset_join;
/// multiset_join = join::<hydroflow::compiled::pull::HalfMultisetJoinState>() -> assert([("a", (0, 0)), ("a", (0, 0))]);
///
/// lhs -> [0]set_join;
/// rhs -> [1]set_join;
/// set_join = join::<hydroflow::compiled::pull::HalfSetJoinState>() -> assert([("a", (0, 0))]);
/// ```
///
/// ### Examples
Expand Down Expand Up @@ -128,7 +140,7 @@ pub const JOIN: OperatorConstraints = OperatorConstraints {
.get(0)
.map(ToTokens::to_token_stream)
.unwrap_or(quote_spanned!(op_span=>
#root::compiled::pull::HalfSetJoinState
#root::compiled::pull::HalfMultisetJoinState
));

// TODO: This is really bad.
Expand Down

0 comments on commit 6f3c536

Please sign in to comment.