Skip to content

Commit

Permalink
WIP: Fix errors with scaling before transformation
Browse files Browse the repository at this point in the history
If we scale before transforming then some values may become negative and
a Box Cox transformation is no longer valid. This commit waits to choose
which transformation algorithm to use until after potentially scaling
the data.
  • Loading branch information
csmarchbanks committed Dec 20, 2024
1 parent 8e460e4 commit 64d5d3d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
2 changes: 1 addition & 1 deletion crates/augurs-forecaster/src/transforms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl Transforms {
}

/// A transformation that can be applied to a time series.
#[derive(Debug)]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Transform {
/// Linear interpolation.
Expand Down
38 changes: 29 additions & 9 deletions js/augurs-transforms-js/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub enum PowerTransformAlgorithm {
#[derive(Debug)]
#[wasm_bindgen]
pub struct PowerTransform {
inner: Transform,
inner: Option<Transform>,
standardize: Standardize,
scale_params: Option<StandardScaleParams>,
}
Expand All @@ -45,8 +45,7 @@ impl PowerTransform {
#[wasm_bindgen(constructor)]
pub fn new(opts: PowerTransformOptions) -> Result<PowerTransform, JsError> {
Ok(PowerTransform {
inner: Transform::power_transform(&opts.data)
.map_err(|e| JsError::new(&e.to_string()))?,
inner: None,
standardize: opts.standardize.unwrap_or_default(),
scale_params: None,
})
Expand All @@ -62,16 +61,32 @@ impl PowerTransform {
pub fn transform(&mut self, data: VecF64) -> Result<Vec<f64>, JsError> {
let data = data.convert()?;
Ok(match self.standardize {
Standardize::None => self.inner.transform(data.iter().copied()).collect(),
Standardize::None => {
let transform =
Transform::power_transform(&data).map_err(|e| JsError::new(&e.to_string()))?;
let result = transform.transform(data.iter().copied()).collect();
self.inner = Some(transform);
result
}
Standardize::Before => {
let scale_params = StandardScaleParams::from_data(data.iter().copied());
let scaler = Transform::standard_scaler(scale_params.clone());
self.scale_params = Some(scale_params);
let scaled: Vec<_> = scaler.transform(data.iter().copied()).collect();
self.inner.transform(scaled.iter().copied()).collect()

let transform = Transform::power_transform(&scaled)
.map_err(|e| JsError::new(&e.to_string()))?;
let result = transform.transform(scaled.iter().copied()).collect();
self.inner = Some(transform);

result
}
Standardize::After => {
let transformed: Vec<_> = self.inner.transform(data.iter().copied()).collect();
let transform =
Transform::power_transform(&data).map_err(|e| JsError::new(&e.to_string()))?;

let transformed: Vec<_> = transform.transform(data.iter().copied()).collect();
self.inner = Some(transform);

let scale_params = StandardScaleParams::from_data(transformed.iter().copied());
let scaler = Transform::standard_scaler(scale_params.clone());
Expand All @@ -91,9 +106,10 @@ impl PowerTransform {
#[wasm_bindgen(js_name = "inverseTransform")]
pub fn inverse_transform(&self, data: VecF64) -> Result<Vec<f64>, JsError> {
let data = data.convert()?;
let transformer = self.inner.clone().unwrap();
Ok(match (self.standardize, self.scale_params.clone()) {
(Standardize::Before, Some(scale_params)) => {
let inverse_transformed = self.inner.inverse_transform(data.iter().copied());
let inverse_transformed = transformer.inverse_transform(data.iter().copied());
let inverse_scaler = Transform::standard_scaler(scale_params.clone());
inverse_scaler
.inverse_transform(inverse_transformed)
Expand All @@ -102,12 +118,15 @@ impl PowerTransform {
(Standardize::After, Some(scale_params)) => {
let inverse_scaler = Transform::standard_scaler(scale_params.clone());
let scaled = inverse_scaler.inverse_transform(data.iter().copied());
self.inner.inverse_transform(scaled).collect()
transformer.inverse_transform(scaled).collect()
}
_ => self.inner.inverse_transform(data.iter().copied()).collect(),
_ => transformer
.inverse_transform(data.iter().copied())
.collect(),
})
}

/*
/// Get the algorithm used by the power transform.
///
/// @experimental
Expand All @@ -128,6 +147,7 @@ impl PowerTransform {
_ => unreachable!(),
}
}
*/
}

/// When to standardize the data.
Expand Down

0 comments on commit 64d5d3d

Please sign in to comment.