Skip to content

Commit

Permalink
improve scalar dtype handling
Browse files Browse the repository at this point in the history
  • Loading branch information
reuster986 authored Feb 23, 2022
1 parent 980154c commit 066852e
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,14 @@ def _binop(self, other : pdarray, op : str) -> pdarray:
repMsg = generic_msg(cmd=cmd,args=args)
return create_pdarray(repMsg)
# pdarray binop scalar
dt = resolve_scalar_dtype(other)
if np.can_cast(other, self.dtype):
# If scalar can be losslessly cast to array dtype,
# do the cast so that return array will have same dtype
dt = self.dtype.name
other = self.dtype.type(other)
else:
# If scalar cannot be safely cast, server will infer the return dtype
dt = resolve_scalar_dtype(other)
if dt not in DTypes:
raise TypeError("Unhandled scalar type: {} ({})".format(other,
type(other)))
Expand Down Expand Up @@ -250,7 +257,14 @@ def _r_binop(self, other : pdarray, op : str) -> pdarray:
if op not in self.BinOps:
raise ValueError("bad operator {}".format(op))
# pdarray binop scalar
dt = resolve_scalar_dtype(other)
if np.can_cast(other, self.dtype):
# If scalar can be losslessly cast to array dtype,
# do the cast so that return array will have same dtype
dt = self.dtype.name
other = self.dtype.type(other)
else:
# If scalar cannot be safely cast, server will infer the return dtype
dt = resolve_scalar_dtype(other)
if dt not in DTypes:
raise TypeError("Unhandled scalar type: {} ({})".format(other,
type(other)))
Expand Down

0 comments on commit 066852e

Please sign in to comment.