From 89c39282083d89e94a746765779a4c584e7f0179 Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Wed, 24 Jul 2024 09:47:13 -0700 Subject: [PATCH 1/3] Implement IntoFuture for winrt futures --- crates/libs/bindgen/src/rust/writer.rs | 25 +++ crates/libs/core/src/future.rs | 68 +++++++++ crates/libs/core/src/lib.rs | 2 + .../windows/src/Windows/Devices/Sms/mod.rs | 144 ++++++++++++++++++ .../windows/src/Windows/Foundation/mod.rs | 88 +++++++++++ .../Security/Authentication/OnlineId/mod.rs | 44 ++++++ .../src/Windows/Storage/Streams/mod.rs | 44 ++++++ crates/samples/windows/ocr/Cargo.toml | 3 + crates/samples/windows/ocr/src/main.rs | 14 +- crates/tests/winrt/Cargo.toml | 1 + crates/tests/winrt/tests/async.rs | 30 ++++ 11 files changed, 458 insertions(+), 5 deletions(-) create mode 100644 crates/libs/core/src/future.rs diff --git a/crates/libs/bindgen/src/rust/writer.rs b/crates/libs/bindgen/src/rust/writer.rs index 3da4429bcc..8e356eb847 100644 --- a/crates/libs/bindgen/src/rust/writer.rs +++ b/crates/libs/bindgen/src/rust/writer.rs @@ -697,6 +697,31 @@ impl Writer { self.GetResults() } } + #features + impl<#constraints> windows_core::AsyncOperation for #ident { + type Output = #return_type; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != #namespace AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&#namespace #handler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } + } + #features + impl<#constraints> std::future::IntoFuture for #ident { + type Output = windows_core::Result<#return_type>; + type IntoFuture = windows_core::FutureWrapper<#ident>; + + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } + } } } } diff --git a/crates/libs/core/src/future.rs b/crates/libs/core/src/future.rs new file mode 100644 index 0000000000..e6e3af0577 --- /dev/null +++ b/crates/libs/core/src/future.rs @@ -0,0 +1,68 @@ +#![cfg(feature = "std")] + +use std::{ + future::Future, + pin::Pin, + sync::{Arc, Mutex}, + task::{Poll, Waker}, +}; + +/// Wraps an `IAsyncOperation`, `IAsyncOperationWithProgress`, `IAsyncAction`, or `IAsyncActionWithProgress`. +/// Impls for this trait are generated automatically by windows-bindgen. +pub trait AsyncOperation { + /// The type produced when the operation finishes. + type Output; + /// Returns whether the operation is finished, in which case `self.get_results()` can be used to get the returned data. + /// Wraps `self.Status() != AsyncStatus::Started`. + fn is_complete(&self) -> crate::Result; + /// Register a callback that will be called once the operation is finished. + /// This can only be called once. + /// Wraps `self.SetCompleted(f)`. + fn set_completed(&self, f: impl Fn() + Send + 'static) -> crate::Result<()>; + /// Get the result value from a completed operation. + /// Wraps `self.GetResults()`. + fn get_results(&self) -> crate::Result; +} + +/// A wrapper around an `AsyncOperation` that implements `std::future::Future`. +/// This is used by generated `IntoFuture` impls. It shouldn't be necessary to use this type manually. +pub struct FutureWrapper { + inner: T, + waker: Option>>, +} + +impl FutureWrapper { + /// Creates a `FutureWrapper`, which implements `std::future::Future`. + pub fn new(inner: T) -> Self { + Self { + inner, + waker: None, + } + } +} + +impl Unpin for FutureWrapper {} + +impl Future for FutureWrapper { + type Output = crate::Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + if self.inner.is_complete()? { + Poll::Ready(self.inner.get_results()) + } else { + if let Some(saved_waker) = &self.waker { + // Update the saved waker, in case the future has been transferred to a different executor. + // (e.g. if using `select`.) + let mut saved_waker = saved_waker.lock().unwrap(); + saved_waker.clone_from(cx.waker()); + } else { + let saved_waker = Arc::new(Mutex::new(cx.waker().clone())); + self.waker = Some(saved_waker.clone()); + self.inner.set_completed(move || { + saved_waker.lock().unwrap().wake_by_ref(); + })?; + } + Poll::Pending + } + } +} diff --git a/crates/libs/core/src/lib.rs b/crates/libs/core/src/lib.rs index d8a2b77f87..1a5e8ae03c 100644 --- a/crates/libs/core/src/lib.rs +++ b/crates/libs/core/src/lib.rs @@ -24,6 +24,7 @@ pub mod imp; mod as_impl; mod com_object; +mod future; mod guid; mod inspectable; mod interface; @@ -41,6 +42,7 @@ mod weak; pub use as_impl::*; pub use com_object::*; +pub use future::*; pub use guid::*; pub use inspectable::*; pub use interface::*; diff --git a/crates/libs/windows/src/Windows/Devices/Sms/mod.rs b/crates/libs/windows/src/Windows/Devices/Sms/mod.rs index 64d785eda9..26b29068a6 100644 --- a/crates/libs/windows/src/Windows/Devices/Sms/mod.rs +++ b/crates/libs/windows/src/Windows/Devices/Sms/mod.rs @@ -1036,6 +1036,30 @@ impl DeleteSmsMessageOperation { } } #[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for DeleteSmsMessageOperation { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for DeleteSmsMessageOperation { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} +#[cfg(feature = "deprecated")] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] pub struct DeleteSmsMessagesOperation(windows_core::IUnknown); @@ -1122,6 +1146,30 @@ impl DeleteSmsMessagesOperation { } } #[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for DeleteSmsMessagesOperation { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for DeleteSmsMessagesOperation { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} +#[cfg(feature = "deprecated")] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] pub struct GetSmsDeviceOperation(windows_core::IUnknown); @@ -1211,6 +1259,30 @@ impl GetSmsDeviceOperation { } } #[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for GetSmsDeviceOperation { + type Output = SmsDevice; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for GetSmsDeviceOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} +#[cfg(feature = "deprecated")] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] pub struct GetSmsMessageOperation(windows_core::IUnknown); @@ -1299,6 +1371,30 @@ impl GetSmsMessageOperation { self.GetResults() } } +#[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for GetSmsMessageOperation { + type Output = ISmsMessage; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for GetSmsMessageOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} #[cfg(all(feature = "Foundation_Collections", feature = "deprecated"))] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] @@ -1407,6 +1503,30 @@ impl GetSmsMessagesOperation { self.GetResults() } } +#[cfg(all(feature = "Foundation_Collections", feature = "deprecated"))] +impl windows_core::AsyncOperation for GetSmsMessagesOperation { + type Output = super::super::Foundation::Collections::IVectorView; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationWithProgressCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +#[cfg(all(feature = "Foundation_Collections", feature = "deprecated"))] +impl std::future::IntoFuture for GetSmsMessagesOperation { + type Output = windows_core::Result>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} #[cfg(feature = "deprecated")] #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] @@ -1493,6 +1613,30 @@ impl SendSmsMessageOperation { self.GetResults() } } +#[cfg(feature = "deprecated")] +impl windows_core::AsyncOperation for SendSmsMessageOperation { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +#[cfg(feature = "deprecated")] +impl std::future::IntoFuture for SendSmsMessageOperation { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} #[repr(transparent)] #[derive(PartialEq, Eq, Debug, Clone)] pub struct SmsAppMessage(windows_core::IUnknown); diff --git a/crates/libs/windows/src/Windows/Foundation/mod.rs b/crates/libs/windows/src/Windows/Foundation/mod.rs index a414916d61..11a548ce5a 100644 --- a/crates/libs/windows/src/Windows/Foundation/mod.rs +++ b/crates/libs/windows/src/Windows/Foundation/mod.rs @@ -78,6 +78,28 @@ impl IAsyncAction { self.GetResults() } } +impl windows_core::AsyncOperation for IAsyncAction { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for IAsyncAction { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for IAsyncAction {} unsafe impl Sync for IAsyncAction {} impl windows_core::RuntimeType for IAsyncAction { @@ -183,6 +205,28 @@ impl IAsyncActionWithProgress windows_core::AsyncOperation for IAsyncActionWithProgress { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&AsyncActionWithProgressCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for IAsyncActionWithProgress { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper>; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for IAsyncActionWithProgress {} unsafe impl Sync for IAsyncActionWithProgress {} impl windows_core::RuntimeType for IAsyncActionWithProgress { @@ -338,6 +382,28 @@ impl IAsyncOperation { self.GetResults() } } +impl windows_core::AsyncOperation for IAsyncOperation { + type Output = TResult; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for IAsyncOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper>; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for IAsyncOperation {} unsafe impl Sync for IAsyncOperation {} impl windows_core::RuntimeType for IAsyncOperation { @@ -455,6 +521,28 @@ impl windows_core::AsyncOperation for IAsyncOperationWithProgress { + type Output = TResult; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&AsyncOperationWithProgressCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for IAsyncOperationWithProgress { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper>; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for IAsyncOperationWithProgress {} unsafe impl Sync for IAsyncOperationWithProgress {} impl windows_core::RuntimeType for IAsyncOperationWithProgress { diff --git a/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs b/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs index e82467d540..a5607a6c52 100644 --- a/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs +++ b/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs @@ -504,6 +504,28 @@ impl SignOutUserOperation { self.GetResults() } } +impl windows_core::AsyncOperation for SignOutUserOperation { + type Output = (); + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for SignOutUserOperation { + type Output = windows_core::Result<()>; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for SignOutUserOperation {} unsafe impl Sync for SignOutUserOperation {} #[repr(transparent)] @@ -587,6 +609,28 @@ impl UserAuthenticationOperation { self.GetResults() } } +impl windows_core::AsyncOperation for UserAuthenticationOperation { + type Output = UserIdentity; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for UserAuthenticationOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for UserAuthenticationOperation {} unsafe impl Sync for UserAuthenticationOperation {} #[repr(transparent)] diff --git a/crates/libs/windows/src/Windows/Storage/Streams/mod.rs b/crates/libs/windows/src/Windows/Storage/Streams/mod.rs index 51059b48e8..329b6bb094 100644 --- a/crates/libs/windows/src/Windows/Storage/Streams/mod.rs +++ b/crates/libs/windows/src/Windows/Storage/Streams/mod.rs @@ -1325,6 +1325,28 @@ impl DataReaderLoadOperation { self.GetResults() } } +impl windows_core::AsyncOperation for DataReaderLoadOperation { + type Output = u32; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for DataReaderLoadOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for DataReaderLoadOperation {} unsafe impl Sync for DataReaderLoadOperation {} #[repr(transparent)] @@ -1593,6 +1615,28 @@ impl DataWriterStoreOperation { self.GetResults() } } +impl windows_core::AsyncOperation for DataWriterStoreOperation { + type Output = u32; + fn is_complete(&self) -> windows_core::Result { + Ok(self.Status()? != super::super::Foundation::AsyncStatus::Started) + } + fn set_completed(&self, f: impl Fn() + Send + 'static) -> windows_core::Result<()> { + self.SetCompleted(&super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| { + f(); + Ok(()) + })) + } + fn get_results(&self) -> windows_core::Result { + self.GetResults() + } +} +impl std::future::IntoFuture for DataWriterStoreOperation { + type Output = windows_core::Result; + type IntoFuture = windows_core::FutureWrapper; + fn into_future(self) -> Self::IntoFuture { + windows_core::FutureWrapper::new(self) + } +} unsafe impl Send for DataWriterStoreOperation {} unsafe impl Sync for DataWriterStoreOperation {} #[repr(transparent)] diff --git a/crates/samples/windows/ocr/Cargo.toml b/crates/samples/windows/ocr/Cargo.toml index 377ff0e03a..e75742439e 100644 --- a/crates/samples/windows/ocr/Cargo.toml +++ b/crates/samples/windows/ocr/Cargo.toml @@ -4,6 +4,9 @@ version = "0.0.0" edition = "2021" publish = false +[dependencies] +futures = "0.3.5" + [dependencies.windows] path = "../../../libs/windows" features = [ diff --git a/crates/samples/windows/ocr/src/main.rs b/crates/samples/windows/ocr/src/main.rs index 4dc85d8d06..7672ece94d 100644 --- a/crates/samples/windows/ocr/src/main.rs +++ b/crates/samples/windows/ocr/src/main.rs @@ -6,18 +6,22 @@ use windows::{ }; fn main() -> Result<()> { + futures::executor::block_on(main_async()) +} + +async fn main_async() -> Result<()> { let mut message = std::env::current_dir().unwrap(); message.push("message.png"); let file = - StorageFile::GetFileFromPathAsync(&HSTRING::from(message.to_str().unwrap()))?.get()?; - let stream = file.OpenAsync(FileAccessMode::Read)?.get()?; + StorageFile::GetFileFromPathAsync(&HSTRING::from(message.to_str().unwrap()))?.await?; + let stream = file.OpenAsync(FileAccessMode::Read)?.await?; - let decode = BitmapDecoder::CreateAsync(&stream)?.get()?; - let bitmap = decode.GetSoftwareBitmapAsync()?.get()?; + let decode = BitmapDecoder::CreateAsync(&stream)?.await?; + let bitmap = decode.GetSoftwareBitmapAsync()?.await?; let engine = OcrEngine::TryCreateFromUserProfileLanguages()?; - let result = engine.RecognizeAsync(&bitmap)?.get()?; + let result = engine.RecognizeAsync(&bitmap)?.await?; println!("{}", result.Text()?); Ok(()) diff --git a/crates/tests/winrt/Cargo.toml b/crates/tests/winrt/Cargo.toml index 10c6bf801f..fa09f0050f 100644 --- a/crates/tests/winrt/Cargo.toml +++ b/crates/tests/winrt/Cargo.toml @@ -29,4 +29,5 @@ features = [ ] [dev-dependencies] +futures = "0.3" helpers = { package = "test_helpers", path = "../helpers" } diff --git a/crates/tests/winrt/tests/async.rs b/crates/tests/winrt/tests/async.rs index 8ce9a224ae..44a843fd5c 100644 --- a/crates/tests/winrt/tests/async.rs +++ b/crates/tests/winrt/tests/async.rs @@ -23,3 +23,33 @@ fn async_get() -> windows::core::Result<()> { Ok(()) } + +async fn async_await() -> windows::core::Result<()> { + use windows::Storage::Streams::*; + + let stream = &InMemoryRandomAccessStream::new()?; + + let writer = DataWriter::CreateDataWriter(stream)?; + writer.WriteByte(1)?; + writer.WriteByte(2)?; + writer.WriteByte(3)?; + writer.StoreAsync()?.await?; + + stream.Seek(0)?; + let reader = DataReader::CreateDataReader(stream)?; + reader.LoadAsync(3)?.await?; + + let mut bytes: [u8; 3] = [0; 3]; + reader.ReadBytes(&mut bytes)?; + + assert!(bytes[0] == 1); + assert!(bytes[1] == 2); + assert!(bytes[2] == 3); + + Ok(()) +} + +#[test] +fn test_async_await() -> windows::core::Result<()> { + futures::executor::block_on(async_await()) +} From ff35ff3be524a95ce759aa3a8c9ca0fda7f89f7f Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Wed, 24 Jul 2024 10:48:39 -0700 Subject: [PATCH 2/3] Add a test for waker-updating behaviour --- crates/tests/winrt/Cargo.toml | 1 + crates/tests/winrt/tests/async.rs | 52 ++++++++++++++++++++++++++++--- 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/crates/tests/winrt/Cargo.toml b/crates/tests/winrt/Cargo.toml index fa09f0050f..45916940ed 100644 --- a/crates/tests/winrt/Cargo.toml +++ b/crates/tests/winrt/Cargo.toml @@ -23,6 +23,7 @@ features = [ "Foundation_Numerics", "Storage_Streams", "System", + "System_Threading", "UI_Composition", "Win32_System_Com", "Win32_System_WinRT", diff --git a/crates/tests/winrt/tests/async.rs b/crates/tests/winrt/tests/async.rs index 44a843fd5c..08323a6409 100644 --- a/crates/tests/winrt/tests/async.rs +++ b/crates/tests/winrt/tests/async.rs @@ -1,7 +1,13 @@ +use std::future::IntoFuture; + +use futures::{executor::LocalPool, future, task::SpawnExt}; +use windows::{ + Storage::Streams::*, + System::Threading::{ThreadPool, WorkItemHandler}, +}; + #[test] fn async_get() -> windows::core::Result<()> { - use windows::Storage::Streams::*; - let stream = &InMemoryRandomAccessStream::new()?; let writer = DataWriter::CreateDataWriter(stream)?; @@ -25,8 +31,6 @@ fn async_get() -> windows::core::Result<()> { } async fn async_await() -> windows::core::Result<()> { - use windows::Storage::Streams::*; - let stream = &InMemoryRandomAccessStream::new()?; let writer = DataWriter::CreateDataWriter(stream)?; @@ -53,3 +57,43 @@ async fn async_await() -> windows::core::Result<()> { fn test_async_await() -> windows::core::Result<()> { futures::executor::block_on(async_await()) } + +#[test] +fn test_async_updates_waker() -> windows::core::Result<()> { + let mut pool = LocalPool::new(); + + let (tx, rx) = std::sync::mpsc::channel::<()>(); + + let winrt_future = ThreadPool::RunAsync(&WorkItemHandler::new(move |_| { + rx.recv().unwrap(); + Ok(()) + }))? + .into_future(); + + let task = pool + .spawner() + .spawn_with_handle(async move { + // Poll the future once on a LocalPool task + match future::select(winrt_future, future::ready(())).await { + future::Either::Left(_) => panic!("threadpool action can't finish yet"), + future::Either::Right(((), future)) => future, + } + }) + .unwrap(); + let winrt_future = pool.run_until(task); + + pool.spawner() + .spawn(async move { + // Now run the future to completion on a *different* LocalPool task. + // This will hang unless winrt_future properly updates its saved waker to the new task. + let (result, ()) = future::join(winrt_future, async { + tx.send(()).unwrap(); + }) + .await; + result.unwrap(); + }) + .unwrap(); + pool.run(); + + Ok(()) +} From 5cd9fd9dd9d15a122df42a39a2eb45f7f92894ba Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Wed, 24 Jul 2024 10:57:40 -0700 Subject: [PATCH 3/3] Cancel futures on drop --- crates/libs/bindgen/src/rust/writer.rs | 3 +++ crates/libs/core/src/future.rs | 20 ++++++++++++------- .../windows/src/Windows/Devices/Sms/mod.rs | 18 +++++++++++++++++ .../windows/src/Windows/Foundation/mod.rs | 12 +++++++++++ .../Security/Authentication/OnlineId/mod.rs | 6 ++++++ .../src/Windows/Storage/Streams/mod.rs | 6 ++++++ 6 files changed, 58 insertions(+), 7 deletions(-) diff --git a/crates/libs/bindgen/src/rust/writer.rs b/crates/libs/bindgen/src/rust/writer.rs index 8e356eb847..a925515119 100644 --- a/crates/libs/bindgen/src/rust/writer.rs +++ b/crates/libs/bindgen/src/rust/writer.rs @@ -712,6 +712,9 @@ impl Writer { fn get_results(&self) -> windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } #features impl<#constraints> std::future::IntoFuture for #ident { diff --git a/crates/libs/core/src/future.rs b/crates/libs/core/src/future.rs index e6e3af0577..bbbc77a6b9 100644 --- a/crates/libs/core/src/future.rs +++ b/crates/libs/core/src/future.rs @@ -22,26 +22,26 @@ pub trait AsyncOperation { /// Get the result value from a completed operation. /// Wraps `self.GetResults()`. fn get_results(&self) -> crate::Result; + /// Attempts to cancel the operation. Any error is ignored. + /// Wraps `self.Cancel()`. + fn cancel(&self); } /// A wrapper around an `AsyncOperation` that implements `std::future::Future`. /// This is used by generated `IntoFuture` impls. It shouldn't be necessary to use this type manually. -pub struct FutureWrapper { +pub struct FutureWrapper { inner: T, waker: Option>>, } -impl FutureWrapper { +impl FutureWrapper { /// Creates a `FutureWrapper`, which implements `std::future::Future`. pub fn new(inner: T) -> Self { - Self { - inner, - waker: None, - } + Self { inner, waker: None } } } -impl Unpin for FutureWrapper {} +impl Unpin for FutureWrapper {} impl Future for FutureWrapper { type Output = crate::Result; @@ -66,3 +66,9 @@ impl Future for FutureWrapper { } } } + +impl Drop for FutureWrapper { + fn drop(&mut self) { + self.inner.cancel(); + } +} diff --git a/crates/libs/windows/src/Windows/Devices/Sms/mod.rs b/crates/libs/windows/src/Windows/Devices/Sms/mod.rs index 26b29068a6..8e7f3daa67 100644 --- a/crates/libs/windows/src/Windows/Devices/Sms/mod.rs +++ b/crates/libs/windows/src/Windows/Devices/Sms/mod.rs @@ -1050,6 +1050,9 @@ impl windows_core::AsyncOperation for DeleteSmsMessageOperation { fn get_results(&self) -> windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } #[cfg(feature = "deprecated")] impl std::future::IntoFuture for DeleteSmsMessageOperation { @@ -1160,6 +1163,9 @@ impl windows_core::AsyncOperation for DeleteSmsMessagesOperation { fn get_results(&self) -> windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } #[cfg(feature = "deprecated")] impl std::future::IntoFuture for DeleteSmsMessagesOperation { @@ -1273,6 +1279,9 @@ impl windows_core::AsyncOperation for GetSmsDeviceOperation { fn get_results(&self) -> windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } #[cfg(feature = "deprecated")] impl std::future::IntoFuture for GetSmsDeviceOperation { @@ -1386,6 +1395,9 @@ impl windows_core::AsyncOperation for GetSmsMessageOperation { fn get_results(&self) -> windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } #[cfg(feature = "deprecated")] impl std::future::IntoFuture for GetSmsMessageOperation { @@ -1518,6 +1530,9 @@ impl windows_core::AsyncOperation for GetSmsMessagesOperation { fn get_results(&self) -> windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } #[cfg(all(feature = "Foundation_Collections", feature = "deprecated"))] impl std::future::IntoFuture for GetSmsMessagesOperation { @@ -1628,6 +1643,9 @@ impl windows_core::AsyncOperation for SendSmsMessageOperation { fn get_results(&self) -> windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } #[cfg(feature = "deprecated")] impl std::future::IntoFuture for SendSmsMessageOperation { diff --git a/crates/libs/windows/src/Windows/Foundation/mod.rs b/crates/libs/windows/src/Windows/Foundation/mod.rs index 11a548ce5a..a08ea00950 100644 --- a/crates/libs/windows/src/Windows/Foundation/mod.rs +++ b/crates/libs/windows/src/Windows/Foundation/mod.rs @@ -92,6 +92,9 @@ impl windows_core::AsyncOperation for IAsyncAction { fn get_results(&self) -> windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } impl std::future::IntoFuture for IAsyncAction { type Output = windows_core::Result<()>; @@ -219,6 +222,9 @@ impl windows_core::AsyncOperatio fn get_results(&self) -> windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } impl std::future::IntoFuture for IAsyncActionWithProgress { type Output = windows_core::Result<()>; @@ -396,6 +402,9 @@ impl windows_core::AsyncOperation fn get_results(&self) -> windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } impl std::future::IntoFuture for IAsyncOperation { type Output = windows_core::Result; @@ -535,6 +544,9 @@ impl windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } impl std::future::IntoFuture for IAsyncOperationWithProgress { type Output = windows_core::Result; diff --git a/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs b/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs index a5607a6c52..7fd0df683f 100644 --- a/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs +++ b/crates/libs/windows/src/Windows/Security/Authentication/OnlineId/mod.rs @@ -518,6 +518,9 @@ impl windows_core::AsyncOperation for SignOutUserOperation { fn get_results(&self) -> windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } impl std::future::IntoFuture for SignOutUserOperation { type Output = windows_core::Result<()>; @@ -623,6 +626,9 @@ impl windows_core::AsyncOperation for UserAuthenticationOperation { fn get_results(&self) -> windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } impl std::future::IntoFuture for UserAuthenticationOperation { type Output = windows_core::Result; diff --git a/crates/libs/windows/src/Windows/Storage/Streams/mod.rs b/crates/libs/windows/src/Windows/Storage/Streams/mod.rs index 329b6bb094..3fa547fb79 100644 --- a/crates/libs/windows/src/Windows/Storage/Streams/mod.rs +++ b/crates/libs/windows/src/Windows/Storage/Streams/mod.rs @@ -1339,6 +1339,9 @@ impl windows_core::AsyncOperation for DataReaderLoadOperation { fn get_results(&self) -> windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } impl std::future::IntoFuture for DataReaderLoadOperation { type Output = windows_core::Result; @@ -1629,6 +1632,9 @@ impl windows_core::AsyncOperation for DataWriterStoreOperation { fn get_results(&self) -> windows_core::Result { self.GetResults() } + fn cancel(&self) { + let _ = self.Cancel(); + } } impl std::future::IntoFuture for DataWriterStoreOperation { type Output = windows_core::Result;