-
Notifications
You must be signed in to change notification settings - Fork 170
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cuda.parallel: In-memory caching of cuda.parallel
build objects
#3216
base: main
Are you sure you want to change the base?
Changes from all commits
cb0eccc
da652e1
221af5c
0eea142
f198200
56f2c61
c822da7
58b6f69
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. | ||
# | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
import functools | ||
from numba import cuda | ||
|
||
|
||
def cache_with_key(key): | ||
""" | ||
Decorator to cache the result of the decorated function. Uses the | ||
provided `key` function to compute the key for cache lookup. `key` | ||
receives all arguments passed to the function. | ||
|
||
Notes | ||
----- | ||
The CUDA compute capability of the current device is appended to | ||
the cache key returned by `key`. | ||
""" | ||
|
||
def deco(func): | ||
cache = {} | ||
|
||
@functools.wraps(func) | ||
def inner(*args, **kwargs): | ||
cc = cuda.get_current_device().compute_capability | ||
cache_key = (key(*args, **kwargs), *cc) | ||
if cache_key not in cache: | ||
result = func(*args, **kwargs) | ||
cache[cache_key] = result | ||
return cache[cache_key] | ||
|
||
return inner | ||
|
||
return deco | ||
|
||
|
||
class CachableFunction: | ||
""" | ||
A type that wraps a function and provides custom comparison | ||
(__eq__) and hash (__hash__) implementations. | ||
|
||
The purpose of this class is to enable caching and comparison of | ||
functions based on their bytecode, constants, and closures, while | ||
ignoring other attributes such as their names or docstrings. | ||
""" | ||
|
||
def __init__(self, func): | ||
self._func = func | ||
self._identity = ( | ||
self._func.__code__.co_code, | ||
self._func.__code__.co_consts, | ||
self._func.__closure__, | ||
) | ||
|
||
def __eq__(self, other): | ||
return self._identity == other._identity | ||
|
||
def __hash__(self): | ||
return hash(self._identity) | ||
|
||
def __repr__(self): | ||
return str(self._func) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. | ||
# | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
""" | ||
Utilities for extracting information from `__cuda_array_interface__`. | ||
""" | ||
|
||
import numpy as np | ||
|
||
from ..typing import DeviceArrayLike | ||
|
||
|
||
def get_dtype(arr: DeviceArrayLike) -> np.dtype: | ||
return np.dtype(arr.__cuda_array_interface__["typestr"]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
# | ||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
from __future__ import annotations # TODO: required for Python 3.7 docs env | ||
|
||
import ctypes | ||
import numba | ||
import numpy as np | ||
|
@@ -12,6 +14,10 @@ | |
|
||
from .. import _cccl as cccl | ||
from .._bindings import get_paths, get_bindings | ||
from .._caching import CachableFunction, cache_with_key | ||
from ..typing import DeviceArrayLike | ||
from ..iterators._iterators import IteratorBase | ||
from .._utils import cai | ||
|
||
|
||
class _Op: | ||
|
@@ -41,12 +47,18 @@ def _dtype_validation(dt1, dt2): | |
|
||
class _Reduce: | ||
# TODO: constructor shouldn't require concrete `d_in`, `d_out`: | ||
def __init__(self, d_in, d_out, op: Callable, h_init: np.ndarray): | ||
def __init__( | ||
self, | ||
d_in: DeviceArrayLike | IteratorBase, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note to self: Python 3.7 isn't going to like this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you reminding yourself to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead, I used |
||
d_out: DeviceArrayLike, | ||
op: Callable, | ||
h_init: np.ndarray, | ||
): | ||
d_in_cccl = cccl.to_cccl_iter(d_in) | ||
self._ctor_d_in_cccl_type_enum_name = cccl.type_enum_as_name( | ||
d_in_cccl.value_type.type.value | ||
) | ||
self._ctor_d_out_dtype = d_out.dtype | ||
self._ctor_d_out_dtype = cai.get_dtype(d_out) | ||
self._ctor_init_dtype = h_init.dtype | ||
cc_major, cc_minor = cuda.get_current_device().compute_capability | ||
cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths() | ||
|
@@ -119,9 +131,28 @@ def __del__(self): | |
bindings.cccl_device_reduce_cleanup(ctypes.byref(self.build_result)) | ||
|
||
|
||
def make_cache_key( | ||
d_in: DeviceArrayLike | IteratorBase, | ||
d_out: DeviceArrayLike, | ||
op: Callable, | ||
h_init: np.ndarray, | ||
): | ||
d_in_key = d_in.kind if isinstance(d_in, IteratorBase) else cai.get_dtype(d_in) | ||
d_out_key = cai.get_dtype(d_out) | ||
op_key = CachableFunction(op) | ||
h_init_key = h_init.dtype | ||
return (d_in_key, d_out_key, op_key, h_init_key) | ||
|
||
|
||
# TODO Figure out `sum` without operator and initial value | ||
# TODO Accept stream | ||
def reduce_into(d_in, d_out, op: Callable, h_init: np.ndarray): | ||
@cache_with_key(make_cache_key) | ||
def reduce_into( | ||
d_in: DeviceArrayLike | IteratorBase, | ||
d_out: DeviceArrayLike, | ||
op: Callable, | ||
h_init: np.ndarray, | ||
): | ||
"""Computes a device-wide reduction using the specified binary ``op`` functor and initial value ``init``. | ||
|
||
Example: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
key_factory
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with
key
as it's the name for the similar argument in e.g.,sorted
andcachetools
.