Skip to content

Commit

Permalink
Use the .kind to generate an abi_name
Browse files Browse the repository at this point in the history
  • Loading branch information
shwina committed Dec 24, 2024
1 parent c822da7 commit 86c58fb
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 55 deletions.
33 changes: 14 additions & 19 deletions python/cuda_parallel/cuda/parallel/experimental/_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,24 @@ class CachableFunction:

def __init__(self, func):
self._func = func
self._identity = None

@property
def identity(self):
if self._identity is not None:
return self._identity
self._identity = (
self._func.__code__.co_code,
self._func.__code__.co_consts,
self._func.__closure__,
)
return self._identity

def __eq__(self, other):
func1, func2 = self._func, other._func

# return True if the functions compare equal for
# caching purposes, False otherwise
code1 = func1.__code__
code2 = func2.__code__

return (
code1.co_code == code2.co_code
and code1.co_consts == code2.co_consts
and func1.__closure__ == func2.__closure__
)
return self.identity == other.identity

def __hash__(self):
return hash(
(
self._func.__code__.co_code,
self._func.__code__.co_consts,
self._func.__closure__,
)
)
return hash(self.identity)

def __repr__(self):
return str(self._func)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ctypes
import operator
import uuid
from functools import lru_cache
from typing import Dict, Callable

Expand All @@ -18,6 +19,14 @@
_DEVICE_POINTER_BITWIDTH = _DEVICE_POINTER_SIZE * 8


@lru_cache(maxsize=None)
def _get_abi_suffix(kind: "IteratorKind"):
# given an IteratorKind, return a UUID. The value
# is cached so that the same UUID is always returned
# for a given IteratorKind.
return uuid.uuid4().hex


@lru_cache(maxsize=256) # TODO: what's a reasonable value?
def cached_compile(func, sig, abi_name=None, **kwargs):
return cuda.compile(func, sig, abi_info={"abi_name": abi_name}, **kwargs)
Expand Down Expand Up @@ -60,7 +69,6 @@ def __init__(
cvalue: ctypes.c_void_p,
numba_type: types.Type,
value_type: types.Type,
abi_name: str,
):
"""
Parameters
Expand All @@ -72,14 +80,10 @@ def __init__(
and dereference functions.
value_type
The numba type of the value returned by the dereference operation.
abi_name
A unique identifier that will determine the abi_names for the
advance and dereference operations.
"""
self.cvalue = cvalue
self.numba_type = numba_type
self.value_type = value_type
self.abi_name = abi_name

@property
def kind(self):
Expand All @@ -90,8 +94,8 @@ def kind(self):
# needed.
@property
def ltoirs(self) -> Dict[str, bytes]:
advance_abi_name = self.abi_name + "_advance"
deref_abi_name = self.abi_name + "_dereference"
advance_abi_name = "advance_" + _get_abi_suffix(self.kind)
deref_abi_name = "dereference_" + _get_abi_suffix(self.kind)
advance_ltoir, _ = cached_compile(
self.__class__.advance,
(
Expand Down Expand Up @@ -123,18 +127,16 @@ def dereference(state):
raise NotImplementedError("Subclasses must override dereference staticmethod")

def __hash__(self):
return hash(
(self.cvalue.value, self.numba_type, self.value_type, self.abi_name)
)
return hash((self.kind, self.cvalue.value, self.numba_type, self.value_type))

def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (
self.cvalue.value == other.cvalue.value
self.kind == other.kind
and self.cvalue.value == other.cvalue.value
and self.numba_type == other.numba_type
and self.value_type == other.value_type
and self.abi_name == other.abi_name
)


Expand Down Expand Up @@ -178,12 +180,10 @@ class RawPointer(IteratorBase):
def __init__(self, ptr: int, value_type: types.Type):
cvalue = ctypes.c_void_p(ptr)
numba_type = types.CPointer(types.CPointer(value_type))
abi_name = f"{self.__class__.__name__}_{str(value_type)}"
super().__init__(
cvalue=cvalue,
numba_type=numba_type,
value_type=value_type,
abi_name=abi_name,
)

@staticmethod
Expand Down Expand Up @@ -231,12 +231,10 @@ def __init__(self, ptr: int, ntype: types.Type):
cvalue = ctypes.c_void_p(ptr)
value_type = ntype
numba_type = types.CPointer(types.CPointer(value_type))
abi_name = f"{self.__class__.__name__}_{str(value_type)}"
super().__init__(
cvalue=cvalue,
numba_type=numba_type,
value_type=value_type,
abi_name=abi_name,
)

@staticmethod
Expand All @@ -259,12 +257,10 @@ def __init__(self, value: np.number):
value_type = numba.from_dtype(value.dtype)
cvalue = to_ctypes(value_type)(value)
numba_type = types.CPointer(value_type)
abi_name = f"{self.__class__.__name__}_{str(value_type)}"
super().__init__(
cvalue=cvalue,
numba_type=numba_type,
value_type=value_type,
abi_name=abi_name,
)

@staticmethod
Expand All @@ -287,12 +283,10 @@ def __init__(self, value: np.number):
value_type = numba.from_dtype(value.dtype)
cvalue = to_ctypes(value_type)(value)
numba_type = types.CPointer(value_type)
abi_name = f"{self.__class__.__name__}_{str(value_type)}"
super().__init__(
cvalue=cvalue,
numba_type=numba_type,
value_type=value_type,
abi_name=abi_name,
)

@staticmethod
Expand Down Expand Up @@ -327,27 +321,20 @@ def __init__(self, it: IteratorBase, op: CUDADispatcher):
self._it = it
self._op = CachableFunction(op.py_func)
numba_type = it.numba_type
# TODO: the abi name below isn't unique enough when we have e.g.,
# two identically named `op` functions with different
# signatures, bytecodes, and/or closure variables.
op_abi_name = f"{self.__class__.__name__}_{op.py_func.__name__}"

# TODO: it would be nice to not need to compile `op` to get
# its return type, but there's nothing in the numba API
# to do that (yet),
_, op_retty = cached_compile(
op,
(self._it.value_type,),
abi_name=op_abi_name,
abi_name=f"{op.__name__}_{_get_abi_suffix(self.kind)}",
output="ltoir",
)
value_type = op_retty
abi_name = f"{self.__class__.__name__}_{it.abi_name}_{op_abi_name}"
super().__init__(
cvalue=it.cvalue,
numba_type=numba_type,
value_type=value_type,
abi_name=abi_name,
)

@property
Expand All @@ -363,16 +350,10 @@ def dereference(state):
return op(it_dereference(state))

def __hash__(self):
return hash(
(
self._it,
self._op._func.py_func.__code__.co_code,
self._op._func.py_func.__closure__,
)
)
return hash((self._it, self._op.identity))

def __eq__(self, other):
if not isinstance(other, IteratorBase):
if not isinstance(other.kind, TransformIteratorKind):
return NotImplemented
return self._it == other._it and self._op == other._op

Expand Down
2 changes: 2 additions & 0 deletions python/cuda_parallel/tests/test_iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_cache_modified_input_iterator_equality():

assert it1 == it2
assert it1 != it3

assert it1.kind == it2.kind == it3.kind
assert it1.kind != it4.kind

Expand All @@ -71,6 +72,7 @@ def op3(x):
assert it1 == it2
assert it1 != it3
assert it1 == it4

assert it1.kind == it2.kind == it4.kind

ary1 = cp.asarray([0, 1, 2])
Expand Down

0 comments on commit 86c58fb

Please sign in to comment.