diff --git a/js/augurs-transforms-js/src/lib.rs b/js/augurs-transforms-js/src/lib.rs index cd08773..8e0dc55 100644 --- a/js/augurs-transforms-js/src/lib.rs +++ b/js/augurs-transforms-js/src/lib.rs @@ -1,5 +1,7 @@ //! JavaScript bindings for augurs transformations, such as power transforms, scaling, etc. +use std::cell::RefCell; + use serde::{Deserialize, Serialize}; use tsify_next::Tsify; use wasm_bindgen::prelude::*; @@ -33,7 +35,8 @@ pub enum PowerTransformAlgorithm { #[wasm_bindgen] pub struct PowerTransform { inner: Transform, - scale_params: Option, + standardize: Standardize, + scale_params: RefCell>, } #[wasm_bindgen] @@ -43,61 +46,70 @@ impl PowerTransform { /// @experimental #[wasm_bindgen(constructor)] pub fn new(opts: PowerTransformOptions) -> Result { - let (scale_params, inner) = if opts.standardize { - let scale_params = StandardScaleParams::from_data(opts.data.iter().copied()); - let scaler = Transform::standard_scaler(scale_params.clone()); - let scaled: Vec<_> = scaler.transform(opts.data.iter().copied()).collect(); - ( - Some(scale_params), - Transform::power_transform(&scaled).map_err(|e| JsError::new(&e.to_string()))?, - ) - } else { - ( - None, - Transform::power_transform(&opts.data).map_err(|e| JsError::new(&e.to_string()))?, - ) - }; Ok(PowerTransform { - inner, - scale_params, + inner: Transform::power_transform(&opts.data) + .map_err(|e| JsError::new(&e.to_string()))?, + standardize: opts.standardize.unwrap_or_default(), + scale_params: RefCell::new(None), }) } /// Transform the given data. /// - /// The transformed data is then scaled using a standard scaler (unless - /// `standardize` was set to `false` in the constructor). + /// The data is also scaled either before or after being transformed as per the standardize + /// option. /// /// @experimental #[wasm_bindgen] pub fn transform(&self, data: VecF64) -> Result, JsError> { let data = data.convert()?; - if let Some(scale_params) = &self.scale_params { - let scaler = Transform::standard_scaler(scale_params.clone()); - let scaled: Vec<_> = scaler.transform(data.iter().copied()).collect(); - Ok(self.inner.transform(scaled.iter().copied()).collect()) - } else { - Ok(self.inner.transform(data.iter().copied()).collect()) - } + Ok(match self.standardize { + Standardize::None => self.inner.transform(data.iter().copied()).collect(), + Standardize::Before => { + let scale_params = StandardScaleParams::from_data(data.iter().copied()); + let scaler = Transform::standard_scaler(scale_params.clone()); + self.scale_params.replace(Some(scale_params)); + let scaled: Vec<_> = scaler.transform(data.iter().copied()).collect(); + self.inner.transform(scaled.iter().copied()).collect() + } + Standardize::After => { + let transformed: Vec<_> = self.inner.transform(data.iter().copied()).collect(); + + let scale_params = StandardScaleParams::from_data(transformed.iter().copied()); + let scaler = Transform::standard_scaler(scale_params.clone()); + self.scale_params.replace(Some(scale_params)); + scaler.transform(transformed.iter().copied()).collect() + } + }) } /// Inverse transform the given data. /// - /// The data is first scaled back to the original scale using the standard scaler - /// (unless `standardize` was set to `false` in the constructor), then the - /// inverse power transform is applied. + /// The data is also inversely scaled according to the standardize option. The ordering is + /// opposite the order done in transform, i.e if transform scales first then transforms, then + /// inverse_transform transforms then scales. /// /// @experimental #[wasm_bindgen(js_name = "inverseTransform")] pub fn inverse_transform(&self, data: VecF64) -> Result, JsError> { let data = data.convert()?; - if let Some(scale_params) = &self.scale_params { - let scaler = Transform::standard_scaler(scale_params.clone()); - let inverse_transformed = self.inner.inverse_transform(data.iter().copied()); - Ok(scaler.inverse_transform(inverse_transformed).collect()) - } else { - Ok(self.inner.inverse_transform(data.iter().copied()).collect()) - } + Ok( + match (self.standardize, self.scale_params.borrow().as_ref()) { + (Standardize::Before, Some(scale_params)) => { + let inverse_transformed = self.inner.inverse_transform(data.iter().copied()); + let inverse_scaler = Transform::standard_scaler(scale_params.clone()); + inverse_scaler + .inverse_transform(inverse_transformed) + .collect() + } + (Standardize::After, Some(scale_params)) => { + let inverse_scaler = Transform::standard_scaler(scale_params.clone()); + let scaled = inverse_scaler.inverse_transform(data.iter().copied()); + self.inner.inverse_transform(scaled).collect() + } + _ => self.inner.inverse_transform(data.iter().copied()).collect(), + }, + ) } /// Get the algorithm used by the power transform. @@ -122,8 +134,19 @@ impl PowerTransform { } } -fn default_standardize() -> bool { - true +/// When to standardize the data. +#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Tsify)] +#[serde(rename_all = "camelCase")] +#[tsify(from_wasm_abi)] +pub enum Standardize { + /// Only run a power transform, do not standardize the data. + None, + /// Standardize the data before running the power transform. This may provide better results for data + /// with a non-zero floor. + Before, + /// Standardize the data after running the power transform. This matches the default in sklearn. + #[default] + After, } /// Options for the power transform. @@ -137,8 +160,7 @@ pub struct PowerTransformOptions { /// Whether to standardize the data after applying the power transform. /// - /// This is generally recommended, and defaults to `true`. - #[serde(default = "default_standardize")] + /// This is generally recommended, and defaults to [`Standardize::After`] to match sklearn. #[tsify(optional)] - pub standardize: bool, + pub standardize: Option, }