diff --git a/tests/test_server/test_prediction_pipeline/test_preprocessing.py b/tests/test_server/test_prediction_pipeline/test_preprocessing.py index 8185d7cb..01a8ca8d 100644 --- a/tests/test_server/test_prediction_pipeline/test_preprocessing.py +++ b/tests/test_server/test_prediction_pipeline/test_preprocessing.py @@ -51,6 +51,16 @@ def test_zero_mean_unit_across_axes(): xr.testing.assert_allclose(expected, result[0]) +def test_binarize(): + binarize_spec = Preprocessing(name="binarize", kwargs={"threshold": 14}) + data = xr.DataArray(np.arange(30).reshape(2, 3, 5), dims=("x", "y", "c")) + expected = xr.zeros_like(data) + expected[{"x": slice(1, None)}] = 1 + preprocessing = make_preprocessing([binarize_spec]) + result = preprocessing(data) + xr.testing.assert_allclose(expected, result) + + def test_unknown_preprocessing_should_raise(): mypreprocessing = Preprocessing(name="mycoolpreprocessing", kwargs={"axes": ("x", "y")}) with pytest.raises(NotImplementedError): diff --git a/tiktorch/server/prediction_pipeline/_preprocessing.py b/tiktorch/server/prediction_pipeline/_preprocessing.py index 92bed2cb..4c0b819e 100644 --- a/tiktorch/server/prediction_pipeline/_preprocessing.py +++ b/tiktorch/server/prediction_pipeline/_preprocessing.py @@ -29,6 +29,10 @@ def zero_mean_unit_variance(tensor: xr.DataArray, axes=None, eps=1.0e-6, mode="p return (tensor - mean) / (std + 1.0e-6) +def binarize(tensor: xr.DataArray, *, threshold) -> xr.DataArray: + return tensor > threshold + + def ensure_dtype(tensor: xr.DataArray, *, dtype): """ Convert array to a given datatype @@ -46,6 +50,7 @@ def add_batch_dim(tensor: xr.DataArray): KNOWN_PREPROCESSING = { "scale_linear": scale_linear, "zero_mean_unit_variance": zero_mean_unit_variance, + "binarize": binarize, "__tiktorch_add_batch_dim": add_batch_dim, "__tiktorch_ensure_dtype": ensure_dtype, }