Skip to content

Commit

Permalink
Improve custom URL scheme handler robustness and error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
UdaraJay authored Dec 16, 2024
1 parent 4740c0b commit ca15c3e
Showing 1 changed file with 82 additions and 53 deletions.
135 changes: 82 additions & 53 deletions src/wkwebview/class/url_scheme_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ extern "C" fn start_task(
) {
unsafe {
#[cfg(feature = "tracing")]
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
.entered();
let span = tracing::info_span!(parent: None, "wry::custom_protocol::handle", uri = tracing::field::Empty)
.entered();

let task_key = task.hash(); // hash by task object address
let task_uuid = webview.add_custom_task_key(task_key);
Expand Down Expand Up @@ -122,7 +122,6 @@ extern "C" fn start_task(
if let Some(all_headers) = all_headers {
for current_header in all_headers.allKeys().to_vec() {
let header_value = all_headers.valueForKey(current_header).unwrap();

// inject the header into the request
http_request = http_request.header(current_header.to_string(), header_value.to_string());
}
Expand All @@ -145,62 +144,74 @@ extern "C" fn start_task(
task.didFinish();
};

fn check_webview_id_valid(webview_id: &str) -> crate::Result<()> {
if !WEBVIEW_IDS.lock().unwrap().contains(webview_id) {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}

/// Task may not live longer than async custom protocol handler.
///
/// There are roughly 2 ways to cause segfault:
/// 1. Task has stopped. pointer of the task not valid anymore.
/// 2. Task had stopped, but the pointer of the task has allocated to a new task.
/// Outdated custom handler may call to the new task instance and cause segfault.
fn check_task_is_valid(
webview: &WryWebView,
task_key: usize,
current_uuid: Retained<NSUUID>,
) -> crate::Result<()> {
let latest_task_uuid = webview.get_custom_task_uuid(task_key);
if let Some(latest_uuid) = latest_task_uuid {
if latest_uuid != current_uuid {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
} else {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}

// send response
match http_request.body(sent_form_body) {
Ok(final_request) => {
let responder: Box<dyn FnOnce(HttpResponse<Cow<'static, [u8]>>)> =
Box::new(move |sent_response| {
fn check_webview_id_valid(webview_id: &str) -> crate::Result<()> {
if !WEBVIEW_IDS.lock().unwrap().contains(webview_id) {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
Ok(())
}
/// Task may not live longer than async custom protocol handler.
///
/// There are roughly 2 ways to cause segfault:
/// 1. Task has stopped. pointer of the task not valid anymore.
/// 2. Task had stopped, but the pointer of the task has allocated to a new task.
/// Outdated custom handler may call to the new task instance and cause segfault.
fn check_task_is_valid(
webview: &WryWebView,
task_key: usize,
current_uuid: Retained<NSUUID>,
) -> crate::Result<()> {
let latest_task_uuid = webview.get_custom_task_uuid(task_key);
if let Some(latest_uuid) = latest_task_uuid {
if latest_uuid != current_uuid {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
} else {
return Err(crate::Error::CustomProtocolTaskInvalid);
}
// Consolidate checks before calling into `did*` methods.
let validate = || -> crate::Result<()> {
check_webview_id_valid(webview_id)?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;
Ok(())
};

// Perform an upfront validation
if let Err(e) = validate() {
#[cfg(feature = "tracing")]
tracing::warn!("Task invalid before sending response: {:?}", e);
return; // If invalid, return early without calling task methods.
}

// ...
unsafe fn response(
// FIXME: though we give it a static lifetime, it's not guaranteed to be valid.
task: &'static ProtocolObject<dyn WKURLSchemeTask>,
// FIXME: though we give it a static lifetime, it's not guaranteed to be valid.
webview: &'static mut WryWebView,
task_key: usize,
task_uuid: Retained<NSUUID>,
webview_id: &str,
url: Retained<NSURL>,
sent_response: HttpResponse<Cow<'_, [u8]>>,
) -> crate::Result<()> {
check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
// Validate
check_webview_id_valid(webview_id)?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;

let content = sent_response.body();
// default: application/octet-stream, but should be provided by the client
let wanted_mime = sent_response.headers().get(CONTENT_TYPE);
// default to 200
let wanted_status_code = sent_response.status().as_u16() as i32;
// default to HTTP/1.1
let wanted_version = format!("{:#?}", sent_response.version());

let mut headers = NSMutableDictionary::new();

if let Some(mime) = wanted_mime {
headers.insert_id(
NSString::from_str(CONTENT_TYPE.as_str()).as_ref(),
Expand All @@ -212,7 +223,6 @@ extern "C" fn start_task(
NSString::from_str(&content.len().to_string()),
);

// add headers
for (name, value) in sent_response.headers().iter() {
if let Ok(value) = value.to_str() {
headers.insert_id(
Expand All @@ -232,48 +242,66 @@ extern "C" fn start_task(
)
.unwrap();

// Re-validate before calling didReceiveResponse
check_webview_id_valid(webview_id)?;
check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;

// Use map_err to convert Option<Retained<Exception>> to crate::Error
objc2::exception::catch(AssertUnwindSafe(|| {
task.didReceiveResponse(&response);
}))
.unwrap();
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;

// Send data
let bytes = content.as_ptr() as *mut c_void;
let data = NSData::alloc();
// MIGRATE NOTE: we copied the content to the NSData because content will be freed
// when out of scope but NSData will also free the content when it's done and cause doube free.
let data = NSData::initWithBytes_length(data, bytes, content.len());
let data = NSData::initWithBytes_length(
data,
content.as_ptr() as *mut c_void,
content.len(),
);

// Check validity again
check_webview_id_valid(webview_id)?;
check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;

objc2::exception::catch(AssertUnwindSafe(|| {
task.didReceiveData(&data);
}))
.unwrap();
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;

// Finish
check_webview_id_valid(webview_id)?;
check_task_is_valid(&*webview, task_key, task_uuid.clone())?;
check_task_is_valid(webview, task_key, task_uuid.clone())?;

objc2::exception::catch(AssertUnwindSafe(|| {
task.didFinish();
}))
.unwrap();

webview.remove_custom_task_key(task_key);
Ok(())
.map_err(|_e| crate::Error::CustomProtocolTaskInvalid)?;

{
let ids = WEBVIEW_IDS.lock().unwrap();
if ids.contains(webview_id) {
webview.remove_custom_task_key(task_key);
Ok(())
} else {
Err(crate::Error::CustomProtocolTaskInvalid)
}
}
}

let _ = response(
#[cfg(feature = "tracing")]
let _span = tracing::info_span!("wry::custom_protocol::call_handler").entered();

if let Err(e) = response(
task,
webview,
task_key,
task_uuid,
webview_id,
url.clone(),
sent_response,
);
) {
#[cfg(feature = "tracing")]
tracing::error!("Error responding to task: {:?}", e);
}
});

#[cfg(feature = "tracing")]
Expand All @@ -294,6 +322,7 @@ extern "C" fn start_task(
}
}
}

extern "C" fn stop_task(
_this: &ProtocolObject<dyn WKURLSchemeHandler>,
_sel: objc2::runtime::Sel,
Expand Down

0 comments on commit ca15c3e

Please sign in to comment.