diff --git a/crates/augurs-forecaster/src/transforms.rs b/crates/augurs-forecaster/src/transforms.rs index 4d878bf..1bfa393 100644 --- a/crates/augurs-forecaster/src/transforms.rs +++ b/crates/augurs-forecaster/src/transforms.rs @@ -43,7 +43,7 @@ impl Transforms { } /// A transformation that can be applied to a time series. -#[derive(Debug)] +#[derive(Debug, Clone)] #[non_exhaustive] pub enum Transform { /// Linear interpolation. diff --git a/js/augurs-transforms-js/src/lib.rs b/js/augurs-transforms-js/src/lib.rs index 1dde91b..305715e 100644 --- a/js/augurs-transforms-js/src/lib.rs +++ b/js/augurs-transforms-js/src/lib.rs @@ -1,7 +1,5 @@ //! JavaScript bindings for augurs transformations, such as power transforms, scaling, etc. -use std::cell::RefCell; - use serde::{Deserialize, Serialize}; use tsify_next::Tsify; use wasm_bindgen::prelude::*; @@ -34,9 +32,9 @@ pub enum PowerTransformAlgorithm { #[derive(Debug)] #[wasm_bindgen] pub struct PowerTransform { - inner: Transform, - standardize: bool, - scale_params: RefCell>, + inner: Option, + standardize: Standardize, + scale_params: Option, } #[wasm_bindgen] @@ -47,58 +45,88 @@ impl PowerTransform { #[wasm_bindgen(constructor)] pub fn new(opts: PowerTransformOptions) -> Result { Ok(PowerTransform { - inner: Transform::power_transform(&opts.data) - .map_err(|e| JsError::new(&e.to_string()))?, - standardize: opts.standardize, - scale_params: RefCell::new(None), + inner: None, + standardize: opts.standardize.unwrap_or_default(), + scale_params: None, }) } /// Transform the given data. /// - /// The transformed data is then scaled using a standard scaler (unless - /// `standardize` was set to `false` in the constructor). + /// The data is also scaled either before or after being transformed as per the standardize + /// option. /// /// @experimental #[wasm_bindgen] - pub fn transform(&self, data: VecF64) -> Result, JsError> { - let transformed: Vec<_> = self - .inner - .transform(data.convert()?.iter().copied()) - .collect(); - if !self.standardize { - Ok(transformed) - } else { - let scale_params = StandardScaleParams::from_data(transformed.iter().copied()); - let scaler = Transform::standard_scaler(scale_params.clone()); - self.scale_params.replace(Some(scale_params)); - Ok(scaler.transform(transformed.iter().copied()).collect()) - } + pub fn transform(&mut self, data: VecF64) -> Result, JsError> { + let data = data.convert()?; + Ok(match self.standardize { + Standardize::None => { + let transform = + Transform::power_transform(&data).map_err(|e| JsError::new(&e.to_string()))?; + let result = transform.transform(data.iter().copied()).collect(); + self.inner = Some(transform); + result + } + Standardize::Before => { + let scale_params = StandardScaleParams::from_data(data.iter().copied()); + let scaler = Transform::standard_scaler(scale_params.clone()); + self.scale_params = Some(scale_params); + let scaled: Vec<_> = scaler.transform(data.iter().copied()).collect(); + + let transform = Transform::power_transform(&scaled) + .map_err(|e| JsError::new(&e.to_string()))?; + let result = transform.transform(scaled.iter().copied()).collect(); + self.inner = Some(transform); + + result + } + Standardize::After => { + let transform = + Transform::power_transform(&data).map_err(|e| JsError::new(&e.to_string()))?; + + let transformed: Vec<_> = transform.transform(data.iter().copied()).collect(); + self.inner = Some(transform); + + let scale_params = StandardScaleParams::from_data(transformed.iter().copied()); + let scaler = Transform::standard_scaler(scale_params.clone()); + self.scale_params = Some(scale_params); + scaler.transform(transformed.iter().copied()).collect() + } + }) } /// Inverse transform the given data. /// - /// The data is first scaled back to the original scale using the standard scaler - /// (unless `standardize` was set to `false` in the constructor), then the - /// inverse power transform is applied. + /// The data is also inversely scaled according to the standardize option. The ordering is + /// opposite the order done in transform, i.e if transform scales first then transforms, then + /// inverse_transform transforms then scales. /// /// @experimental #[wasm_bindgen(js_name = "inverseTransform")] pub fn inverse_transform(&self, data: VecF64) -> Result, JsError> { - match (self.standardize, self.scale_params.borrow().as_ref()) { - (true, Some(scale_params)) => { + let data = data.convert()?; + let transformer = self.inner.clone().unwrap(); + Ok(match (self.standardize, self.scale_params.clone()) { + (Standardize::Before, Some(scale_params)) => { + let inverse_transformed = transformer.inverse_transform(data.iter().copied()); + let inverse_scaler = Transform::standard_scaler(scale_params.clone()); + inverse_scaler + .inverse_transform(inverse_transformed) + .collect() + } + (Standardize::After, Some(scale_params)) => { let inverse_scaler = Transform::standard_scaler(scale_params.clone()); - let data = data.convert()?; let scaled = inverse_scaler.inverse_transform(data.iter().copied()); - Ok(self.inner.inverse_transform(scaled).collect()) + transformer.inverse_transform(scaled).collect() } - _ => Ok(self - .inner - .inverse_transform(data.convert()?.iter().copied()) - .collect()), - } + _ => transformer + .inverse_transform(data.iter().copied()) + .collect(), + }) } + /* /// Get the algorithm used by the power transform. /// /// @experimental @@ -119,10 +147,22 @@ impl PowerTransform { _ => unreachable!(), } } + */ } -fn default_standardize() -> bool { - true +/// When to standardize the data. +#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Tsify)] +#[serde(rename_all = "camelCase")] +#[tsify(from_wasm_abi)] +pub enum Standardize { + /// Only run a power transform, do not standardize the data. + None, + /// Standardize the data before running the power transform. This may provide better results for data + /// with a non-zero floor. + Before, + /// Standardize the data after running the power transform. This matches the default in sklearn. + #[default] + After, } /// Options for the power transform. @@ -136,8 +176,7 @@ pub struct PowerTransformOptions { /// Whether to standardize the data after applying the power transform. /// - /// This is generally recommended, and defaults to `true`. - #[serde(default = "default_standardize")] + /// This is generally recommended, and defaults to [`Standardize::After`] to match sklearn. #[tsify(optional)] - pub standardize: bool, + pub standardize: Option, }