Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: invert power transform/scaling order #207

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/augurs-forecaster/src/transforms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl Transforms {
}

/// A transformation that can be applied to a time series.
#[derive(Debug)]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Transform {
/// Linear interpolation.
Expand Down
121 changes: 80 additions & 41 deletions js/augurs-transforms-js/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
//! JavaScript bindings for augurs transformations, such as power transforms, scaling, etc.

use std::cell::RefCell;

use serde::{Deserialize, Serialize};
use tsify_next::Tsify;
use wasm_bindgen::prelude::*;
Expand Down Expand Up @@ -34,9 +32,9 @@ pub enum PowerTransformAlgorithm {
#[derive(Debug)]
#[wasm_bindgen]
pub struct PowerTransform {
inner: Transform,
standardize: bool,
scale_params: RefCell<Option<StandardScaleParams>>,
inner: Option<Transform>,
standardize: Standardize,
scale_params: Option<StandardScaleParams>,
}

#[wasm_bindgen]
Expand All @@ -47,58 +45,88 @@ impl PowerTransform {
#[wasm_bindgen(constructor)]
pub fn new(opts: PowerTransformOptions) -> Result<PowerTransform, JsError> {
Ok(PowerTransform {
inner: Transform::power_transform(&opts.data)
.map_err(|e| JsError::new(&e.to_string()))?,
standardize: opts.standardize,
scale_params: RefCell::new(None),
inner: None,
standardize: opts.standardize.unwrap_or_default(),
scale_params: None,
})
}

/// Transform the given data.
///
/// The transformed data is then scaled using a standard scaler (unless
/// `standardize` was set to `false` in the constructor).
/// The data is also scaled either before or after being transformed as per the standardize
/// option.
///
/// @experimental
#[wasm_bindgen]
pub fn transform(&self, data: VecF64) -> Result<Vec<f64>, JsError> {
let transformed: Vec<_> = self
.inner
.transform(data.convert()?.iter().copied())
.collect();
if !self.standardize {
Ok(transformed)
} else {
let scale_params = StandardScaleParams::from_data(transformed.iter().copied());
let scaler = Transform::standard_scaler(scale_params.clone());
self.scale_params.replace(Some(scale_params));
Ok(scaler.transform(transformed.iter().copied()).collect())
}
pub fn transform(&mut self, data: VecF64) -> Result<Vec<f64>, JsError> {
let data = data.convert()?;
Ok(match self.standardize {
Standardize::None => {
let transform =
Transform::power_transform(&data).map_err(|e| JsError::new(&e.to_string()))?;
let result = transform.transform(data.iter().copied()).collect();
self.inner = Some(transform);
result
}
Standardize::Before => {
let scale_params = StandardScaleParams::from_data(data.iter().copied());
let scaler = Transform::standard_scaler(scale_params.clone());
self.scale_params = Some(scale_params);
let scaled: Vec<_> = scaler.transform(data.iter().copied()).collect();

let transform = Transform::power_transform(&scaled)
.map_err(|e| JsError::new(&e.to_string()))?;
let result = transform.transform(scaled.iter().copied()).collect();
self.inner = Some(transform);

result
}
Standardize::After => {
let transform =
Transform::power_transform(&data).map_err(|e| JsError::new(&e.to_string()))?;

let transformed: Vec<_> = transform.transform(data.iter().copied()).collect();
self.inner = Some(transform);

let scale_params = StandardScaleParams::from_data(transformed.iter().copied());
let scaler = Transform::standard_scaler(scale_params.clone());
self.scale_params = Some(scale_params);
scaler.transform(transformed.iter().copied()).collect()
}
})
}

/// Inverse transform the given data.
///
/// The data is first scaled back to the original scale using the standard scaler
/// (unless `standardize` was set to `false` in the constructor), then the
/// inverse power transform is applied.
/// The data is also inversely scaled according to the standardize option. The ordering is
/// opposite the order done in transform, i.e if transform scales first then transforms, then
/// inverse_transform transforms then scales.
///
/// @experimental
#[wasm_bindgen(js_name = "inverseTransform")]
pub fn inverse_transform(&self, data: VecF64) -> Result<Vec<f64>, JsError> {
match (self.standardize, self.scale_params.borrow().as_ref()) {
(true, Some(scale_params)) => {
let data = data.convert()?;
let transformer = self.inner.clone().unwrap();
Ok(match (self.standardize, self.scale_params.clone()) {
(Standardize::Before, Some(scale_params)) => {
let inverse_transformed = transformer.inverse_transform(data.iter().copied());
let inverse_scaler = Transform::standard_scaler(scale_params.clone());
inverse_scaler
.inverse_transform(inverse_transformed)
.collect()
}
(Standardize::After, Some(scale_params)) => {
let inverse_scaler = Transform::standard_scaler(scale_params.clone());
let data = data.convert()?;
let scaled = inverse_scaler.inverse_transform(data.iter().copied());
Ok(self.inner.inverse_transform(scaled).collect())
transformer.inverse_transform(scaled).collect()
}
_ => Ok(self
.inner
.inverse_transform(data.convert()?.iter().copied())
.collect()),
}
_ => transformer
.inverse_transform(data.iter().copied())
.collect(),
})
}

/*
/// Get the algorithm used by the power transform.
///
/// @experimental
Expand All @@ -119,10 +147,22 @@ impl PowerTransform {
_ => unreachable!(),
}
}
*/
}

fn default_standardize() -> bool {
true
/// When to standardize the data.
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Tsify)]
#[serde(rename_all = "camelCase")]
#[tsify(from_wasm_abi)]
pub enum Standardize {
/// Only run a power transform, do not standardize the data.
None,
/// Standardize the data before running the power transform. This may provide better results for data
/// with a non-zero floor.
Before,
/// Standardize the data after running the power transform. This matches the default in sklearn.
#[default]
After,
}

/// Options for the power transform.
Expand All @@ -136,8 +176,7 @@ pub struct PowerTransformOptions {

/// Whether to standardize the data after applying the power transform.
///
/// This is generally recommended, and defaults to `true`.
#[serde(default = "default_standardize")]
/// This is generally recommended, and defaults to [`Standardize::After`] to match sklearn.
#[tsify(optional)]
pub standardize: bool,
pub standardize: Option<Standardize>,
}
Loading