From d19973bcf575fc0590ebc6618134acb1db94e21c Mon Sep 17 00:00:00 2001 From: Chris Marchbanks Date: Thu, 19 Dec 2024 15:54:25 -0700 Subject: [PATCH] Scale the data before or after a transformation Allow users to specify if they do not want to scale their data, or to scale it either before or after doing a power transformation. This allows both matching the sklearn behavior of scaling the data after the transformation, or scaling it before the transformation which can help with data that floors at non-zero values. --- js/augurs-transforms-js/src/lib.rs | 107 ++++++++++++++++++----------- 1 file changed, 66 insertions(+), 41 deletions(-) diff --git a/js/augurs-transforms-js/src/lib.rs b/js/augurs-transforms-js/src/lib.rs index cd08773..d393a47 100644 --- a/js/augurs-transforms-js/src/lib.rs +++ b/js/augurs-transforms-js/src/lib.rs @@ -1,5 +1,7 @@ //! JavaScript bindings for augurs transformations, such as power transforms, scaling, etc. +use std::cell::RefCell; + use serde::{Deserialize, Serialize}; use tsify_next::Tsify; use wasm_bindgen::prelude::*; @@ -33,7 +35,8 @@ pub enum PowerTransformAlgorithm { #[wasm_bindgen] pub struct PowerTransform { inner: Transform, - scale_params: Option, + standardize: Standardize, + scale_params: RefCell>, } #[wasm_bindgen] @@ -43,61 +46,73 @@ impl PowerTransform { /// @experimental #[wasm_bindgen(constructor)] pub fn new(opts: PowerTransformOptions) -> Result { - let (scale_params, inner) = if opts.standardize { - let scale_params = StandardScaleParams::from_data(opts.data.iter().copied()); - let scaler = Transform::standard_scaler(scale_params.clone()); - let scaled: Vec<_> = scaler.transform(opts.data.iter().copied()).collect(); - ( - Some(scale_params), - Transform::power_transform(&scaled).map_err(|e| JsError::new(&e.to_string()))?, - ) - } else { - ( - None, - Transform::power_transform(&opts.data).map_err(|e| JsError::new(&e.to_string()))?, - ) - }; Ok(PowerTransform { - inner, - scale_params, + inner: Transform::power_transform(&opts.data).map_err(|e| + JsError::new(&e.to_string()))?, + standardize: opts.standardize.unwrap_or_default(), + scale_params: RefCell::new(None), }) } /// Transform the given data. /// - /// The transformed data is then scaled using a standard scaler (unless - /// `standardize` was set to `false` in the constructor). + /// The data is also scaled either before or after being transformed as per the standardize + /// option. /// /// @experimental #[wasm_bindgen] pub fn transform(&self, data: VecF64) -> Result, JsError> { let data = data.convert()?; - if let Some(scale_params) = &self.scale_params { - let scaler = Transform::standard_scaler(scale_params.clone()); - let scaled: Vec<_> = scaler.transform(data.iter().copied()).collect(); - Ok(self.inner.transform(scaled.iter().copied()).collect()) - } else { - Ok(self.inner.transform(data.iter().copied()).collect()) - } + Ok(match self.standardize { + Standardize::None => self.inner + .transform(data.iter().copied()) + .collect(), + Standardize::Before => { + let scale_params = StandardScaleParams::from_data(data.iter().copied()); + let scaler = Transform::standard_scaler(scale_params.clone()); + self.scale_params.replace(Some(scale_params)); + let scaled: Vec<_> = scaler.transform(data.iter().copied()).collect(); + self.inner.transform(scaled.iter().copied()).collect() + } + Standardize::After => { + let transformed: Vec<_> = self + .inner + .transform(data.iter().copied()) + .collect(); + + let scale_params = StandardScaleParams::from_data(transformed.iter().copied()); + let scaler = Transform::standard_scaler(scale_params.clone()); + self.scale_params.replace(Some(scale_params)); + scaler.transform(transformed.iter().copied()).collect() + } + }) } /// Inverse transform the given data. /// - /// The data is first scaled back to the original scale using the standard scaler - /// (unless `standardize` was set to `false` in the constructor), then the - /// inverse power transform is applied. + /// The data is also inversely scaled according to the standardize option. The ordering is + /// opposite the order done in transform, i.e if transform scales first then transforms, then + /// inverse_transform transforms then scales. /// /// @experimental #[wasm_bindgen(js_name = "inverseTransform")] pub fn inverse_transform(&self, data: VecF64) -> Result, JsError> { let data = data.convert()?; - if let Some(scale_params) = &self.scale_params { - let scaler = Transform::standard_scaler(scale_params.clone()); - let inverse_transformed = self.inner.inverse_transform(data.iter().copied()); - Ok(scaler.inverse_transform(inverse_transformed).collect()) - } else { - Ok(self.inner.inverse_transform(data.iter().copied()).collect()) - } + Ok(match (self.standardize, self.scale_params.borrow().as_ref()) { + (Standardize::Before, Some(scale_params)) => { + let inverse_transformed = self.inner.inverse_transform(data.iter().copied()); + let inverse_scaler = Transform::standard_scaler(scale_params.clone()); + inverse_scaler.inverse_transform(inverse_transformed).collect() + } + (Standardize::After, Some(scale_params)) => { + let inverse_scaler = Transform::standard_scaler(scale_params.clone()); + let scaled = inverse_scaler.inverse_transform(data.iter().copied()); + self.inner.inverse_transform(scaled).collect() + } + _ => self.inner + .inverse_transform(data.iter().copied()) + .collect(), + }) } /// Get the algorithm used by the power transform. @@ -122,8 +137,19 @@ impl PowerTransform { } } -fn default_standardize() -> bool { - true +/// When to standardize the data. +#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Tsify)] +#[serde(rename_all = "camelCase")] +#[tsify(from_wasm_abi)] +pub enum Standardize { + /// Only run a power transform, do not standardize the data. + None, + /// Standardize the data before running the power transform. This may provide better results for data + /// with a non-zero floor. + Before, + /// Standardize the data after running the power transform. This matches the default in sklearn. + #[default] + After, } /// Options for the power transform. @@ -137,8 +163,7 @@ pub struct PowerTransformOptions { /// Whether to standardize the data after applying the power transform. /// - /// This is generally recommended, and defaults to `true`. - #[serde(default = "default_standardize")] + /// This is generally recommended, and defaults to [`Standardize::After`] to match sklearn. #[tsify(optional)] - pub standardize: bool, + pub standardize: Option, }