Skip to content

Commit

Permalink
Add support for nested gRPC callouts. (#240)
Browse files Browse the repository at this point in the history
Signed-off-by: andytesti <[email protected]>
  • Loading branch information
andytesti authored Jul 19, 2024
1 parent 176e12d commit ec3ddd2
Showing 1 changed file with 42 additions and 35 deletions.
77 changes: 42 additions & 35 deletions src/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,8 @@ impl Dispatcher {
}

fn on_grpc_receive(&self, token_id: u32, response_size: usize) {
if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) {
let context_id = self.grpc_callouts.borrow_mut().remove(&token_id);
if let Some(context_id) = context_id {
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
Expand All @@ -467,24 +468,26 @@ impl Dispatcher {
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_call_response(token_id, 0, response_size);
}
} else if let Some(context_id) = self.grpc_streams.borrow_mut().get(&token_id) {
let context_id = *context_id;
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_stream_message(token_id, response_size);
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_stream_message(token_id, response_size);
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_stream_message(token_id, response_size);
}
} else {
// TODO: change back to a panic once underlying issue is fixed.
trace!("on_grpc_receive_initial_metadata: invalid token_id");
let context_id = self.grpc_streams.borrow().get(&token_id).cloned();
if let Some(context_id) = context_id {
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_stream_message(token_id, response_size);
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_stream_message(token_id, response_size);
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_stream_message(token_id, response_size);
}
} else {
// TODO: change back to a panic once underlying issue is fixed.
trace!("on_grpc_receive_initial_metadata: invalid token_id");
}
}
}

Expand Down Expand Up @@ -514,7 +517,8 @@ impl Dispatcher {
}

fn on_grpc_close(&self, token_id: u32, status_code: u32) {
if let Some(context_id) = self.grpc_callouts.borrow_mut().remove(&token_id) {
let context_id = self.grpc_callouts.borrow_mut().remove(&token_id);
if let Some(context_id) = context_id {
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
Expand All @@ -528,23 +532,26 @@ impl Dispatcher {
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_call_response(token_id, status_code, 0);
}
} else if let Some(context_id) = self.grpc_streams.borrow_mut().remove(&token_id) {
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_stream_close(token_id, status_code)
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_stream_close(token_id, status_code)
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_stream_close(token_id, status_code)
}
} else {
// TODO: change back to a panic once underlying issue is fixed.
trace!("on_grpc_close: invalid token_id, a non-connected stream has closed");
let context_id = self.grpc_streams.borrow_mut().remove(&token_id);
if let Some(context_id) = context_id {
if let Some(http_stream) = self.http_streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
http_stream.on_grpc_stream_close(token_id, status_code)
} else if let Some(stream) = self.streams.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
stream.on_grpc_stream_close(token_id, status_code)
} else if let Some(root) = self.roots.borrow_mut().get_mut(&context_id) {
self.active_id.set(context_id);
hostcalls::set_effective_context(context_id).unwrap();
root.on_grpc_stream_close(token_id, status_code)
}
} else {
// TODO: change back to a panic once underlying issue is fixed.
trace!("on_grpc_close: invalid token_id, a non-connected stream has closed");
}
}
}
}
Expand Down

0 comments on commit ec3ddd2

Please sign in to comment.