diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index 44eaee9906..9ac9b7b9d0 100755 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -210,7 +210,14 @@ def _binop(self, other : pdarray, op : str) -> pdarray: repMsg = generic_msg(cmd=cmd,args=args) return create_pdarray(repMsg) # pdarray binop scalar - dt = resolve_scalar_dtype(other) + if np.can_cast(other, self.dtype): + # If scalar can be losslessly cast to array dtype, + # do the cast so that return array will have same dtype + dt = self.dtype.name + other = self.dtype.type(other) + else: + # If scalar cannot be safely cast, server will infer the return dtype + dt = resolve_scalar_dtype(other) if dt not in DTypes: raise TypeError("Unhandled scalar type: {} ({})".format(other, type(other))) @@ -250,7 +257,14 @@ def _r_binop(self, other : pdarray, op : str) -> pdarray: if op not in self.BinOps: raise ValueError("bad operator {}".format(op)) # pdarray binop scalar - dt = resolve_scalar_dtype(other) + if np.can_cast(other, self.dtype): + # If scalar can be losslessly cast to array dtype, + # do the cast so that return array will have same dtype + dt = self.dtype.name + other = self.dtype.type(other) + else: + # If scalar cannot be safely cast, server will infer the return dtype + dt = resolve_scalar_dtype(other) if dt not in DTypes: raise TypeError("Unhandled scalar type: {} ({})".format(other, type(other)))