diff --git a/crates/augurs-forecaster/src/transforms.rs b/crates/augurs-forecaster/src/transforms.rs index 4d878bf..1bfa393 100644 --- a/crates/augurs-forecaster/src/transforms.rs +++ b/crates/augurs-forecaster/src/transforms.rs @@ -43,7 +43,7 @@ impl Transforms { } /// A transformation that can be applied to a time series. -#[derive(Debug)] +#[derive(Debug, Clone)] #[non_exhaustive] pub enum Transform { /// Linear interpolation. diff --git a/js/augurs-transforms-js/src/lib.rs b/js/augurs-transforms-js/src/lib.rs index a5e38d2..305715e 100644 --- a/js/augurs-transforms-js/src/lib.rs +++ b/js/augurs-transforms-js/src/lib.rs @@ -32,7 +32,7 @@ pub enum PowerTransformAlgorithm { #[derive(Debug)] #[wasm_bindgen] pub struct PowerTransform { - inner: Transform, + inner: Option, standardize: Standardize, scale_params: Option, } @@ -45,8 +45,7 @@ impl PowerTransform { #[wasm_bindgen(constructor)] pub fn new(opts: PowerTransformOptions) -> Result { Ok(PowerTransform { - inner: Transform::power_transform(&opts.data) - .map_err(|e| JsError::new(&e.to_string()))?, + inner: None, standardize: opts.standardize.unwrap_or_default(), scale_params: None, }) @@ -62,16 +61,32 @@ impl PowerTransform { pub fn transform(&mut self, data: VecF64) -> Result, JsError> { let data = data.convert()?; Ok(match self.standardize { - Standardize::None => self.inner.transform(data.iter().copied()).collect(), + Standardize::None => { + let transform = + Transform::power_transform(&data).map_err(|e| JsError::new(&e.to_string()))?; + let result = transform.transform(data.iter().copied()).collect(); + self.inner = Some(transform); + result + } Standardize::Before => { let scale_params = StandardScaleParams::from_data(data.iter().copied()); let scaler = Transform::standard_scaler(scale_params.clone()); self.scale_params = Some(scale_params); let scaled: Vec<_> = scaler.transform(data.iter().copied()).collect(); - self.inner.transform(scaled.iter().copied()).collect() + + let transform = Transform::power_transform(&scaled) + .map_err(|e| JsError::new(&e.to_string()))?; + let result = transform.transform(scaled.iter().copied()).collect(); + self.inner = Some(transform); + + result } Standardize::After => { - let transformed: Vec<_> = self.inner.transform(data.iter().copied()).collect(); + let transform = + Transform::power_transform(&data).map_err(|e| JsError::new(&e.to_string()))?; + + let transformed: Vec<_> = transform.transform(data.iter().copied()).collect(); + self.inner = Some(transform); let scale_params = StandardScaleParams::from_data(transformed.iter().copied()); let scaler = Transform::standard_scaler(scale_params.clone()); @@ -91,9 +106,10 @@ impl PowerTransform { #[wasm_bindgen(js_name = "inverseTransform")] pub fn inverse_transform(&self, data: VecF64) -> Result, JsError> { let data = data.convert()?; + let transformer = self.inner.clone().unwrap(); Ok(match (self.standardize, self.scale_params.clone()) { (Standardize::Before, Some(scale_params)) => { - let inverse_transformed = self.inner.inverse_transform(data.iter().copied()); + let inverse_transformed = transformer.inverse_transform(data.iter().copied()); let inverse_scaler = Transform::standard_scaler(scale_params.clone()); inverse_scaler .inverse_transform(inverse_transformed) @@ -102,12 +118,15 @@ impl PowerTransform { (Standardize::After, Some(scale_params)) => { let inverse_scaler = Transform::standard_scaler(scale_params.clone()); let scaled = inverse_scaler.inverse_transform(data.iter().copied()); - self.inner.inverse_transform(scaled).collect() + transformer.inverse_transform(scaled).collect() } - _ => self.inner.inverse_transform(data.iter().copied()).collect(), + _ => transformer + .inverse_transform(data.iter().copied()) + .collect(), }) } + /* /// Get the algorithm used by the power transform. /// /// @experimental @@ -128,6 +147,7 @@ impl PowerTransform { _ => unreachable!(), } } + */ } /// When to standardize the data.