Skip to content

Commit

Permalink
fix(random): pass dtype argument to numpy random Generator
Browse files Browse the repository at this point in the history
numpy random Generator methods do not check the datatype of an output
array, so if the output dtype is not the default (np.float64), the
dtype must be explicitly passed to the generator.

Closes #288
  • Loading branch information
ljgray committed Aug 16, 2024
1 parent bae3d8c commit 533da75
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions draco/util/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,16 @@ def _call(*args, **kwargs):
# A worker method for each thread to fill its part of the array with the
# random numbers
def _fill(gen: np.random.Generator, local_array: np.ndarray) -> None:
if has_dtype:
kwargs["dtype"] = dtype
if has_out:
if out.dtype != dtype:
raise TypeError(
f"Output array of type f{local_array.dtype} does not "
f"match dtype argument {dtype}."
)
method(gen, *args, **kwargs, out=local_array)
else:
if has_dtype:
kwargs["dtype"] = dtype
local_array[:] = method(
gen,
*args,
Expand Down

0 comments on commit 533da75

Please sign in to comment.