Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add join_multiset() #804

Merged
merged 3 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benches/benches/reachability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ fn benchmark_hydroflow_surface(c: &mut Criterion) {
hydroflow_syntax! {
origin = source_iter(vec![1]);
stream_of_edges = source_iter(edges);
reached_vertices = union() -> unique();
reached_vertices = union();
origin -> reached_vertices;

my_join_tee = join() -> flat_map(|(src, ((), dst))| [src, dst]) -> tee();
Expand Down
8 changes: 6 additions & 2 deletions docs/docs/hydroflow/quickstart/example_5_reachability.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,14 @@ We route the `origin` vertex into it as one input right away:

<CodeBlock language="rust">{getLines(exampleCode, 8, 12)}</CodeBlock>

Note the square-bracket syntax for differentiating the multiple inputs to `union()`
is the same as that of `join()` (except that union can have an unbounded number of inputs,
whereas `join()` is defined to only have two.)

Now, `join()` is defined to only have one output. In our program, we want to copy the joined
output to two places: to the original `for_each` from above to print output, and *also*
back to the `union` operator we called `reached_vertices`. This is also the reason why `reached_vertices` must now also contain a `unique()` because without it the data would cycle endlessly in the graph.
We feed the `join()` output through a `flat_map()` as before, and then we feed the result into a [`tee()`](../syntax/surface_ops_gen.md#tee) operator,
back to the `union` operator we called `reached_vertices`. We feed the `join()` output
through a `flat_map()` as before, and then we feed the result into a [`tee()`](../syntax/surface_ops_gen.md#tee) operator,
which is the mirror image of `union()`: instead of merging many inputs to one output,
it copies one input to many different outputs. Each input element is _cloned_, in Rust terms, and
given to each of the outputs. The syntax for the outputs of `tee()` mirrors that of union: we *append*
Expand Down
10 changes: 5 additions & 5 deletions hydroflow/examples/example_5_reachability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ pub fn main() {
// inputs: the origin vertex (vertex 0) and stream of input edges
origin = source_iter(vec![0]);
stream_of_edges = source_stream(edges_recv);
origin -> reached_vertices;
reached_vertices = union() -> unique();
origin -> [0]reached_vertices;
reached_vertices = union();

// the join
reached_vertices -> map(|v| (v, ())) -> [0]my_join_tee;
stream_of_edges -> [1]my_join_tee;
my_join_tee = join() -> flat_map(|(src, ((), dst))| [src, dst]) -> unique() -> tee();
my_join_tee = join() -> flat_map(|(src, ((), dst))| [src, dst]) -> tee();

// the loop and the output
my_join_tee[0] -> reached_vertices;
my_join_tee[1] -> for_each(|x| println!("Reached: {}", x));
my_join_tee[0] -> [1]reached_vertices;
my_join_tee[1] -> unique() -> for_each(|x| println!("Reached: {}", x));
};

println!(
Expand Down
14 changes: 7 additions & 7 deletions hydroflow/examples/example_6_unreachability.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,26 @@ pub fn main() {
let mut flow = hydroflow_syntax! {
origin = source_iter(vec![0]);
stream_of_edges = source_stream(pairs_recv) -> tee();
reached_vertices = union() -> unique() -> tee();
origin -> reached_vertices;
reached_vertices = union()->tee();
origin -> [0]reached_vertices;

// the join for reachable vertices
my_join = join() -> flat_map(|(src, ((), dst))| [src, dst]);
reached_vertices -> map(|v| (v, ())) -> [0]my_join;
reached_vertices[0] -> map(|v| (v, ())) -> [0]my_join;
stream_of_edges[1] -> [1]my_join;

// the loop
my_join -> reached_vertices;
my_join -> [1]reached_vertices;

// the difference all_vertices - reached_vertices
all_vertices = stream_of_edges[0]
-> flat_map(|(src, dst)| [src, dst]) -> tee();
unreached_vertices = difference();
all_vertices -> [pos]unreached_vertices;
reached_vertices -> [neg]unreached_vertices;
all_vertices[0] -> [pos]unreached_vertices;
reached_vertices[1] -> [neg]unreached_vertices;

// the output
all_vertices -> for_each(|v| println!("Received vertex: {}", v));
all_vertices[1] -> unique() -> for_each(|v| println!("Received vertex: {}", v));
unreached_vertices -> for_each(|v| println!("unreached_vertices vertex: {}", v));
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ digraph {
n4v1 [label="(n4v1) source_stream(diagnosed_recv)", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n9v1 [label="(n9v1) map(|(pid, t)| (pid, (t, t + TRANSMISSIBLE_DURATION)))", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n3v1 [label="(n3v1) union()", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n5v1 [label="(n5v1) join::<HalfSetJoinState>()", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n5v1 [label="(n5v1) join()", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n6v1 [label="(n6v1) filter(|(_pid_a, ((_pid_b, t_contact), (t_from, t_to)))| {\l (t_from..=t_to).contains(&t_contact)\l})\l", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n7v1 [label="(n7v1) map(|(_pid_a, (pid_b_t_contact, _t_from_to))| pid_b_t_contact)", fontname=Monaco, shape=invhouse, style = filled, color = "#0022ff", fontcolor = "#ffffff"]
n8v1 [label="(n8v1) tee()", fontname=Monaco, shape=house, style = filled, color = "#ffff00"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ subgraph sg_1v1 ["sg_1v1 stratum 0"]
4v1[\"(4v1) <code>source_stream(diagnosed_recv)</code>"/]:::pullClass
9v1[\"(9v1) <code>map(|(pid, t)| (pid, (t, t + TRANSMISSIBLE_DURATION)))</code>"/]:::pullClass
3v1[\"(3v1) <code>union()</code>"/]:::pullClass
5v1[\"(5v1) <code>join::&lt;HalfSetJoinState&gt;()</code>"/]:::pullClass
5v1[\"(5v1) <code>join()</code>"/]:::pullClass
6v1[\"<div style=text-align:center>(6v1)</div> <code>filter(|(_pid_a, ((_pid_b, t_contact), (t_from, t_to)))| {<br> (t_from..=t_to).contains(&amp;t_contact)<br>})</code>"/]:::pullClass
7v1[\"(7v1) <code>map(|(_pid_a, (pid_b_t_contact, _t_from_to))| pid_b_t_contact)</code>"/]:::pullClass
8v1[/"(8v1) <code>tee()</code>"\]:::pushClass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,33 @@ subgraph sg_1v1 ["sg_1v1 stratum 0"]
1v1[\"(1v1) <code>source_iter(vec![0])</code>"/]:::pullClass
2v1[\"(2v1) <code>source_stream(edges_recv)</code>"/]:::pullClass
3v1[\"(3v1) <code>union()</code>"/]:::pullClass
4v1[\"(4v1) <code>unique()</code>"/]:::pullClass
5v1[\"(5v1) <code>map(|v| (v, ()))</code>"/]:::pullClass
6v1[\"(6v1) <code>join()</code>"/]:::pullClass
7v1[\"(7v1) <code>flat_map(|(src, ((), dst))| [src, dst])</code>"/]:::pullClass
8v1[\"(8v1) <code>unique()</code>"/]:::pullClass
9v1[/"(9v1) <code>tee()</code>"\]:::pushClass
10v1[/"(10v1) <code>for_each(|x| println!(&quot;Reached: {}&quot;, x))</code>"\]:::pushClass
11v1["(11v1) <code>handoff</code>"]:::otherClass
11v1--->3v1
1v1--->3v1
2v1--1--->6v1
4v1[\"(4v1) <code>map(|v| (v, ()))</code>"/]:::pullClass
5v1[\"(5v1) <code>join()</code>"/]:::pullClass
6v1[\"(6v1) <code>flat_map(|(src, ((), dst))| [src, dst])</code>"/]:::pullClass
7v1[/"(7v1) <code>tee()</code>"\]:::pushClass
8v1[/"(8v1) <code>unique()</code>"\]:::pushClass
9v1[/"(9v1) <code>for_each(|x| println!(&quot;Reached: {}&quot;, x))</code>"\]:::pushClass
10v1["(10v1) <code>handoff</code>"]:::otherClass
10v1--1--->3v1
1v1--0--->3v1
2v1--1--->5v1
3v1--->4v1
4v1--->5v1
5v1--0--->6v1
4v1--0--->5v1
5v1--->6v1
6v1--->7v1
7v1--->8v1
7v1--0--->10v1
7v1--1--->8v1
8v1--->9v1
9v1--0--->11v1
9v1--1--->10v1
subgraph sg_1v1_var_my_join_tee ["var <tt>my_join_tee</tt>"]
5v1
6v1
7v1
8v1
9v1
end
subgraph sg_1v1_var_origin ["var <tt>origin</tt>"]
1v1
end
subgraph sg_1v1_var_reached_vertices ["var <tt>reached_vertices</tt>"]
3v1
4v1
end
subgraph sg_1v1_var_stream_of_edges ["var <tt>stream_of_edges</tt>"]
2v1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,79 +9,75 @@ classDef pushClass fill:#ff8,stroke:#000,text-align:left,white-space:pre
linkStyle default stroke:#aaa,stroke-width:4px,color:red,font-size:1.5em;
subgraph sg_1v1 ["sg_1v1 stratum 0"]
1v1[\"(1v1) <code>source_iter(vec![0])</code>"/]:::pullClass
9v1[\"(9v1) <code>map(|v| (v, ()))</code>"/]:::pullClass
7v1[\"(7v1) <code>join()</code>"/]:::pullClass
8v1[\"(8v1) <code>flat_map(|(src, ((), dst))| [src, dst])</code>"/]:::pullClass
8v1[\"(8v1) <code>map(|v| (v, ()))</code>"/]:::pullClass
6v1[\"(6v1) <code>join()</code>"/]:::pullClass
7v1[\"(7v1) <code>flat_map(|(src, ((), dst))| [src, dst])</code>"/]:::pullClass
4v1[\"(4v1) <code>union()</code>"/]:::pullClass
5v1[\"(5v1) <code>unique()</code>"/]:::pullClass
6v1[/"(6v1) <code>tee()</code>"\]:::pushClass
5v1[/"(5v1) <code>tee()</code>"\]:::pushClass
15v1["(15v1) <code>handoff</code>"]:::otherClass
15v1--->9v1
1v1--->4v1
9v1--0--->7v1
7v1--->8v1
8v1--->4v1
15v1--->8v1
1v1--0--->4v1
8v1--0--->6v1
6v1--->7v1
7v1--1--->4v1
4v1--->5v1
5v1--->6v1
6v1--->15v1
5v1--0--->15v1
subgraph sg_1v1_var_my_join ["var <tt>my_join</tt>"]
6v1
7v1
8v1
end
subgraph sg_1v1_var_origin ["var <tt>origin</tt>"]
1v1
end
subgraph sg_1v1_var_reached_vertices ["var <tt>reached_vertices</tt>"]
4v1
5v1
6v1
end
end
subgraph sg_2v1 ["sg_2v1 stratum 0"]
2v1[\"(2v1) <code>source_stream(pairs_recv)</code>"/]:::pullClass
3v1[/"(3v1) <code>tee()</code>"\]:::pushClass
10v1[/"(10v1) <code>flat_map(|(src, dst)| [src, dst])</code>"\]:::pushClass
11v1[/"(11v1) <code>tee()</code>"\]:::pushClass
9v1[/"(9v1) <code>flat_map(|(src, dst)| [src, dst])</code>"\]:::pushClass
10v1[/"(10v1) <code>tee()</code>"\]:::pushClass
12v1[/"(12v1) <code>unique()</code>"\]:::pushClass
13v1[/"(13v1) <code>for_each(|v| println!(&quot;Received vertex: {}&quot;, v))</code>"\]:::pushClass
2v1--->3v1
3v1--0--->10v1
10v1--->11v1
11v1--->13v1
3v1--0--->9v1
9v1--->10v1
10v1--1--->12v1
12v1--->13v1
subgraph sg_2v1_var_all_vertices ["var <tt>all_vertices</tt>"]
9v1
10v1
11v1
end
subgraph sg_2v1_var_stream_of_edges ["var <tt>stream_of_edges</tt>"]
2v1
3v1
end
end
subgraph sg_3v1 ["sg_3v1 stratum 1"]
12v1[\"(12v1) <code>difference()</code>"/]:::pullClass
11v1[\"(11v1) <code>difference()</code>"/]:::pullClass
14v1[/"(14v1) <code>for_each(|v| println!(&quot;unreached_vertices vertex: {}&quot;, v))</code>"\]:::pushClass
12v1--->14v1
11v1--->14v1
subgraph sg_3v1_var_unreached_vertices ["var <tt>unreached_vertices</tt>"]
12v1
11v1
end
end
3v1--1--->16v1
6v1--->18v1
11v1--->17v1
5v1--1--->18v1
10v1--0--->17v1
16v1["(16v1) <code>handoff</code>"]:::otherClass
16v1--1--->7v1
16v1--1--->6v1
17v1["(17v1) <code>handoff</code>"]:::otherClass
17v1--pos--->12v1
17v1--pos--->11v1
18v1["(18v1) <code>handoff</code>"]:::otherClass
18v1==neg===o12v1
18v1==neg===o11v1

Received vertex: 5
Received vertex: 10
Received vertex: 0
Received vertex: 3
Received vertex: 3
Received vertex: 6
Received vertex: 6
Received vertex: 5
Received vertex: 11
Received vertex: 12
unreached_vertices vertex: 11
Expand Down
3 changes: 1 addition & 2 deletions hydroflow/tests/surface_codegen.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::collections::HashSet;

use hydroflow::compiled::pull::HalfSetJoinState;
use hydroflow::scheduled::graph::Hydroflow;
use hydroflow::util::collect_ready;
use hydroflow::{assert_graphvis_snapshots, hydroflow_syntax};
Expand Down Expand Up @@ -740,7 +739,7 @@ pub fn test_covid_tracing() {
source_stream(diagnosed_recv) -> [0]exposed;

new_exposed = (
join::<HalfSetJoinState>() ->
join() ->
filter(|(_pid_a, ((_pid_b, t_contact), (t_from, t_to)))| {
(t_from..=t_to).contains(&t_contact)
}) ->
Expand Down
30 changes: 5 additions & 25 deletions hydroflow_lang/src/graph/ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ use crate::graph::{OpInstGenerics, OperatorInstance};
/// Forms the equijoin of the tuples in the input streams by their first (key) attribute. Note that the result nests the 2nd input field (values) into a tuple in the 2nd output field.
///
/// ```hydroflow
/// // should print `(hello, (world, cleveland))`
/// source_iter(vec![("hello", "world"), ("stay", "gold")]) -> [0]my_join;
/// source_iter(vec![("hello", "cleveland"), ("hello", "cleveland")]) -> [1]my_join;
/// my_join = join() -> assert([("hello", ("world", "cleveland")), ("hello", ("world", "cleveland"))]);
/// source_iter(vec![("hello", "world"), ("stay", "gold"), ("hello", "world")]) -> [0]my_join;
/// source_iter(vec![("hello", "cleveland")]) -> [1]my_join;
/// my_join = join()
/// -> assert([("hello", ("world", "cleveland"))]);
/// ```
///
/// `join` can also be provided with one or two generic lifetime persistence arguments, either
Expand All @@ -43,26 +43,6 @@ use crate::graph::{OpInstGenerics, OperatorInstance};
/// // etc.
/// ```
///
/// Join also accepts one type argument that controls how the join state is built up. This (currently) allows switching between a SetUnion and NonSetUnion implementation.
/// The default is HalfMultisetJoinState
/// For example:
/// ```hydroflow
/// lhs = source_iter([("a", 0), ("a", 0)]) -> tee();
/// rhs = source_iter([("a", 0)]) -> tee();
///
/// lhs -> [0]default_join;
/// rhs -> [1]default_join;
/// default_join = join() -> assert([("a", (0, 0)), ("a", (0, 0))]);
///
/// lhs -> [0]multiset_join;
/// rhs -> [1]multiset_join;
/// multiset_join = join::<hydroflow::compiled::pull::HalfMultisetJoinState>() -> assert([("a", (0, 0)), ("a", (0, 0))]);
///
/// lhs -> [0]set_join;
/// rhs -> [1]set_join;
/// set_join = join::<hydroflow::compiled::pull::HalfSetJoinState>() -> assert([("a", (0, 0))]);
/// ```
///
/// ### Examples
///
/// ```rustbook
Expand Down Expand Up @@ -140,7 +120,7 @@ pub const JOIN: OperatorConstraints = OperatorConstraints {
.get(0)
.map(ToTokens::to_token_stream)
.unwrap_or(quote_spanned!(op_span=>
#root::compiled::pull::HalfMultisetJoinState
#root::compiled::pull::HalfSetJoinState
));

// TODO: This is really bad.
Expand Down
69 changes: 69 additions & 0 deletions hydroflow_lang/src/graph/ops/join_multiset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use syn::{parse_quote, parse_quote_spanned};

use super::{
FlowProperties, FlowPropertyVal, OperatorCategory, OperatorConstraints, WriteContextArgs,
RANGE_0, RANGE_1,
};
use crate::graph::{OpInstGenerics, OperatorInstance};

/// > 2 input streams of type <(K, V1)> and <(K, V2)>, 1 output stream of type <(K, (V1, V2))>
///
/// This operator is equivalent to `join` except that the LHS and RHS are collected into multisets rather than sets before joining.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm order is preserved I think so multisets might not be the right word

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that there are many (multiset) join algorithms, that seems like a side-effect of the implementation that we may not want to provide as a guarantee, so I wouldn't sweat it.

///
/// For example:
/// ```hydroflow
/// lhs = source_iter([("a", 0), ("a", 0)]) -> tee();
/// rhs = source_iter([("a", "hydro")]) -> tee();
///
/// lhs -> [0]multiset_join;
/// rhs -> [1]multiset_join;
/// multiset_join = join_multiset() -> assert([("a", (0, 0)), ("a", (0, 0))]);
///
/// lhs -> [0]set_join;
/// rhs -> [1]set_join;
/// set_join = join() -> assert([("a", (0, 0))]);
/// ```
pub const JOIN_MULTISET: OperatorConstraints = OperatorConstraints {
name: "join_multiset",
categories: &[OperatorCategory::MultiIn],
hard_range_inn: &(2..=2),
soft_range_inn: &(2..=2),
hard_range_out: RANGE_1,
soft_range_out: RANGE_1,
num_args: 0,
persistence_args: &(0..=2),
type_args: RANGE_0,
is_external_input: false,
ports_inn: Some(|| super::PortListSpec::Fixed(parse_quote! { 0, 1 })),
ports_out: None,
properties: FlowProperties {
deterministic: FlowPropertyVal::Preserve,
monotonic: FlowPropertyVal::Preserve,
inconsistency_tainted: false,
},
input_delaytype_fn: |_| None,
write_fn: |wc @ &WriteContextArgs {
root,
op_span,
op_inst: op_inst @ OperatorInstance { .. },
..
},
diagnostics| {
let join_type = parse_quote_spanned! {op_span=> // Uses `lat_type.span()`!
#root::compiled::pull::HalfMultisetJoinState
};

let wc = WriteContextArgs {
op_inst: &OperatorInstance {
generics: OpInstGenerics {
type_args: vec![join_type],
..wc.op_inst.generics.clone()
},
..op_inst.clone()
},
..wc.clone()
};

(super::join::JOIN.write_fn)(&wc, diagnostics)
},
};
Loading
Loading