From c3c80a6e7469d226c008a3023318f2fd7eaad77d Mon Sep 17 00:00:00 2001 From: Emil Melnikov Date: Mon, 22 Mar 2021 15:19:05 +0100 Subject: [PATCH] Add "scale_linear" preprocessing operation See https://github.com/ilastik/tiktorch/issues/152 --- .../test_prediction_pipeline/test_preprocessing.py | 9 +++++++++ tiktorch/server/prediction_pipeline/_preprocessing.py | 5 +++++ 2 files changed, 14 insertions(+) diff --git a/tests/test_server/test_prediction_pipeline/test_preprocessing.py b/tests/test_server/test_prediction_pipeline/test_preprocessing.py index 26f3647c..8185d7cb 100644 --- a/tests/test_server/test_prediction_pipeline/test_preprocessing.py +++ b/tests/test_server/test_prediction_pipeline/test_preprocessing.py @@ -6,6 +6,15 @@ from tiktorch.server.prediction_pipeline._preprocessing import ADD_BATCH_DIM, make_preprocessing +def test_scale_linear(): + spec = Preprocessing(name="scale_linear", kwargs={"offset": 42, "gain": 2}) + data = xr.DataArray(np.arange(4).reshape(2, 2), dims=("x", "y")) + expected = xr.DataArray(np.array([[42, 44], [46, 48]]), dims=("x", "y")) + preprocessing = make_preprocessing([spec]) + result = preprocessing(data) + xr.testing.assert_allclose(expected, result) + + def test_zero_mean_unit_variance_preprocessing(): zero_mean_spec = Preprocessing(name="zero_mean_unit_variance", kwargs={}) data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) diff --git a/tiktorch/server/prediction_pipeline/_preprocessing.py b/tiktorch/server/prediction_pipeline/_preprocessing.py index d0439af0..92bed2cb 100644 --- a/tiktorch/server/prediction_pipeline/_preprocessing.py +++ b/tiktorch/server/prediction_pipeline/_preprocessing.py @@ -12,6 +12,10 @@ def make_ensure_dtype_preprocessing(dtype): return Preprocessing(name="__tiktorch_ensure_dtype", kwargs={"dtype": dtype}) +def scale_linear(tensor: xr.DataArray, *, gain, offset) -> xr.DataArray: + return gain * tensor + offset + + def zero_mean_unit_variance(tensor: xr.DataArray, axes=None, eps=1.0e-6, mode="per_sample") -> xr.DataArray: if axes: axes = tuple(axes) @@ -40,6 +44,7 @@ def add_batch_dim(tensor: xr.DataArray): KNOWN_PREPROCESSING = { + "scale_linear": scale_linear, "zero_mean_unit_variance": zero_mean_unit_variance, "__tiktorch_add_batch_dim": add_batch_dim, "__tiktorch_ensure_dtype": ensure_dtype,