Skip to content

Commit

Permalink
Closes #3917: copy function to match numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
ajpotts committed Dec 4, 2024
1 parent cc1797a commit 2297a4c
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 11 deletions.
17 changes: 17 additions & 0 deletions arkouda/pdarrayclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,23 @@ def max_bits(self, max_bits):
generic_msg(cmd="set_max_bits", args={"array": self, "max_bits": max_bits})
self._max_bits = max_bits

def copy(self) -> pdarray:
"""
Return an array copy of the given object.
Returns
-------
pdarray
A deep copy of the pdarray.
"""
from arkouda.pdarraycreation import array

ret = array(self, copy=True)
if isinstance(ret, pdarray):
return ret
else:
raise RuntimeError("Could not copy pdarray.")

def equals(self, other) -> bool_scalars:
"""
Whether pdarrays are the same size and all entries are equal.
Expand Down
24 changes: 22 additions & 2 deletions arkouda/pdarraycreation.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ def from_series(series: pd.Series, dtype: Optional[Union[type, str]] = None) ->


def array(
a: Union[pdarray, np.ndarray, Iterable],
a: Union[pdarray, np.ndarray, Iterable, Strings],
dtype: Union[np.dtype, type, str, None] = None,
copy: bool = True,
max_bits: int = -1,
) -> Union[pdarray, Strings]:
"""
Expand All @@ -153,6 +154,10 @@ def array(
Rank-1 array of a supported dtype
dtype: np.dtype, type, or str
The target dtype to cast values to
copy: bool=True, optional
If True (default), then the array data is copied.
Note that any copy of the data is deep, which differs from numpy.
For False it raises a ValueError if a copy cannot be avoided. Default: True.
max_bits: int
Specifies the maximum number of bits; only used for bigint pdarrays
Expand All @@ -167,6 +172,8 @@ def array(
TypeError
Raised if a is not a pdarray, np.ndarray, or Python Iterable such as a
list, array, tuple, or deque
Raised if a Strings is called with dtype other than ak.str_
RuntimeError
Raised if a is not one-dimensional, nbytes > maxTransferBytes, a.dtype is
not supported (not in DTypes), or if the product of a size and
Expand Down Expand Up @@ -206,9 +213,22 @@ def array(
"""
from arkouda.numpy import cast as akcast

if copy is False:
if isinstance(a, (Strings, pdarray)):
return a
else:
raise ValueError(
"In ak.array, copy=False can only used with applied to pdarray objects."
)

if isinstance(a, Strings):
if dtype and dtype != "str_":
raise TypeError(f"Cannot cast Strings to dtype {dtype} in ak.array")
return a[:]

# If a is already a pdarray, do nothing
if isinstance(a, pdarray):
casted = a if dtype is None else akcast(a, dtype)
casted = a[:] if dtype is None else akcast(a, dtype)
if dtype == bigint and max_bits != -1:
casted.max_bits = max_bits
return casted
Expand Down
46 changes: 38 additions & 8 deletions arkouda/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import builtins
import json
from typing import TYPE_CHECKING, Sequence, Tuple, Union, cast
from typing import TYPE_CHECKING, Iterable, Sequence, Tuple, TypeVar, Union, cast
from warnings import warn

import numpy as np
from typeguard import typechecked

from arkouda.categorical import Categorical
from arkouda.client import generic_msg, get_config, get_mem_used
from arkouda.client_dtypes import BitVector, BitVectorizer, IPv4
from arkouda.groupbyclass import GroupBy, broadcast
Expand All @@ -24,12 +24,17 @@
from arkouda.pdarraysetops import unique
from arkouda.segarray import SegArray
from arkouda.sorting import coargsort
from arkouda.strings import Strings
from arkouda.timeclass import Datetime, Timedelta

if TYPE_CHECKING:
from arkouda.index import Index
from arkouda.series import Series
from arkouda.strings import Strings
from arkouda.categorical import Categorical
else:
Strings = TypeVar("Strings")
Series = TypeVar("Series")
Categorical = TypeVar("Categorical")


def identity(x):
Expand Down Expand Up @@ -192,6 +197,7 @@ def attach(name: str):
from arkouda.index import Index, MultiIndex
from arkouda.pdarrayclass import pdarray
from arkouda.series import Series
from arkouda.categorical import Categorical

rep_msg = json.loads(cast(str, generic_msg(cmd="attach", args={"name": name})))
rtn_obj = None
Expand Down Expand Up @@ -425,7 +431,7 @@ def convert_bytes(nbytes, unit="B"):


def is_numeric(
arry: Union[pdarray, Strings, Categorical, "Series", "Index"] # noqa: F821
arry: Union[pdarray, Strings, "Categorical", "Series", "Index"] # noqa: F821
) -> builtins.bool:
"""
Check if the dtype of the given array is numeric.
Expand Down Expand Up @@ -460,7 +466,7 @@ def is_numeric(
return False


def is_float(arry: Union[pdarray, Strings, Categorical, "Series", "Index"]): # noqa: F821
def is_float(arry: Union[pdarray, Strings, "Categorical", "Series", "Index"]): # noqa: F821
"""
Check if the dtype of the given array is float.
Expand Down Expand Up @@ -494,7 +500,7 @@ def is_float(arry: Union[pdarray, Strings, Categorical, "Series", "Index"]): #
return False


def is_int(arry: Union[pdarray, Strings, Categorical, "Series", "Index"]): # noqa: F821
def is_int(arry: Union[pdarray, Strings, "Categorical", "Series", "Index"]): # noqa: F821
"""
Check if the dtype of the given array is int.
Expand Down Expand Up @@ -529,9 +535,10 @@ def is_int(arry: Union[pdarray, Strings, Categorical, "Series", "Index"]): # no
return False


@typechecked
def map(
values: Union[pdarray, Strings, Categorical], mapping: Union[dict, "Series"]
) -> Union[pdarray, Strings]:
values: Union[pdarray, "Strings", "Categorical"], mapping: Union[dict, "Series"]
) -> Union[pdarray, "Strings"]:
"""
Map values of an array according to an input mapping.
Expand Down Expand Up @@ -573,7 +580,9 @@ def map(
import numpy as np

from arkouda import Series, array, broadcast, full
from arkouda.categorical import Categorical
from arkouda.pdarraysetops import in1d
from arkouda.strings import Strings

keys = values
gb = GroupBy(keys, dropna=False)
Expand Down Expand Up @@ -623,3 +632,24 @@ def _infer_shape_from_size(size):
shape = full_size
ndim = 1
return shape, ndim, full_size


@typechecked
def copy(a: Union[pdarray, np.ndarray, Iterable, "Strings"]) -> Union[pdarray, "Strings"]:
"""
Return an array copy of the given object.
Returns
-------
pdarray
Array interpretation of a.
"""
from arkouda.strings import Strings

if isinstance(a, Strings):
cpy = a[:]
return cpy

from arkouda.pdarraycreation import array

return array(a, copy=True)
26 changes: 25 additions & 1 deletion tests/pdarray_creation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import pytest

import arkouda as ak
from arkouda.testing import assert_arkouda_array_equal, assert_equivalent
from arkouda.testing import assert_arkouda_array_equal
from arkouda.testing import assert_equal as ak_assert_equal
from arkouda.testing import assert_equivalent

INT_SCALARS = list(ak.dtypes.int_scalars.__args__)
NUMERIC_SCALARS = list(ak.dtypes.numeric_scalars.__args__)
Expand Down Expand Up @@ -43,6 +45,15 @@ def test_array_creation(self, dtype):
assert len(pda) == fixed_size
assert dtype == pda.dtype

def test_array_creation_strings(self):
fixed_size = 100
pda = ak.array(ak.arange(fixed_size, dtype=ak.str_))
assert isinstance(pda, ak.Strings)
assert len(pda) == fixed_size

with pytest.raises(TypeError):
ak.array(ak.arange(fixed_size, dtype=ak.str_), dtype=ak.int64),

@pytest.mark.skip_if_max_rank_less_than(3)
@pytest.mark.parametrize("size", pytest.prob_size)
@pytest.mark.parametrize("dtype", [int, ak.int64, ak.uint64, float, ak.float64, bool, ak.bool_])
Expand Down Expand Up @@ -105,6 +116,19 @@ def test_array_creation_misc(self):
with pytest.raises(TypeError):
ak.array(list(list(0)))

@pytest.mark.parametrize("dtype", [ak.int64, ak.float64, ak.bool_, ak.bigint])
def test_array_copy(self, dtype):

a = ak.arange(100, dtype=dtype)

b = ak.array(a, copy=True)
assert not a is b
ak_assert_equal(a, b)

c = ak.array(a, copy=False)
assert a is c
ak_assert_equal(a, c)

@pytest.mark.skip_if_max_rank_less_than(2)
def test_array_creation_transpose_bug_reproducer(self):

Expand Down
11 changes: 11 additions & 0 deletions tests/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import arkouda as ak
from arkouda.util import is_float, is_int, is_numeric, map
import pytest


class TestUtil:
Expand Down Expand Up @@ -119,3 +120,13 @@ def test_map(self):

result = map(d, {"1": 7.0})
assert np.allclose(result.to_list(), [7.0, 7.0, np.nan, np.nan, np.nan], equal_nan=True)

@pytest.mark.parametrize("dtype", [ak.int64, ak.float64, ak.bool_, ak.bigint, ak.str_])
def test_copy(self, dtype):
a = ak.arange(10, dtype=dtype)
b = ak.util.copy(a)

from arkouda import assert_equal as ak_assert_equal

assert not a is b
ak_assert_equal(a, b)

0 comments on commit 2297a4c

Please sign in to comment.