Skip to content

Commit

Permalink
Scale the data before or after a transformation
Browse files Browse the repository at this point in the history
Allow users to specify if they do not want to scale their data, or to
scale it either before or after doing a power transformation. This
allows both matching the sklearn behavior of scaling the data after the
transformation, or scaling it before the transformation which can help
with data that floors at non-zero values.
  • Loading branch information
csmarchbanks committed Dec 19, 2024
1 parent 0cb4f59 commit 7a8586f
Showing 1 changed file with 63 additions and 41 deletions.
104 changes: 63 additions & 41 deletions js/augurs-transforms-js/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! JavaScript bindings for augurs transformations, such as power transforms, scaling, etc.
use std::cell::RefCell;

use serde::{Deserialize, Serialize};
use tsify_next::Tsify;
use wasm_bindgen::prelude::*;
Expand Down Expand Up @@ -33,7 +35,8 @@ pub enum PowerTransformAlgorithm {
#[wasm_bindgen]
pub struct PowerTransform {
inner: Transform,
scale_params: Option<StandardScaleParams>,
standardize: Standardize,
scale_params: RefCell<Option<StandardScaleParams>>,
}

#[wasm_bindgen]
Expand All @@ -43,61 +46,70 @@ impl PowerTransform {
/// @experimental
#[wasm_bindgen(constructor)]
pub fn new(opts: PowerTransformOptions) -> Result<PowerTransform, JsError> {
let (scale_params, inner) = if opts.standardize {
let scale_params = StandardScaleParams::from_data(opts.data.iter().copied());
let scaler = Transform::standard_scaler(scale_params.clone());
let scaled: Vec<_> = scaler.transform(opts.data.iter().copied()).collect();
(
Some(scale_params),
Transform::power_transform(&scaled).map_err(|e| JsError::new(&e.to_string()))?,
)
} else {
(
None,
Transform::power_transform(&opts.data).map_err(|e| JsError::new(&e.to_string()))?,
)
};
Ok(PowerTransform {
inner,
scale_params,
inner: Transform::power_transform(&opts.data)
.map_err(|e| JsError::new(&e.to_string()))?,
standardize: opts.standardize.unwrap_or_default(),
scale_params: RefCell::new(None),
})
}

/// Transform the given data.
///
/// The transformed data is then scaled using a standard scaler (unless
/// `standardize` was set to `false` in the constructor).
/// The data is also scaled either before or after being transformed as per the standardize
/// option.
///
/// @experimental
#[wasm_bindgen]
pub fn transform(&self, data: VecF64) -> Result<Vec<f64>, JsError> {
let data = data.convert()?;
if let Some(scale_params) = &self.scale_params {
let scaler = Transform::standard_scaler(scale_params.clone());
let scaled: Vec<_> = scaler.transform(data.iter().copied()).collect();
Ok(self.inner.transform(scaled.iter().copied()).collect())
} else {
Ok(self.inner.transform(data.iter().copied()).collect())
}
Ok(match self.standardize {
Standardize::None => self.inner.transform(data.iter().copied()).collect(),
Standardize::Before => {
let scale_params = StandardScaleParams::from_data(data.iter().copied());
let scaler = Transform::standard_scaler(scale_params.clone());
self.scale_params.replace(Some(scale_params));
let scaled: Vec<_> = scaler.transform(data.iter().copied()).collect();
self.inner.transform(scaled.iter().copied()).collect()
}
Standardize::After => {
let transformed: Vec<_> = self.inner.transform(data.iter().copied()).collect();

let scale_params = StandardScaleParams::from_data(transformed.iter().copied());
let scaler = Transform::standard_scaler(scale_params.clone());
self.scale_params.replace(Some(scale_params));
scaler.transform(transformed.iter().copied()).collect()
}
})
}

/// Inverse transform the given data.
///
/// The data is first scaled back to the original scale using the standard scaler
/// (unless `standardize` was set to `false` in the constructor), then the
/// inverse power transform is applied.
/// The data is also inversely scaled according to the standardize option. The ordering is
/// opposite the order done in transform, i.e if transform scales first then transforms, then
/// inverse_transform transforms then scales.
///
/// @experimental
#[wasm_bindgen(js_name = "inverseTransform")]
pub fn inverse_transform(&self, data: VecF64) -> Result<Vec<f64>, JsError> {
let data = data.convert()?;
if let Some(scale_params) = &self.scale_params {
let scaler = Transform::standard_scaler(scale_params.clone());
let inverse_transformed = self.inner.inverse_transform(data.iter().copied());
Ok(scaler.inverse_transform(inverse_transformed).collect())
} else {
Ok(self.inner.inverse_transform(data.iter().copied()).collect())
}
Ok(
match (self.standardize, self.scale_params.borrow().as_ref()) {
(Standardize::Before, Some(scale_params)) => {
let inverse_transformed = self.inner.inverse_transform(data.iter().copied());
let inverse_scaler = Transform::standard_scaler(scale_params.clone());
inverse_scaler
.inverse_transform(inverse_transformed)
.collect()
}
(Standardize::After, Some(scale_params)) => {
let inverse_scaler = Transform::standard_scaler(scale_params.clone());
let scaled = inverse_scaler.inverse_transform(data.iter().copied());
self.inner.inverse_transform(scaled).collect()
}
_ => self.inner.inverse_transform(data.iter().copied()).collect(),
},
)
}

/// Get the algorithm used by the power transform.
Expand All @@ -122,8 +134,19 @@ impl PowerTransform {
}
}

fn default_standardize() -> bool {
true
/// When to standardize the data.
#[derive(Debug, Default, Clone, Copy, Eq, PartialEq, Serialize, Deserialize, Tsify)]
#[serde(rename_all = "camelCase")]
#[tsify(from_wasm_abi)]
pub enum Standardize {
/// Only run a power transform, do not standardize the data.
None,
/// Standardize the data before running the power transform. This may provide better results for data
/// with a non-zero floor.
Before,
/// Standardize the data after running the power transform. This matches the default in sklearn.
#[default]
After,
}

/// Options for the power transform.
Expand All @@ -137,8 +160,7 @@ pub struct PowerTransformOptions {

/// Whether to standardize the data after applying the power transform.
///
/// This is generally recommended, and defaults to `true`.
#[serde(default = "default_standardize")]
/// This is generally recommended, and defaults to [`Standardize::After`] to match sklearn.
#[tsify(optional)]
pub standardize: bool,
pub standardize: Option<Standardize>,
}

0 comments on commit 7a8586f

Please sign in to comment.