diff --git a/tiktorch/runner/prediction_pipeline/_preprocessing.py b/tiktorch/runner/prediction_pipeline/_preprocessing.py index f5f9c8d7..7abd2ce0 100644 --- a/tiktorch/runner/prediction_pipeline/_preprocessing.py +++ b/tiktorch/runner/prediction_pipeline/_preprocessing.py @@ -32,7 +32,11 @@ def zero_mean_unit_variance(tensor: xr.DataArray, axes=None, eps=1.0e-6, mode="p if mode != "per_sample": raise NotImplementedError(f"Unsupported mode for zero_mean_unit_variance: {mode}") - return (tensor - mean) / (std + 1.0e-6) + ret = (tensor - mean) / (std + 1.0e-6) + + # monkey patch: maks sure we don't change dtype + # todo: allow preprocessing to change dtype? + return ret.astype("float32") def binarize(tensor: xr.DataArray, *, threshold) -> xr.DataArray: