Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda.parallel: In-memory caching of cuda.parallel build objects #3216

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions python/cuda_parallel/cuda/parallel/experimental/_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import functools


def cache_with_key(key):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key_factory
?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with key as it's the name for the similar argument in e.g., sorted and cachetools.

"""
Decorator to cache the result of the decorated function. Uses the
provided `key` function to compute the key for cache lookup. `key`
receives all arguments passed to the function.
"""

def deco(func):
cache = {}

@functools.wraps(func)
def inner(*args, **kwargs):
cache_key = key(*args, **kwargs)
if cache_key not in cache:
result = func(*args, **kwargs)
cache[cache_key] = result
# `cache_key` *must* be in `cache`, use `.get()`
# as it is faster:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was surprised to read that, and chatgpt does not agree.

My prompt was: "If obj is a Python dict, is obj.get(key) faster than obj[key]?"

I recommend keeping this code straightforward and intuitive: return cache[cache_key]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. I think my memory served me poorly here, and the quick benchmark I did turns out to give inconsistent results. I changed it to just [].

return cache.get(cache_key)

return inner

return deco
Empty file.
16 changes: 16 additions & 0 deletions python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
Utilities for extracting information from `__cuda_array_interface__`.
"""

import numpy as np

from ..typing import DeviceArrayLike


def get_dtype(arr: DeviceArrayLike) -> np.dtype:
return np.dtype(arr.__cuda_array_interface__["typestr"])
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

from .. import _cccl as cccl
from .._bindings import get_paths, get_bindings
from .._caching import cache_with_key
from ..typing import DeviceArrayLike
from ..iterators._iterators import IteratorBase
from .._utils import cai as cai
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oversight? (delete as cai)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 56f2c61.



class _Op:
Expand Down Expand Up @@ -41,12 +45,18 @@ def _dtype_validation(dt1, dt2):

class _Reduce:
# TODO: constructor shouldn't require concrete `d_in`, `d_out`:
def __init__(self, d_in, d_out, op: Callable, h_init: np.ndarray):
def __init__(
self,
d_in: DeviceArrayLike | IteratorBase,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: Python 3.7 isn't going to like this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you reminding yourself to use Union[DeviceArrayLike, IteratorBase]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, I used from __future__ import annotations which will make migrating easy.

d_out: DeviceArrayLike,
op: Callable,
h_init: np.ndarray,
):
d_in_cccl = cccl.to_cccl_iter(d_in)
self._ctor_d_in_cccl_type_enum_name = cccl.type_enum_as_name(
d_in_cccl.value_type.type.value
)
self._ctor_d_out_dtype = d_out.dtype
self._ctor_d_out_dtype = cai.get_dtype(d_out)
self._ctor_init_dtype = h_init.dtype
cc_major, cc_minor = cuda.get_current_device().compute_capability
cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths()
Expand Down Expand Up @@ -119,9 +129,28 @@ def __del__(self):
bindings.cccl_device_reduce_cleanup(ctypes.byref(self.build_result))


def make_cache_key(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike,
op: Callable,
h_init: np.ndarray,
):
d_in_key = d_in if isinstance(d_in, IteratorBase) else cai.get_dtype(d_in)
d_out_key = d_out if isinstance(d_out, IteratorBase) else cai.get_dtype(d_out)
op_key = (op.__code__.co_code, op.__code__.co_consts, op.__closure__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The d_out type hint is DeviceArrayLike only, but the rhs of the d_out_key expression tests for IteratorBase.

It'll not be great if this code is copy-pasted as we add more algorithms.

One idea would be to introduce helper functions, but that would only be slightly better. I wonder if we could do much better, right in your decorator. You could loop over args and kwargs, and use isinstance() and potentially typing.get_type_hints() to check for supported argument types; we only have a few. This could be fast and compact. The entire make_cache_key() function wouldn't be needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The d_out type hint is DeviceArrayLike only, but the rhs of the d_out_key expression tests for IteratorBase.

Ah, nice observation. I changed it to just cai.get_dtype(d_out). 56f2c61.

I wonder if we could do much better, right in your decorator.

We could, but I'd prefer separating the concerns here, at the risk of a tiny bit of logic repetition across usages of cache_with_key. It also helps keep cache_with_key smaller, more generic, and more explainable if we didn't specialize it to deal with DeviceArrayLike arguments in a specific way.

h_init_key = h_init.dtype
return (d_in_key, d_out_key, op_key, h_init_key)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

important: there's an implicit state affecting the key. As written, I'd get the same reducer object for devices of different architectures:

cudaSetDevice(0)
reducer_1 = reduce_into(d_in, d_out, ...)
cudaSetDevice(1)
reducer_2 = reduce_into(d_in, d_out, ...)

Let's incorporate cc_major, cc_minor = cuda.get_current_device().compute_capability somewhere in the key to address that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. I did that in 56f2c61.



# TODO Figure out `sum` without operator and initial value
# TODO Accept stream
def reduce_into(d_in, d_out, op: Callable, h_init: np.ndarray):
@cache_with_key(make_cache_key)
def reduce_into(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike,
op: Callable,
h_init: np.ndarray,
):
"""Computes a device-wide reduction using the specified binary ``op`` functor and initial value ``init``.

