Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda.parallel: In-memory caching of cuda.parallel build objects #3216

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions python/cuda_parallel/cuda/parallel/experimental/_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import functools
from numba import cuda


def cache_with_key(key):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key_factory
?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with key as it's the name for the similar argument in e.g., sorted and cachetools.

"""
Decorator to cache the result of the decorated function. Uses the
provided `key` function to compute the key for cache lookup. `key`
receives all arguments passed to the function.

Notes
-----
The CUDA compute capability of the current device is appended to
the cache key returned by `key`.
"""

def deco(func):
cache = {}

@functools.wraps(func)
def inner(*args, **kwargs):
cc = cuda.get_current_device().compute_capability
cache_key = (key(*args, **kwargs), *cc)
if cache_key not in cache:
result = func(*args, **kwargs)
cache[cache_key] = result
return cache[cache_key]

return inner

return deco


class CachableFunction:
"""
A type that wraps a function and provides custom comparison
(__eq__) and hash (__hash__) implementations.

The purpose of this class is to enable caching and comparison of
functions based on their bytecode, constants, and closures, while
ignoring other attributes such as their names or docstrings.
"""

def __init__(self, func):
self._func = func
self._identity = (
self._func.__code__.co_code,
self._func.__code__.co_consts,
self._func.__closure__,
)

def __eq__(self, other):
return self._identity == other._identity

def __hash__(self):
return hash(self._identity)

def __repr__(self):
return str(self._func)
Empty file.
16 changes: 16 additions & 0 deletions python/cuda_parallel/cuda/parallel/experimental/_utils/cai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
Utilities for extracting information from `__cuda_array_interface__`.
"""

import numpy as np

from ..typing import DeviceArrayLike


def get_dtype(arr: DeviceArrayLike) -> np.dtype:
return np.dtype(arr.__cuda_array_interface__["typestr"])
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations # TODO: required for Python 3.7 docs env

import ctypes
import numba
import numpy as np
Expand All @@ -12,6 +14,10 @@

from .. import _cccl as cccl
from .._bindings import get_paths, get_bindings
from .._caching import CachableFunction, cache_with_key
from ..typing import DeviceArrayLike
from ..iterators._iterators import IteratorBase
from .._utils import cai


class _Op:
Expand Down Expand Up @@ -41,12 +47,18 @@ def _dtype_validation(dt1, dt2):

class _Reduce:
# TODO: constructor shouldn't require concrete `d_in`, `d_out`:
def __init__(self, d_in, d_out, op: Callable, h_init: np.ndarray):
def __init__(
self,
d_in: DeviceArrayLike | IteratorBase,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: Python 3.7 isn't going to like this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you reminding yourself to use Union[DeviceArrayLike, IteratorBase]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, I used from __future__ import annotations which will make migrating easy.

d_out: DeviceArrayLike,
op: Callable,
h_init: np.ndarray,
):
d_in_cccl = cccl.to_cccl_iter(d_in)
self._ctor_d_in_cccl_type_enum_name = cccl.type_enum_as_name(
d_in_cccl.value_type.type.value
)
self._ctor_d_out_dtype = d_out.dtype
self._ctor_d_out_dtype = cai.get_dtype(d_out)
self._ctor_init_dtype = h_init.dtype
cc_major, cc_minor = cuda.get_current_device().compute_capability
cub_path, thrust_path, libcudacxx_path, cuda_include_path = get_paths()
Expand Down Expand Up @@ -119,9 +131,28 @@ def __del__(self):
bindings.cccl_device_reduce_cleanup(ctypes.byref(self.build_result))


def make_cache_key(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike,
op: Callable,
h_init: np.ndarray,
):
d_in_key = d_in.kind if isinstance(d_in, IteratorBase) else cai.get_dtype(d_in)
d_out_key = cai.get_dtype(d_out)
op_key = CachableFunction(op)
h_init_key = h_init.dtype
return (d_in_key, d_out_key, op_key, h_init_key)


# TODO Figure out `sum` without operator and initial value
# TODO Accept stream
def reduce_into(d_in, d_out, op: Callable, h_init: np.ndarray):
@cache_with_key(make_cache_key)
def reduce_into(
d_in: DeviceArrayLike | IteratorBase,
d_out: DeviceArrayLike,
op: Callable,
h_init: np.ndarray,
):
"""Computes a device-wide reduction using the specified binary ``op`` functor and initial value ``init``.

Example:
Expand Down
Loading
Loading