From 3b1d274e4cae6a275bcd03f4ced63b49e42fb5fa Mon Sep 17 00:00:00 2001 From: fynnbe Date: Thu, 15 Jul 2021 16:43:24 +0200 Subject: [PATCH] maks sure we don't change dtype --- tiktorch/runner/prediction_pipeline/_preprocessing.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tiktorch/runner/prediction_pipeline/_preprocessing.py b/tiktorch/runner/prediction_pipeline/_preprocessing.py index f5f9c8d7..952e299d 100644 --- a/tiktorch/runner/prediction_pipeline/_preprocessing.py +++ b/tiktorch/runner/prediction_pipeline/_preprocessing.py @@ -32,7 +32,11 @@ def zero_mean_unit_variance(tensor: xr.DataArray, axes=None, eps=1.0e-6, mode="p if mode != "per_sample": raise NotImplementedError(f"Unsupported mode for zero_mean_unit_variance: {mode}") - return (tensor - mean) / (std + 1.0e-6) + ret = (tensor - mean) / (std + 1.0e-6) + + # monkey patch: maks sure we don't change dtype + # todo: allow preprocessing to change dtype? + return ret.astype(tensor.dtype) def binarize(tensor: xr.DataArray, *, threshold) -> xr.DataArray: