From 51ac166a79b1d68985eab703c3f646cf4c76b3c9 Mon Sep 17 00:00:00 2001 From: k-dominik Date: Mon, 22 Mar 2021 15:44:06 +0100 Subject: [PATCH] added binarize preprocessing function --- .../test_prediction_pipeline/test_preprocessing.py | 10 ++++++++++ tiktorch/server/prediction_pipeline/_preprocessing.py | 5 +++++ 2 files changed, 15 insertions(+) diff --git a/tests/test_server/test_prediction_pipeline/test_preprocessing.py b/tests/test_server/test_prediction_pipeline/test_preprocessing.py index 26f3647c..5e9ffd54 100644 --- a/tests/test_server/test_prediction_pipeline/test_preprocessing.py +++ b/tests/test_server/test_prediction_pipeline/test_preprocessing.py @@ -42,6 +42,16 @@ def test_zero_mean_unit_across_axes(): xr.testing.assert_allclose(expected, result[0]) +def test_binarize(): + binarize_spec = Preprocessing(name="binarize", kwargs={"threshold": 14}) + data = xr.DataArray(np.arange(30).reshape(2, 3, 5), dims=("x", "y", "c")) + expected = xr.zeros_like(data) + expected[{"x": slice(1, None)}] = 1 + preprocessing = make_preprocessing([binarize_spec]) + result = preprocessing(data) + xr.testing.assert_allclose(expected, result) + + def test_unknown_preprocessing_should_raise(): mypreprocessing = Preprocessing(name="mycoolpreprocessing", kwargs={"axes": ("x", "y")}) with pytest.raises(NotImplementedError): diff --git a/tiktorch/server/prediction_pipeline/_preprocessing.py b/tiktorch/server/prediction_pipeline/_preprocessing.py index d0439af0..3f52a4b6 100644 --- a/tiktorch/server/prediction_pipeline/_preprocessing.py +++ b/tiktorch/server/prediction_pipeline/_preprocessing.py @@ -25,6 +25,10 @@ def zero_mean_unit_variance(tensor: xr.DataArray, axes=None, eps=1.0e-6, mode="p return (tensor - mean) / (std + 1.0e-6) +def binarize(tensor: xr.DataArray, *, threshold) -> xr.DataArray: + return tensor > threshold + + def ensure_dtype(tensor: xr.DataArray, *, dtype): """ Convert array to a given datatype @@ -41,6 +45,7 @@ def add_batch_dim(tensor: xr.DataArray): KNOWN_PREPROCESSING = { "zero_mean_unit_variance": zero_mean_unit_variance, + "binarize": binarize, "__tiktorch_add_batch_dim": add_batch_dim, "__tiktorch_ensure_dtype": ensure_dtype, }