Skip to content

Commit

Permalink
chore(execution-engine): refactor unseen canon stream creation (#636)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikevoronov authored Jul 20, 2023
1 parent 3fa8be0 commit 6fd0385
Show file tree
Hide file tree
Showing 15 changed files with 295 additions and 289 deletions.
6 changes: 1 addition & 5 deletions air/src/execution_step/boxed_value/canon_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
* limitations under the License.
*/

use super::Stream;
use super::ValueAggregate;
use crate::execution_step::Generation;
use crate::JValue;

use air_interpreter_cid::CID;
Expand All @@ -42,9 +40,7 @@ impl CanonStream {
Self { values, tetraplet }
}

pub(crate) fn from_stream(stream: &Stream, peer_pk: String) -> Self {
// it's always possible to iter over all generations of a stream
let values = stream.iter(Generation::Last).unwrap().cloned().collect::<Vec<_>>();
pub(crate) fn from_values(values: Vec<ValueAggregate>, peer_pk: String) -> Self {
let tetraplet = SecurityTetraplet::new(peer_pk, "", "", "");
Self {
values,
Expand Down
11 changes: 0 additions & 11 deletions air/src/execution_step/boxed_value/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@ pub struct Stream {
}

impl Stream {
pub(crate) fn new(values: Vec<Vec<ValueAggregate>>, previous_gens_count: usize) -> Self {
Self {
values,
previous_gens_count,
}
}

pub(crate) fn from_generations_count(previous_count: GenerationIdx, current_count: GenerationIdx) -> Self {
let last_generation_count = GenerationIdx::from(1);
// TODO: bubble up an overflow error instead of expect
Expand Down Expand Up @@ -218,10 +211,6 @@ impl Stream {
fn remove_empty_generations(&mut self) {
self.values.retain(|values| !values.is_empty());
}

pub(crate) fn previous_gens_count(&self) -> usize {
self.previous_gens_count
}
}

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
Expand Down
36 changes: 10 additions & 26 deletions air/src/execution_step/boxed_value/stream_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,32 +81,18 @@ impl StreamMap {
&mut self.stream
}

// TODO: change the implementation to mutate the underlying stream
// instead of creating a new one
pub(crate) fn create_unique_keys_stream(&self) -> Stream {
/// Returns an iterator to values with unique keys.
pub(crate) fn iter_unique_key(&self) -> impl Iterator<Item = &ValueAggregate> {
use std::collections::HashSet;

let mut met_keys = HashSet::new();

// unwrap is safe because slice_iter always returns Some if supplied generations are valid
let new_values = self
.stream
.slice_iter(Generation::Nth(0.into()), Generation::Last)
.unwrap()
.map(|values| {
values
.iter()
.filter(|v| {
StreamMapKey::from_kvpair(v)
.map(|key| met_keys.insert(key))
.unwrap_or(false)
})
.cloned()
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();

Stream::new(new_values, self.stream.previous_gens_count())
// it's always possible to go through all values
self.stream.iter(Generation::Last).unwrap().filter(move |value| {
StreamMapKey::from_kvpair(value)
.map(|key| met_keys.insert(key))
.unwrap_or(false)
})
}
}

Expand Down Expand Up @@ -270,8 +256,7 @@ mod test {
);
}

let unique_keys_only = stream_map.create_unique_keys_stream();
let mut iter = unique_keys_only.iter(Generation::Last).unwrap();
let mut iter = stream_map.iter_unique_key();

assert_eq!(&json!(0), iter.next().unwrap().get_result().get("value").unwrap());
assert_eq!(&json!(1), iter.next().unwrap().get_result().get("value").unwrap());
Expand All @@ -295,8 +280,7 @@ mod test {
insert_into_map(&mut stream_map, &key_values[1], Generation::nth(4), CurrentData);
insert_into_map(&mut stream_map, &key_values[3], Generation::nth(2), CurrentData);

let unique_keys_only = stream_map.create_unique_keys_stream();
let mut iter = unique_keys_only.iter(Generation::Last).unwrap();
let mut iter = stream_map.iter_unique_key();

assert_eq!(&json!(0), iter.next().unwrap().get_result().get("value").unwrap());
assert_eq!(&json!(2), iter.next().unwrap().get_result().get("value").unwrap());
Expand Down
47 changes: 28 additions & 19 deletions air/src/execution_step/execution_context/cid_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use crate::execution_step::ValueAggregate;
use crate::JValue;
use crate::UncatchableError;

use air_interpreter_cid::CidCalculationError;
use air_interpreter_cid::CID;
use air_interpreter_data::CanonCidAggregate;
use air_interpreter_data::CanonResultCidAggregate;
Expand All @@ -45,24 +44,6 @@ impl ExecutionCidState {
Self::default()
}

pub fn insert_value(
&mut self,
value: Rc<JValue>,
tetraplet: RcSecurityTetraplet,
argument_hash: Rc<str>,
) -> Result<Rc<CID<ServiceResultCidAggregate>>, CidCalculationError> {
let value_cid = self.value_tracker.record_value(value)?;
let tetraplet_cid = self.tetraplet_tracker.record_value(tetraplet)?;

let service_result_agg = ServiceResultCidAggregate {
value_cid,
argument_hash,
tetraplet_cid,
};

self.service_result_agg_tracker.record_value(service_result_agg)
}

pub(crate) fn from_cid_info(prev_cid_info: CidInfo, current_cid_info: CidInfo) -> Self {
let value_tracker = CidTracker::from_cid_stores(prev_cid_info.value_store, current_cid_info.value_store);
let tetraplet_tracker =
Expand All @@ -85,6 +66,34 @@ impl ExecutionCidState {
}
}

pub fn track_service_result(
&mut self,
value: Rc<JValue>,
tetraplet: RcSecurityTetraplet,
argument_hash: Rc<str>,
) -> Result<Rc<CID<ServiceResultCidAggregate>>, UncatchableError> {
let value_cid = self.value_tracker.track_value(value)?;
let tetraplet_cid = self.tetraplet_tracker.track_value(tetraplet)?;
let service_result_agg = ServiceResultCidAggregate::new(value_cid, argument_hash, tetraplet_cid);

self.service_result_agg_tracker
.track_value(service_result_agg)
.map_err(UncatchableError::from)
}

pub(crate) fn track_canon_value(
&mut self,
canon_value: &ValueAggregate,
) -> Result<Rc<CID<CanonCidAggregate>>, UncatchableError> {
let value_cid = self.value_tracker.track_value(canon_value.get_result().clone())?;
let tetraplet = self.tetraplet_tracker.track_value(canon_value.get_tetraplet())?;

let canon_value_aggregate = CanonCidAggregate::new(value_cid, tetraplet, canon_value.get_provenance());
self.canon_element_tracker
.track_value(canon_value_aggregate)
.map_err(UncatchableError::from)
}

pub(crate) fn get_value_by_cid(&self, cid: &CID<JValue>) -> Result<Rc<JValue>, UncatchableError> {
self.value_tracker
.get(cid)
Expand Down
16 changes: 8 additions & 8 deletions air/src/execution_step/instructions/call/call_result_setter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ pub(crate) fn populate_context_from_peer_service_result<'i>(
match output {
CallOutputValue::Scalar(scalar) => {
let peer_id: Box<str> = tetraplet.peer_pk.as_str().into();
let service_result_agg_cid = exec_ctx
.cid_state
.insert_value(executed_result.result.clone(), tetraplet, argument_hash)
.map_err(UncatchableError::from)?;
let service_result_agg_cid =
exec_ctx
.cid_state
.track_service_result(executed_result.result.clone(), tetraplet, argument_hash)?;
let executed_result = ValueAggregate::from_service_result(executed_result, service_result_agg_cid.clone());

exec_ctx.scalars.set_scalar_value(scalar.name, executed_result)?;
Expand All @@ -51,10 +51,10 @@ pub(crate) fn populate_context_from_peer_service_result<'i>(
}
CallOutputValue::Stream(stream) => {
let peer_id: Box<str> = tetraplet.peer_pk.as_str().into();
let service_result_agg_cid = exec_ctx
.cid_state
.insert_value(executed_result.result.clone(), tetraplet, argument_hash)
.map_err(UncatchableError::from)?;
let service_result_agg_cid =
exec_ctx
.cid_state
.track_service_result(executed_result.result.clone(), tetraplet, argument_hash)?;

let executed_result = ValueAggregate::from_service_result(executed_result, service_result_agg_cid.clone());

Expand Down
17 changes: 9 additions & 8 deletions air/src/execution_step/instructions/call/prev_result_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ fn handle_service_error(
let failed_value = CallServiceFailed::new(service_result.ret_code, error_message).to_value();

let peer_id: Box<str> = tetraplet.peer_pk.as_str().into();
let service_result_agg_cid = exec_ctx
.cid_state
.insert_value(failed_value.into(), tetraplet, argument_hash)
.map_err(UncatchableError::from)?;
let service_result_agg_cid =
exec_ctx
.cid_state
.track_service_result(failed_value.into(), tetraplet, argument_hash)?;

exec_ctx.record_call_cid(peer_id, &service_result_agg_cid);
trace_ctx.meet_call_end(Failed(service_result_agg_cid));
Expand All @@ -207,10 +207,11 @@ fn try_to_service_result(

let failed_value = CallServiceFailed::new(i32::MAX, error_msg.clone()).to_value();

let service_result_agg_cid = exec_ctx
.cid_state
.insert_value(failed_value.into(), tetraplet.clone(), argument_hash.clone())
.map_err(UncatchableError::from)?;
let service_result_agg_cid = exec_ctx.cid_state.track_service_result(
failed_value.into(),
tetraplet.clone(),
argument_hash.clone(),
)?;
let error = CallResult::failed(service_result_agg_cid);

trace_ctx.meet_call_end(error);
Expand Down
Loading

0 comments on commit 6fd0385

Please sign in to comment.