diff --git a/nativelink-scheduler/tests/utils/mock_scheduler.rs b/nativelink-scheduler/tests/utils/mock_scheduler.rs index fe0e37035..113a3d25c 100644 --- a/nativelink-scheduler/tests/utils/mock_scheduler.rs +++ b/nativelink-scheduler/tests/utils/mock_scheduler.rs @@ -17,12 +17,10 @@ use std::sync::Arc; use async_trait::async_trait; use nativelink_error::{make_input_err, Error}; use nativelink_metric::{MetricsComponent, RootMetricsComponent}; -use nativelink_util::{ - action_messages::{ActionInfo, OperationId}, - known_platform_property_provider::KnownPlatformPropertyProvider, - operation_state_manager::{ - ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter, - }, +use nativelink_util::action_messages::{ActionInfo, OperationId}; +use nativelink_util::known_platform_property_provider::KnownPlatformPropertyProvider; +use nativelink_util::operation_state_manager::{ + ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter, }; use tokio::sync::{mpsc, Mutex}; diff --git a/nativelink-util/BUILD.bazel b/nativelink-util/BUILD.bazel index b5598b872..5a47a82e6 100644 --- a/nativelink-util/BUILD.bazel +++ b/nativelink-util/BUILD.bazel @@ -17,6 +17,7 @@ rust_library( "src/connection_manager.rs", "src/default_store_key_subscribe.rs", "src/digest_hasher.rs", + "src/drop_guard.rs", "src/evicting_map.rs", "src/fastcdc.rs", "src/fs.rs", diff --git a/nativelink-util/src/drop_guard.rs b/nativelink-util/src/drop_guard.rs new file mode 100644 index 000000000..b1c8d0a30 --- /dev/null +++ b/nativelink-util/src/drop_guard.rs @@ -0,0 +1,42 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub struct DropGuard { + future: Option>>, +} + +impl DropGuard { + pub fn new(future: F) -> Self { + Self { + future: Some(Box::pin(future)), + } + } +} + +impl Future for DropGuard { + type Output = F::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(future) = self.future.as_mut() { + match future.as_mut().poll(cx) { + Poll::Ready(output) => { + self.future = None; // Set future to None after it completes + Poll::Ready(output) + } + Poll::Pending => Poll::Pending, + } + } else { + panic!("Future already completed"); + } + } +} + +impl Drop for DropGuard { + fn drop(&mut self) { + if let Some(future) = self.future.take() { + // Block on the future to ensure it completes. + futures::executor::block_on(future); + } + } +} diff --git a/nativelink-util/src/evicting_map.rs b/nativelink-util/src/evicting_map.rs index 90a1d5597..bc588edff 100644 --- a/nativelink-util/src/evicting_map.rs +++ b/nativelink-util/src/evicting_map.rs @@ -28,6 +28,7 @@ use nativelink_metric::MetricsComponent; use serde::{Deserialize, Serialize}; use tracing::{event, Level}; +use crate::drop_guard::DropGuard; use crate::instant_wrapper::InstantWrapper; use crate::metrics_utils::{Counter, CounterWithTime}; @@ -434,7 +435,11 @@ where data, }; - if let Some(old_item) = state.put(key, eviction_item).await { + let fut = state.put(key, eviction_item); + + let drop_guard = DropGuard::new(fut); + + if let Some(old_item) = drop_guard.await { replaced_items.push(old_item); } state.sum_store_size += new_item_size; diff --git a/nativelink-util/src/lib.rs b/nativelink-util/src/lib.rs index 69f6edaa2..6cb6772aa 100644 --- a/nativelink-util/src/lib.rs +++ b/nativelink-util/src/lib.rs @@ -20,6 +20,7 @@ pub mod common; pub mod connection_manager; pub mod default_store_key_subscribe; pub mod digest_hasher; +pub mod drop_guard; pub mod evicting_map; pub mod fastcdc; pub mod fs;