Skip to content

Commit

Permalink
Change Callback struct to hold fn() instead of &Fn()
Browse files Browse the repository at this point in the history
This removed the need to impl unsafe Send & Sync traits for
Callback type.

Also updated the error handler test to be more rusty.
  • Loading branch information
9prady9 committed Jun 28, 2017
1 parent e99ea54 commit 8db332e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 56 deletions.
63 changes: 31 additions & 32 deletions src/error.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,39 @@
use std::ops::{Deref, DerefMut};
use defines::AfError;
use std::error::Error;
use std::marker::{Send, Sync};
use std::sync::RwLock;

/// Signature of callback function to be called to handle errors
pub type ErrorCallback = Fn(AfError);
/// Signature of error handling callback function
pub type ErrorCallback = fn(AfError);

/// Wrap ErrorCallback function pointer inside a structure
/// to enable implementing Send, Sync traits on it.
pub struct Callback<'cblifetime> {
///Reference to a valid error callback function
///Make sure this callback stays relevant throughout the lifetime of application.
pub cb: &'cblifetime ErrorCallback,
/// Structure holding handle to callback function
pub struct Callback {
cb: ErrorCallback,
}

// Implement Send, Sync traits for Callback structure to
// enable the user of Callback function pointer in conjunction
// with threads using a mutex.
unsafe impl<'cblifetime> Send for Callback<'cblifetime> {}
unsafe impl<'cblifetime> Sync for Callback<'cblifetime> {}
impl Callback {
/// Associated function to create a new Callback object
pub fn new(callback: ErrorCallback) -> Self {
Callback {cb: callback}
}

/// call invokes the error callback with `error_code`.
pub fn call(&self, error_code: AfError) {
(self.cb)(error_code)
}
}

pub const DEFAULT_HANDLE_ERROR: Callback<'static> = Callback{cb: &handle_error_general};
/// Default error handling callback provided by ArrayFire crate
pub fn handle_error_general(error_code: AfError) {
match error_code {
AfError::SUCCESS => {}, /* No-op */
_ => panic!("Error message: {}", error_code.description()),
}
}

lazy_static! {
static ref ERROR_HANDLER_LOCK: RwLock< Callback<'static> > =
RwLock::new(DEFAULT_HANDLE_ERROR);
static ref ERROR_HANDLER_LOCK: RwLock< Callback > =
RwLock::new(Callback::new(handle_error_general));
}

/// Register user provided error handler
Expand All @@ -45,16 +53,17 @@ lazy_static! {
/// }
/// }
///
/// pub const ERR_HANDLE: Callback<'static> = Callback{ cb: &handleError};
///
/// fn main() {
/// register_error_handler(ERR_HANDLE);
/// //Registering the error handler should be the first call
/// //before any other functions are called if your version
/// //of error is to be used for subsequent function calls
/// register_error_handler(Callback::new(handleError));
///
/// info();
/// }
/// ```
#[allow(unused_must_use)]
pub fn register_error_handler(cb_value: Callback<'static>) {
pub fn register_error_handler(cb_value: Callback) {
let mut gaurd = match ERROR_HANDLER_LOCK.write() {
Ok(g) => g,
Err(_)=> panic!("Failed to acquire lock to register error handler"),
Expand All @@ -63,22 +72,12 @@ pub fn register_error_handler(cb_value: Callback<'static>) {
*gaurd.deref_mut() = cb_value;
}

/// Default error handling callback provided by ArrayFire crate
pub fn handle_error_general(error_code: AfError) {
match error_code {
AfError::SUCCESS => {}, /* No-op */
_ => panic!("Error message: {}", error_code.description()),
}
}

#[allow(non_snake_case)]
pub fn HANDLE_ERROR(error_code: AfError) {
let gaurd = match ERROR_HANDLER_LOCK.read() {
Ok(g) => g,
Err(_)=> panic!("Failed to acquire lock while handling FFI return value"),
};

let func = gaurd.deref().cb;

func(error_code);
(*gaurd.deref()).call(error_code);
}
41 changes: 17 additions & 24 deletions tests/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,44 @@ extern crate arrayfire as af;

use std::error::Error;
use std::thread;
use std::time::Duration;
use af::*;

macro_rules! implement_handler {
($fn_name:ident, $msg: expr) => (

($fn_name:ident) => (
pub fn $fn_name(error_code: AfError) {
println!("{:?}", $msg);
match error_code {
AfError::SUCCESS => {}, /* No-op */
_ => panic!("Error message: {}", error_code.description()),
}
}

)
}

implement_handler!(handler_sample1, "Error Handler Sample1");
implement_handler!(handler_sample2, "Error Handler Sample2");
implement_handler!(handler_sample3, "Error Handler Sample3");
implement_handler!(handler_sample4, "Error Handler Sample4");

pub const HANDLE1: Callback<'static> = Callback{ cb: &handler_sample1};
pub const HANDLE2: Callback<'static> = Callback{ cb: &handler_sample2};
pub const HANDLE3: Callback<'static> = Callback{ cb: &handler_sample3};
pub const HANDLE4: Callback<'static> = Callback{ cb: &handler_sample4};
implement_handler!(handler_sample1);
implement_handler!(handler_sample2);
implement_handler!(handler_sample3);
implement_handler!(handler_sample4);

#[allow(unused_must_use)]
#[test]
fn check_error_handler_mutation() {

for i in 0..4 {
let children = (0..4).map(|i| {
thread::Builder::new().name(format!("child {}",i+1).to_string()).spawn(move || {
println!("{:?}", thread::current());
let target_device = i%af::device_count();
println!("Thread {:?} 's target device is {}", thread::current(), target_device);
match i {
0 => register_error_handler(HANDLE1),
1 => register_error_handler(HANDLE2),
2 => register_error_handler(HANDLE3),
3 => register_error_handler(HANDLE4),
0 => register_error_handler(Callback::new(handler_sample1)),
1 => register_error_handler(Callback::new(handler_sample2)),
2 => register_error_handler(Callback::new(handler_sample3)),
3 => register_error_handler(Callback::new(handler_sample4)),
_ => panic!("Impossible scenario"),
}
});
}
}).ok().expect("Failed to launch a thread")
}).collect::< Vec<_> >();

af::info();
thread::sleep(Duration::from_millis(50));
for c in children {
c.join();
}

}

0 comments on commit 8db332e

Please sign in to comment.