Example:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from llvmlite import ir
from numba.core.extending import intrinsic, overload
from numba.core.typing.ctypes_utils import to_ctypes
from numba.cuda.dispatcher import CUDADispatcher
from numba import cuda, types
import numba
import numpy as np
Expand All @@ -15,6 +16,19 @@
_DEVICE_POINTER_BITWIDTH = _DEVICE_POINTER_SIZE * 8


def _compare_funcs(func1, func2):
# return True if the functions compare equal for
# caching purposes, False otherwise
code1 = func1.__code__
code2 = func2.__code__

return (
code1.co_code == code2.co_code
and code1.co_consts == code2.co_consts
and func1.__closure__ == func2.__closure__
)


@lru_cache(maxsize=256) # TODO: what's a reasonable value?
def cached_compile(func, sig, abi_name=None, **kwargs):
return cuda.compile(func, sig, abi_info={"abi_name": abi_name}, **kwargs)
Expand All @@ -24,12 +38,10 @@ class IteratorBase:
"""
An Iterator is a wrapper around a pointer, and must define the following:

- a `state` property that returns a `ctypes.c_void_p` object, representing
a pointer to some data.
- an `advance` (static) method that receives the state pointer and performs
- an `advance` (static) method that receives the pointer and performs
an action that advances the pointer by the offset `distance`
(returns nothing).
- a `dereference` (static) method that dereferences the state pointer
- a `dereference` (static) method that dereferences the pointer
and returns a value.

Iterators are not meant to be used directly. They are constructed and passed
Expand All @@ -38,18 +50,28 @@ class IteratorBase:
The `advance` and `dereference` must be compilable to device code by numba.
"""

def __init__(self, numba_type: types.Type, value_type: types.Type, abi_name: str):
def __init__(
self,
cvalue: ctypes.c_void_p,
numba_type: types.Type,
value_type: types.Type,
abi_name: str,
):
"""
Parameters
----------
cvalue
A ctypes type representing the object pointed to by the iterator.
numba_type
A numba type that specifies how to interpret the state pointer.
A numba type representing the type of the input to the advance
and dereference functions.
value_type
The numba type of the value returned by the dereference operation.
abi_name
A unique identifier that will determine the abi_names for the
advance and dereference operations.
"""
self.cvalue = cvalue
self.numba_type = numba_type
self.value_type = value_type
self.abi_name = abi_name
Expand Down Expand Up @@ -81,7 +103,7 @@ def ltoirs(self) -> Dict[str, bytes]:

@property
def state(self) -> ctypes.c_void_p:
raise NotImplementedError("Subclasses must override advance staticmethod")
return ctypes.cast(ctypes.pointer(self.cvalue), ctypes.c_void_p)

@staticmethod
def advance(state, distance):
Expand All @@ -91,6 +113,21 @@ def advance(state, distance):
def dereference(state):
raise NotImplementedError("Subclasses must override dereference staticmethod")

def __hash__(self):
return hash(
(self.cvalue.value, self.numba_type, self.value_type, self.abi_name)
)

def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (
self.cvalue.value == other.cvalue.value
and self.numba_type == other.numba_type
and self.value_type == other.value_type
and self.abi_name == other.abi_name
)


def sizeof_pointee(context, ptr):
size = context.get_abi_sizeof(ptr.type.pointee)
Expand Down Expand Up @@ -125,10 +162,11 @@ def impl(ptr, offset):
class RawPointer(IteratorBase):
def __init__(self, ptr: int, ntype: types.Type):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename ntype to value_type everywhere while you're at it? And then inline the assignments right in the super().__init__() calls.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 56f2c61.

value_type = ntype
self._cvalue = ctypes.c_void_p(ptr)
cvalue = ctypes.c_void_p(ptr)
numba_type = types.CPointer(types.CPointer(value_type))
abi_name = f"{self.__class__.__name__}_{str(value_type)}"
super().__init__(
cvalue=cvalue,
numba_type=numba_type,
value_type=value_type,
abi_name=abi_name,
Expand All @@ -142,10 +180,6 @@ def advance(state, distance):
def dereference(state):
return state[0][0]

@property
def state(self) -> ctypes.c_void_p:
return ctypes.cast(ctypes.pointer(self._cvalue), ctypes.c_void_p)


def pointer(container, ntype: types.Type) -> RawPointer:
return RawPointer(container.__cuda_array_interface__["data"][0], ntype)
Expand Down Expand Up @@ -174,11 +208,12 @@ def codegen(context, builder, sig, args):

class CacheModifiedPointer(IteratorBase):
def __init__(self, ptr: int, ntype: types.Type):
self._cvalue = ctypes.c_void_p(ptr)
cvalue = ctypes.c_void_p(ptr)
value_type = ntype
numba_type = types.CPointer(types.CPointer(value_type))
abi_name = f"{self.__class__.__name__}_{str(value_type)}"
super().__init__(
cvalue=cvalue,
numba_type=numba_type,
value_type=value_type,
abi_name=abi_name,
Expand All @@ -192,18 +227,15 @@ def advance(state, distance):
def dereference(state):
return load_cs(state[0])

@property
def state(self) -> ctypes.c_void_p:
return ctypes.cast(ctypes.pointer(self._cvalue), ctypes.c_void_p)


class ConstantIterator(IteratorBase):
def __init__(self, value: np.number):
value_type = numba.from_dtype(value.dtype)
self._cvalue = to_ctypes(value_type)(value)
cvalue = to_ctypes(value_type)(value)
numba_type = types.CPointer(value_type)
abi_name = f"{self.__class__.__name__}_{str(value_type)}"
super().__init__(
cvalue=cvalue,
numba_type=numba_type,
value_type=value_type,
abi_name=abi_name,
Expand All @@ -217,18 +249,15 @@ def advance(state, distance):
def dereference(state):
return state[0]

@property
def state(self) -> ctypes.c_void_p:
return ctypes.cast(ctypes.pointer(self._cvalue), ctypes.c_void_p)


class CountingIterator(IteratorBase):
def __init__(self, value: np.number):
value_type = numba.from_dtype(value.dtype)
self._cvalue = to_ctypes(value_type)(value)
cvalue = to_ctypes(value_type)(value)
numba_type = types.CPointer(value_type)
abi_name = f"{self.__class__.__name__}_{str(value_type)}"
super().__init__(
cvalue=cvalue,
numba_type=numba_type,
value_type=value_type,
abi_name=abi_name,
Expand All @@ -242,10 +271,6 @@ def advance(state, distance):
def dereference(state):
return state[0]

@property
def state(self) -> ctypes.c_void_p:
return ctypes.cast(ctypes.pointer(self._cvalue), ctypes.c_void_p)


def make_transform_iterator(it, op: Callable):
if hasattr(it, "__cuda_array_interface__"):
Expand All @@ -256,8 +281,9 @@ def make_transform_iterator(it, op: Callable):
op = cuda.jit(op, device=True)

class TransformIterator(IteratorBase):
def __init__(self, it: IteratorBase, op):
def __init__(self, it: IteratorBase, op: CUDADispatcher):
self._it = it
self._op = op
numba_type = it.numba_type
# TODO: the abi name below isn't unique enough when we have e.g.,
# two identically named `op` functions with different
Expand All @@ -276,6 +302,7 @@ def __init__(self, it: IteratorBase, op):
value_type = op_retty
abi_name = f"{self.__class__.__name__}_{it.abi_name}_{op_abi_name}"
super().__init__(
cvalue=it.cvalue,
numba_type=numba_type,
value_type=value_type,
abi_name=abi_name,
Expand All @@ -289,8 +316,20 @@ def advance(state, distance):
def dereference(state):
return op(it_dereference(state))

@property
def state(self) -> ctypes.c_void_p:
return it.state
def __hash__(self):
return hash(
(
self._it,
self._op.py_func.__code__.co_code,
self._op.py_func.__closure__,
)
)

def __eq__(self, other):
if not isinstance(other, IteratorBase):
return NotImplemented
return self._it == other._it and _compare_funcs(
self._op.py_func, other._op.py_func
)

return TransformIterator(it, op)
10 changes: 10 additions & 0 deletions python/cuda_parallel/cuda/parallel/experimental/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Protocol


class DeviceArrayLike(Protocol):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TypeWithCUDAArrayInterface

would be much more expressive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used DeviceArrayLike to match NumPy's ArrayLike protocol. Not tying it to CAI would also enable us to extend this to a union type to include objects supporting other protocols (like dlpack).

"""
Objects representing a device array, having a `.__cuda_array_interface__`
attribute.
"""

__cuda_array_interface__: dict
Loading
Loading