Skip to content

Commit

Permalink
Merge pull request #155 from ilastik/scale-linear
Browse files Browse the repository at this point in the history
Add "scale_linear" preprocessing operation
  • Loading branch information
m-novikov authored Mar 26, 2021
2 parents 2d6803c + c3c80a6 commit a261fec
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
from tiktorch.server.prediction_pipeline._preprocessing import ADD_BATCH_DIM, make_preprocessing


def test_scale_linear():
spec = Preprocessing(name="scale_linear", kwargs={"offset": 42, "gain": 2})
data = xr.DataArray(np.arange(4).reshape(2, 2), dims=("x", "y"))
expected = xr.DataArray(np.array([[42, 44], [46, 48]]), dims=("x", "y"))
preprocessing = make_preprocessing([spec])
result = preprocessing(data)
xr.testing.assert_allclose(expected, result)


def test_zero_mean_unit_variance_preprocessing():
zero_mean_spec = Preprocessing(name="zero_mean_unit_variance", kwargs={})
data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y"))
Expand Down
5 changes: 5 additions & 0 deletions tiktorch/server/prediction_pipeline/_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ def make_ensure_dtype_preprocessing(dtype):
return Preprocessing(name="__tiktorch_ensure_dtype", kwargs={"dtype": dtype})


def scale_linear(tensor: xr.DataArray, *, gain, offset) -> xr.DataArray:
return gain * tensor + offset


def zero_mean_unit_variance(tensor: xr.DataArray, axes=None, eps=1.0e-6, mode="per_sample") -> xr.DataArray:
if axes:
axes = tuple(axes)
Expand Down Expand Up @@ -40,6 +44,7 @@ def add_batch_dim(tensor: xr.DataArray):


KNOWN_PREPROCESSING = {
"scale_linear": scale_linear,
"zero_mean_unit_variance": zero_mean_unit_variance,
"__tiktorch_add_batch_dim": add_batch_dim,
"__tiktorch_ensure_dtype": ensure_dtype,
Expand Down

0 comments on commit a261fec

Please sign in to comment.