Skip to content

Commit

Permalink
Merge pull request #1141 from Bears-R-Us/improve-scalar-dtype
Browse files Browse the repository at this point in the history
Improve scalar dtype handling
  • Loading branch information
reuster986 authored Feb 23, 2022
2 parents c306460 + 066852e commit 899e312
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,14 @@ def _binop(self, other : pdarray, op : str) -> pdarray:
repMsg = generic_msg(cmd=cmd,args=args)
return create_pdarray(repMsg)
# pdarray binop scalar
dt = resolve_scalar_dtype(other)
if np.can_cast(other, self.dtype):
# If scalar can be losslessly cast to array dtype,
# do the cast so that return array will have same dtype
dt = self.dtype.name
other = self.dtype.type(other)
else:
# If scalar cannot be safely cast, server will infer the return dtype
dt = resolve_scalar_dtype(other)
if dt not in DTypes:
raise TypeError("Unhandled scalar type: {} ({})".format(other,
type(other)))
Expand Down Expand Up @@ -250,7 +257,14 @@ def _r_binop(self, other : pdarray, op : str) -> pdarray:
if op not in self.BinOps:
raise ValueError("bad operator {}".format(op))
# pdarray binop scalar
dt = resolve_scalar_dtype(other)
if np.can_cast(other, self.dtype):
# If scalar can be losslessly cast to array dtype,
# do the cast so that return array will have same dtype
dt = self.dtype.name
other = self.dtype.type(other)
else:
# If scalar cannot be safely cast, server will infer the return dtype
dt = resolve_scalar_dtype(other)
if dt not in DTypes:
raise TypeError("Unhandled scalar type: {} ({})".format(other,
type(other)))
Expand Down

0 comments on commit 899e312

Please sign in to comment.