-
Notifications
You must be signed in to change notification settings - Fork 170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cuda.parallel: In-memory caching of cuda.parallel
build objects
#3216
base: main
Are you sure you want to change the base?
Changes from 4 commits
cb0eccc
da652e1
221af5c
0eea142
f198200
56f2c61
c822da7
58b6f69
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. | ||
# | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import functools | ||
|
||
|
||
def cache_with_key(key): | ||
""" | ||
Decorator to cache the result of the decorated function. Uses the | ||
provided `key` function to compute the key for cache lookup. `key` | ||
receives all arguments passed to the function. | ||
""" | ||
|
||
def deco(func): | ||
cache = {} | ||
|
||
@functools.wraps(func) | ||
def inner(*args, **kwargs): | ||
cache_key = key(*args, **kwargs) | ||
if cache_key not in cache: | ||
result = func(*args, **kwargs) | ||
cache[cache_key] = result | ||
# `cache_key` *must* be in `cache`, use `.get()` | ||
# as it is faster: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was surprised to read that, and chatgpt does not agree. My prompt was: "If obj is a Python dict, is obj.get(key) faster than obj[key]?" I recommend keeping this code straightforward and intuitive: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right. I think my memory served me poorly here, and the quick benchmark I did turns out to give inconsistent results. I changed it to just |
||
return cache.get(cache_key) | ||
|
||
return inner | ||
|
||
return deco |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. | ||
# | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
""" | ||
Utilities for extracting information from `__cuda_array_interface__`. | ||
""" | ||
|
||
import numpy as np | ||
|
||
from ..typing import DeviceArrayLike | ||
|
||
|
||
def get_dtype(arr: DeviceArrayLike) -> np.dtype: | ||
return np.dtype(arr.__cuda_array_interface__["typestr"]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,10 @@ | |
|
||
from .. import _cccl as cccl | ||
from .._bindings import get_paths, get_bindings | ||
from .._caching import cache_with_key | ||
from ..typing import DeviceArrayLike | ||
from ..iterators._iterators import IteratorBase | ||
from .._utils import cai as cai | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oversight? (delete There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in 56f2c61. |
||
|
||
|
||
class _Op: | ||
|
@@ -41,12 +45,18 @@ def _dtype_validation(dt1, dt2): | |
|
||
class _Reduce: | ||
# TODO: constructor shouldn't require concrete `d_in`, `d_out`: | ||
def __init__(self, d_in, d_out, op: Callable, h_init: np.ndarray): | ||
def __init__( | ||
self, | ||
d_in: DeviceArrayLike | IteratorBase, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note to self: Python 3.7 isn't going to like this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you reminding yourself to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead, I used |
||
d_out: DeviceArrayLike, | ||
op: Callable, | ||
h_init: np.ndarray, | ||
): | ||
d_in_cccl = cccl.to_cccl_iter(d_in) | ||
self._ctor_d_in_cccl_type_enum_name = cccl.type_enum_as_name( | ||
d_in_cccl.value_type.type.value | ||
) | ||
self._ctor_d_out_dtype = d_out.dtype | ||
self._ctor_d_out_dtype = cai.get_dtype(d_out) | ||
self._ctor_init_dtype = h_init.dtype | ||
cc_major, cc_minor = cuda.get_current_device().compute_capability | ||
cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths() | ||
|
@@ -119,9 +129,28 @@ def __del__(self): | |
bindings.cccl_device_reduce_cleanup(ctypes.byref(self.build_result)) | ||
|
||
|
||
def make_cache_key( | ||
d_in: DeviceArrayLike | IteratorBase, | ||
d_out: DeviceArrayLike, | ||
op: Callable, | ||
h_init: np.ndarray, | ||
): | ||
d_in_key = d_in if isinstance(d_in, IteratorBase) else cai.get_dtype(d_in) | ||
d_out_key = d_out if isinstance(d_out, IteratorBase) else cai.get_dtype(d_out) | ||
op_key = (op.__code__.co_code, op.__code__.co_consts, op.__closure__) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The It'll not be great if this code is copy-pasted as we add more algorithms. One idea would be to introduce helper functions, but that would only be slightly better. I wonder if we could do much better, right in your decorator. You could loop over There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Ah, nice observation. I changed it to just
We could, but I'd prefer separating the concerns here, at the risk of a tiny bit of logic repetition across usages of |
||
h_init_key = h_init.dtype | ||
return (d_in_key, d_out_key, op_key, h_init_key) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. important: there's an implicit state affecting the key. As written, I'd get the same reducer object for devices of different architectures: cudaSetDevice(0)
reducer_1 = reduce_into(d_in, d_out, ...)
cudaSetDevice(1)
reducer_2 = reduce_into(d_in, d_out, ...) Let's incorporate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call. I did that in 56f2c61. |
||
|
||
|
||
# TODO Figure out `sum` without operator and initial value | ||
# TODO Accept stream | ||
def reduce_into(d_in, d_out, op: Callable, h_init: np.ndarray): | ||
@cache_with_key(make_cache_key) | ||
def reduce_into( | ||
d_in: DeviceArrayLike | IteratorBase, | ||
d_out: DeviceArrayLike, | ||
op: Callable, | ||
h_init: np.ndarray, | ||
): | ||
"""Computes a device-wide reduction using the specified binary ``op`` functor and initial value ``init``. | ||
|
||
Example: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
from llvmlite import ir | ||
from numba.core.extending import intrinsic, overload | ||
from numba.core.typing.ctypes_utils import to_ctypes | ||
from numba.cuda.dispatcher import CUDADispatcher | ||
from numba import cuda, types | ||
import numba | ||
import numpy as np | ||
|
@@ -15,6 +16,19 @@ | |
_DEVICE_POINTER_BITWIDTH = _DEVICE_POINTER_SIZE * 8 | ||
|
||
|
||
def _compare_funcs(func1, func2): | ||
# return True if the functions compare equal for | ||
# caching purposes, False otherwise | ||
code1 = func1.__code__ | ||
code2 = func2.__code__ | ||
|
||
return ( | ||
code1.co_code == code2.co_code | ||
and code1.co_consts == code2.co_consts | ||
and func1.__closure__ == func2.__closure__ | ||
) | ||
|
||
|
||
@lru_cache(maxsize=256) # TODO: what's a reasonable value? | ||
def cached_compile(func, sig, abi_name=None, **kwargs): | ||
return cuda.compile(func, sig, abi_info={"abi_name": abi_name}, **kwargs) | ||
|
@@ -24,12 +38,10 @@ class IteratorBase: | |
""" | ||
An Iterator is a wrapper around a pointer, and must define the following: | ||
|
||
- a `state` property that returns a `ctypes.c_void_p` object, representing | ||
a pointer to some data. | ||
- an `advance` (static) method that receives the state pointer and performs | ||
- an `advance` (static) method that receives the pointer and performs | ||
an action that advances the pointer by the offset `distance` | ||
(returns nothing). | ||
- a `dereference` (static) method that dereferences the state pointer | ||
- a `dereference` (static) method that dereferences the pointer | ||
and returns a value. | ||
|
||
Iterators are not meant to be used directly. They are constructed and passed | ||
|
@@ -38,18 +50,28 @@ class IteratorBase: | |
The `advance` and `dereference` must be compilable to device code by numba. | ||
""" | ||
|
||
def __init__(self, numba_type: types.Type, value_type: types.Type, abi_name: str): | ||
def __init__( | ||
self, | ||
cvalue: ctypes.c_void_p, | ||
numba_type: types.Type, | ||
value_type: types.Type, | ||
abi_name: str, | ||
): | ||
""" | ||
Parameters | ||
---------- | ||
cvalue | ||
A ctypes type representing the object pointed to by the iterator. | ||
numba_type | ||
A numba type that specifies how to interpret the state pointer. | ||
A numba type representing the type of the input to the advance | ||
and dereference functions. | ||
value_type | ||
The numba type of the value returned by the dereference operation. | ||
abi_name | ||
A unique identifier that will determine the abi_names for the | ||
advance and dereference operations. | ||
""" | ||
self.cvalue = cvalue | ||
self.numba_type = numba_type | ||
self.value_type = value_type | ||
self.abi_name = abi_name | ||
|
@@ -81,7 +103,7 @@ def ltoirs(self) -> Dict[str, bytes]: | |
|
||
@property | ||
def state(self) -> ctypes.c_void_p: | ||
raise NotImplementedError("Subclasses must override advance staticmethod") | ||
return ctypes.cast(ctypes.pointer(self.cvalue), ctypes.c_void_p) | ||
|
||
@staticmethod | ||
def advance(state, distance): | ||
|
@@ -91,6 +113,21 @@ def advance(state, distance): | |
def dereference(state): | ||
raise NotImplementedError("Subclasses must override dereference staticmethod") | ||
|
||
def __hash__(self): | ||
return hash( | ||
(self.cvalue.value, self.numba_type, self.value_type, self.abi_name) | ||
) | ||
|
||
def __eq__(self, other): | ||
if not isinstance(other, self.__class__): | ||
return NotImplemented | ||
return ( | ||
self.cvalue.value == other.cvalue.value | ||
and self.numba_type == other.numba_type | ||
and self.value_type == other.value_type | ||
and self.abi_name == other.abi_name | ||
) | ||
|
||
|
||
def sizeof_pointee(context, ptr): | ||
size = context.get_abi_sizeof(ptr.type.pointee) | ||
|
@@ -125,10 +162,11 @@ def impl(ptr, offset): | |
class RawPointer(IteratorBase): | ||
def __init__(self, ptr: int, ntype: types.Type): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe rename ntype to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in 56f2c61. |
||
value_type = ntype | ||
self._cvalue = ctypes.c_void_p(ptr) | ||
cvalue = ctypes.c_void_p(ptr) | ||
numba_type = types.CPointer(types.CPointer(value_type)) | ||
abi_name = f"{self.__class__.__name__}_{str(value_type)}" | ||
super().__init__( | ||
cvalue=cvalue, | ||
numba_type=numba_type, | ||
value_type=value_type, | ||
abi_name=abi_name, | ||
|
@@ -142,10 +180,6 @@ def advance(state, distance): | |
def dereference(state): | ||
return state[0][0] | ||
|
||
@property | ||
def state(self) -> ctypes.c_void_p: | ||
return ctypes.cast(ctypes.pointer(self._cvalue), ctypes.c_void_p) | ||
|
||
|
||
def pointer(container, ntype: types.Type) -> RawPointer: | ||
return RawPointer(container.__cuda_array_interface__["data"][0], ntype) | ||
|
@@ -174,11 +208,12 @@ def codegen(context, builder, sig, args): | |
|
||
class CacheModifiedPointer(IteratorBase): | ||
def __init__(self, ptr: int, ntype: types.Type): | ||
self._cvalue = ctypes.c_void_p(ptr) | ||
cvalue = ctypes.c_void_p(ptr) | ||
value_type = ntype | ||
numba_type = types.CPointer(types.CPointer(value_type)) | ||
abi_name = f"{self.__class__.__name__}_{str(value_type)}" | ||
super().__init__( | ||
cvalue=cvalue, | ||
numba_type=numba_type, | ||
value_type=value_type, | ||
abi_name=abi_name, | ||
|
@@ -192,18 +227,15 @@ def advance(state, distance): | |
def dereference(state): | ||
return load_cs(state[0]) | ||
|
||
@property | ||
def state(self) -> ctypes.c_void_p: | ||
return ctypes.cast(ctypes.pointer(self._cvalue), ctypes.c_void_p) | ||
|
||
|
||
class ConstantIterator(IteratorBase): | ||
def __init__(self, value: np.number): | ||
value_type = numba.from_dtype(value.dtype) | ||
self._cvalue = to_ctypes(value_type)(value) | ||
cvalue = to_ctypes(value_type)(value) | ||
numba_type = types.CPointer(value_type) | ||
abi_name = f"{self.__class__.__name__}_{str(value_type)}" | ||
super().__init__( | ||
cvalue=cvalue, | ||
numba_type=numba_type, | ||
value_type=value_type, | ||
abi_name=abi_name, | ||
|
@@ -217,18 +249,15 @@ def advance(state, distance): | |
def dereference(state): | ||
return state[0] | ||
|
||
@property | ||
def state(self) -> ctypes.c_void_p: | ||
return ctypes.cast(ctypes.pointer(self._cvalue), ctypes.c_void_p) | ||
|
||
|
||
class CountingIterator(IteratorBase): | ||
def __init__(self, value: np.number): | ||
value_type = numba.from_dtype(value.dtype) | ||
self._cvalue = to_ctypes(value_type)(value) | ||
cvalue = to_ctypes(value_type)(value) | ||
numba_type = types.CPointer(value_type) | ||
abi_name = f"{self.__class__.__name__}_{str(value_type)}" | ||
super().__init__( | ||
cvalue=cvalue, | ||
numba_type=numba_type, | ||
value_type=value_type, | ||
abi_name=abi_name, | ||
|
@@ -242,10 +271,6 @@ def advance(state, distance): | |
def dereference(state): | ||
return state[0] | ||
|
||
@property | ||
def state(self) -> ctypes.c_void_p: | ||
return ctypes.cast(ctypes.pointer(self._cvalue), ctypes.c_void_p) | ||
|
||
|
||
def make_transform_iterator(it, op: Callable): | ||
if hasattr(it, "__cuda_array_interface__"): | ||
|
@@ -256,8 +281,9 @@ def make_transform_iterator(it, op: Callable): | |
op = cuda.jit(op, device=True) | ||
|
||
class TransformIterator(IteratorBase): | ||
def __init__(self, it: IteratorBase, op): | ||
def __init__(self, it: IteratorBase, op: CUDADispatcher): | ||
self._it = it | ||
self._op = op | ||
numba_type = it.numba_type | ||
# TODO: the abi name below isn't unique enough when we have e.g., | ||
# two identically named `op` functions with different | ||
|
@@ -276,6 +302,7 @@ def __init__(self, it: IteratorBase, op): | |
value_type = op_retty | ||
abi_name = f"{self.__class__.__name__}_{it.abi_name}_{op_abi_name}" | ||
super().__init__( | ||
cvalue=it.cvalue, | ||
numba_type=numba_type, | ||
value_type=value_type, | ||
abi_name=abi_name, | ||
|
@@ -289,8 +316,20 @@ def advance(state, distance): | |
def dereference(state): | ||
return op(it_dereference(state)) | ||
|
||
@property | ||
def state(self) -> ctypes.c_void_p: | ||
return it.state | ||
def __hash__(self): | ||
return hash( | ||
( | ||
self._it, | ||
self._op.py_func.__code__.co_code, | ||
self._op.py_func.__closure__, | ||
) | ||
) | ||
|
||
def __eq__(self, other): | ||
if not isinstance(other, IteratorBase): | ||
return NotImplemented | ||
return self._it == other._it and _compare_funcs( | ||
self._op.py_func, other._op.py_func | ||
) | ||
|
||
return TransformIterator(it, op) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from typing import Protocol | ||
|
||
|
||
class DeviceArrayLike(Protocol): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
would be much more expressive. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
""" | ||
Objects representing a device array, having a `.__cuda_array_interface__` | ||
attribute. | ||
""" | ||
|
||
__cuda_array_interface__: dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
key_factory
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with
key
as it's the name for the similar argument in e.g.,sorted
andcachetools
.