From 066852ec72adef705e08404b48e44b76c97f0544 Mon Sep 17 00:00:00 2001 From: reuster986 Date: Wed, 23 Feb 2022 15:13:39 -0500 Subject: [PATCH] improve scalar dtype handling --- arkouda/pdarrayclass.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/arkouda/pdarrayclass.py b/arkouda/pdarrayclass.py index 44eaee9906..9ac9b7b9d0 100755 --- a/arkouda/pdarrayclass.py +++ b/arkouda/pdarrayclass.py @@ -210,7 +210,14 @@ def _binop(self, other : pdarray, op : str) -> pdarray: repMsg = generic_msg(cmd=cmd,args=args) return create_pdarray(repMsg) # pdarray binop scalar - dt = resolve_scalar_dtype(other) + if np.can_cast(other, self.dtype): + # If scalar can be losslessly cast to array dtype, + # do the cast so that return array will have same dtype + dt = self.dtype.name + other = self.dtype.type(other) + else: + # If scalar cannot be safely cast, server will infer the return dtype + dt = resolve_scalar_dtype(other) if dt not in DTypes: raise TypeError("Unhandled scalar type: {} ({})".format(other, type(other))) @@ -250,7 +257,14 @@ def _r_binop(self, other : pdarray, op : str) -> pdarray: if op not in self.BinOps: raise ValueError("bad operator {}".format(op)) # pdarray binop scalar - dt = resolve_scalar_dtype(other) + if np.can_cast(other, self.dtype): + # If scalar can be losslessly cast to array dtype, + # do the cast so that return array will have same dtype + dt = self.dtype.name + other = self.dtype.type(other) + else: + # If scalar cannot be safely cast, server will infer the return dtype + dt = resolve_scalar_dtype(other) if dt not in DTypes: raise TypeError("Unhandled scalar type: {} ({})".format(other, type(other)))