From 8db332ea3553ecf1fd3f5f9ec95b6791bb418fc0 Mon Sep 17 00:00:00 2001 From: pradeep Date: Wed, 28 Jun 2017 18:24:04 +0530 Subject: [PATCH] Change Callback struct to hold fn() instead of &Fn() This removed the need to impl unsafe Send & Sync traits for Callback type. Also updated the error handler test to be more rusty. --- src/error.rs | 63 ++++++++++++++++++++++++++-------------------------- tests/lib.rs | 41 ++++++++++++++-------------------- 2 files changed, 48 insertions(+), 56 deletions(-) diff --git a/src/error.rs b/src/error.rs index b20b7673c..0a42ec793 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,31 +1,39 @@ use std::ops::{Deref, DerefMut}; use defines::AfError; use std::error::Error; -use std::marker::{Send, Sync}; use std::sync::RwLock; -/// Signature of callback function to be called to handle errors -pub type ErrorCallback = Fn(AfError); +/// Signature of error handling callback function +pub type ErrorCallback = fn(AfError); -/// Wrap ErrorCallback function pointer inside a structure -/// to enable implementing Send, Sync traits on it. -pub struct Callback<'cblifetime> { - ///Reference to a valid error callback function - ///Make sure this callback stays relevant throughout the lifetime of application. - pub cb: &'cblifetime ErrorCallback, +/// Structure holding handle to callback function +pub struct Callback { + cb: ErrorCallback, } -// Implement Send, Sync traits for Callback structure to -// enable the user of Callback function pointer in conjunction -// with threads using a mutex. -unsafe impl<'cblifetime> Send for Callback<'cblifetime> {} -unsafe impl<'cblifetime> Sync for Callback<'cblifetime> {} +impl Callback { + /// Associated function to create a new Callback object + pub fn new(callback: ErrorCallback) -> Self { + Callback {cb: callback} + } + + /// call invokes the error callback with `error_code`. + pub fn call(&self, error_code: AfError) { + (self.cb)(error_code) + } +} -pub const DEFAULT_HANDLE_ERROR: Callback<'static> = Callback{cb: &handle_error_general}; +/// Default error handling callback provided by ArrayFire crate +pub fn handle_error_general(error_code: AfError) { + match error_code { + AfError::SUCCESS => {}, /* No-op */ + _ => panic!("Error message: {}", error_code.description()), + } +} lazy_static! { - static ref ERROR_HANDLER_LOCK: RwLock< Callback<'static> > = - RwLock::new(DEFAULT_HANDLE_ERROR); + static ref ERROR_HANDLER_LOCK: RwLock< Callback > = + RwLock::new(Callback::new(handle_error_general)); } /// Register user provided error handler @@ -45,16 +53,17 @@ lazy_static! { /// } /// } /// -/// pub const ERR_HANDLE: Callback<'static> = Callback{ cb: &handleError}; -/// /// fn main() { -/// register_error_handler(ERR_HANDLE); +/// //Registering the error handler should be the first call +/// //before any other functions are called if your version +/// //of error is to be used for subsequent function calls +/// register_error_handler(Callback::new(handleError)); /// /// info(); /// } /// ``` #[allow(unused_must_use)] -pub fn register_error_handler(cb_value: Callback<'static>) { +pub fn register_error_handler(cb_value: Callback) { let mut gaurd = match ERROR_HANDLER_LOCK.write() { Ok(g) => g, Err(_)=> panic!("Failed to acquire lock to register error handler"), @@ -63,14 +72,6 @@ pub fn register_error_handler(cb_value: Callback<'static>) { *gaurd.deref_mut() = cb_value; } -/// Default error handling callback provided by ArrayFire crate -pub fn handle_error_general(error_code: AfError) { - match error_code { - AfError::SUCCESS => {}, /* No-op */ - _ => panic!("Error message: {}", error_code.description()), - } -} - #[allow(non_snake_case)] pub fn HANDLE_ERROR(error_code: AfError) { let gaurd = match ERROR_HANDLER_LOCK.read() { @@ -78,7 +79,5 @@ pub fn HANDLE_ERROR(error_code: AfError) { Err(_)=> panic!("Failed to acquire lock while handling FFI return value"), }; - let func = gaurd.deref().cb; - - func(error_code); + (*gaurd.deref()).call(error_code); } diff --git a/tests/lib.rs b/tests/lib.rs index 11990e92a..ba7365b21 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -2,51 +2,44 @@ extern crate arrayfire as af; use std::error::Error; use std::thread; -use std::time::Duration; use af::*; macro_rules! implement_handler { - ($fn_name:ident, $msg: expr) => ( - + ($fn_name:ident) => ( pub fn $fn_name(error_code: AfError) { - println!("{:?}", $msg); match error_code { AfError::SUCCESS => {}, /* No-op */ _ => panic!("Error message: {}", error_code.description()), } } - ) } -implement_handler!(handler_sample1, "Error Handler Sample1"); -implement_handler!(handler_sample2, "Error Handler Sample2"); -implement_handler!(handler_sample3, "Error Handler Sample3"); -implement_handler!(handler_sample4, "Error Handler Sample4"); - -pub const HANDLE1: Callback<'static> = Callback{ cb: &handler_sample1}; -pub const HANDLE2: Callback<'static> = Callback{ cb: &handler_sample2}; -pub const HANDLE3: Callback<'static> = Callback{ cb: &handler_sample3}; -pub const HANDLE4: Callback<'static> = Callback{ cb: &handler_sample4}; +implement_handler!(handler_sample1); +implement_handler!(handler_sample2); +implement_handler!(handler_sample3); +implement_handler!(handler_sample4); #[allow(unused_must_use)] #[test] fn check_error_handler_mutation() { - for i in 0..4 { + let children = (0..4).map(|i| { thread::Builder::new().name(format!("child {}",i+1).to_string()).spawn(move || { - println!("{:?}", thread::current()); + let target_device = i%af::device_count(); + println!("Thread {:?} 's target device is {}", thread::current(), target_device); match i { - 0 => register_error_handler(HANDLE1), - 1 => register_error_handler(HANDLE2), - 2 => register_error_handler(HANDLE3), - 3 => register_error_handler(HANDLE4), + 0 => register_error_handler(Callback::new(handler_sample1)), + 1 => register_error_handler(Callback::new(handler_sample2)), + 2 => register_error_handler(Callback::new(handler_sample3)), + 3 => register_error_handler(Callback::new(handler_sample4)), _ => panic!("Impossible scenario"), } - }); - } + }).ok().expect("Failed to launch a thread") + }).collect::< Vec<_> >(); - af::info(); - thread::sleep(Duration::from_millis(50)); + for c in children { + c.join(); + } }