From 533da7513de42c5eb6395ea28edb594f60d685a0 Mon Sep 17 00:00:00 2001 From: Liam Gray Date: Fri, 16 Aug 2024 12:16:01 -0700 Subject: [PATCH] fix(random): pass dtype argument to numpy random Generator numpy random Generator methods do not check the datatype of an output array, so if the output dtype is not the default (np.float64), the dtype must be explicitly passed to the generator. Closes #288 --- draco/util/random.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/draco/util/random.py b/draco/util/random.py index 6ff1edd72..0c7b262e9 100644 --- a/draco/util/random.py +++ b/draco/util/random.py @@ -418,11 +418,16 @@ def _call(*args, **kwargs): # A worker method for each thread to fill its part of the array with the # random numbers def _fill(gen: np.random.Generator, local_array: np.ndarray) -> None: + if has_dtype: + kwargs["dtype"] = dtype if has_out: + if out.dtype != dtype: + raise TypeError( + f"Output array of type f{local_array.dtype} does not " + f"match dtype argument {dtype}." + ) method(gen, *args, **kwargs, out=local_array) else: - if has_dtype: - kwargs["dtype"] = dtype local_array[:] = method( gen, *args,