Skip to content

Commit

Permalink
Init cudagraph
Browse files Browse the repository at this point in the history
Co-authored-by: Mit Kotak <[email protected]>
Co-authored-by: Gerlof Fokkema <[email protected]>
  • Loading branch information
3 people committed Sep 13, 2023
1 parent 44bff55 commit 9b3c043
Show file tree
Hide file tree
Showing 6 changed files with 745 additions and 22 deletions.
125 changes: 125 additions & 0 deletions doc/driver.rst
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,22 @@ Constants

.. attribute:: LAZY_ENABLE_PEER_ACCESS

.. class:: capture_mode

CUDA 10 and newer.

.. attribute:: GLOBAL
.. attribute:: THREAD_LOCAL
.. attribute:: RELAXED

.. class:: capture_status

CUDA 10 and newer.

.. attribute:: NONE
.. attribute:: ACTIVE
.. attribute:: INVALIDATED


Graphics-related constants
^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -845,6 +861,43 @@ Concurrency and Streams

.. versionadded:: 2011.1

.. method:: begin_capture(capture_mode=capture_mode.GLOBAL)

Begins graph stream capture on a stream.

When a stream is in capture mode, all operations pushed into the stream
will not be executed, but will instead be captured into a graph.

:arg capture_mode: A :class:`capture_mode` specifying mode for capturing graph.

CUDA 10 and above.

.. method:: end_capture()

Ends stream capture and returns a :class:`Graph` object.

CUDA 10 and above.

.. method:: get_capture_info_v2()

Query a stream's capture state.

Return a :class:`tuple` of (:class:`capture_status` capture status, :class:`int` id for the capture sequence,
:class:`Graph` the graph being captured into, a :class:`list` of :class:`GraphNode` specifying set of nodes the
next node to be captured in the stream will depend on)

CUDA 10 and above.

.. method:: update_capture_dependencies(dependencies, flags)

Modifies the dependency set of a capturing stream.
The dependency set is the set of nodes that the next captured node in the stream will depend on.

:arg dependencies: A :class:`list` of :class:`GraphNode` specifying the new list of dependencies.
:arg flags: A :class:`int` controlling whether the set passed to the API is added to the existing set or replaces it.

CUDA 11.3 and above.

.. class:: Event(flags=0)

An event is a temporal 'marker' in a :class:`Stream` that allows taking the time
Expand Down Expand Up @@ -895,6 +948,78 @@ Concurrency and Streams

.. versionadded: 2011.2
CUDAGraphs
----------

CUDA 10.0 and above

Launching a simple kernel using CUDAGraphs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. literalinclude:: ../examples/cudagraph_kernel.py

.. class:: GraphNode

An object representing a node on :class:`Graph`.

Wraps `cuGraphNode <https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1gc72514a94dacc85ed0617f979211079c>`

.. class:: GraphExec

An executable graph to be launched on a stream.

Wraps `cuGraphExec <https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1gf0abeceeaa9f0a39592fe36a538ea1f0>`_

.. method:: launch(stream_py=None)

Launches an executable graph in a stream.

:arg stream_py: :class:`Stream` object specifying device stream.
Will use default stream if *stream_py* is None.

.. method:: kernel_node_set_params(*args, kernel_node, func=None, block=(), grid=(), shared_mem_bytes=0)

Sets a kernel node's parameters. Refer to :meth:`add_kernel_node` for argument specifications.

.. class:: Graph()

A cudagraph is a data dependency graph meant to
serve as an alternative to :class:`Stream`.

Wraps `cuGraph <https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1g69f555c38df5b3fa1ed25efef794739a>`

.. method:: add_kernel_node(*args, func, block, grid=(1, ), dependencies=[], shared_mem_bytes=0)

Returns and adds a :class:`GraphNode` object specifying
kernel node to the graph.

Will be placed at the root of the graph if dependencies
are not specified.

:arg args: *arg1* through *argn* are the positional C arguments to the kernel.
See :meth:`Function.__call__` for more argument details.

:arg func: a :class:`Function`object specifying kernel function.
:arg block: a :class:`tuple` of up to three integer entries specifying the number
of thread blocks to launch, as a multi-dimensional grid.

:arg grid: a :class:`tuple` of up to three integer entries specifying the grid configuration.

:arg dependencies: A :class:`list` of :class:`GraphNode` objects specifying dependency nodes.

:arg shared_mem_bytes: A :class:`int` specifying size of shared memory.

.. method:: instantiate()

Returns and instantiates a :class:`GraphExec` object.

.. method:: debug_dot_print(path)

Returns a DOT file describing graph structure at specifed path.

:arg path: String specifying path for saving DOT file.

Memory
------

Expand Down
57 changes: 57 additions & 0 deletions examples/demo_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Sample source code from the Tutorial Introduction in the documentation.
import pycuda.driver as cuda
import pycuda.autoinit # noqa
from pycuda.compiler import SourceModule

mod = SourceModule("""
__global__ void plus(float *a, int num)
{
int idx = threadIdx.x + threadIdx.y*4;
a[idx] += num;
}
__global__ void times(float *a, float *b)
{
int idx = threadIdx.x + threadIdx.y*4;
a[idx] *= b[idx];
}
""")
func_plus = mod.get_function("plus")
func_times = mod.get_function("times")

import numpy
a = numpy.zeros((4, 4)).astype(numpy.float32)
a_gpu = cuda.mem_alloc_like(a)
b = numpy.zeros((4, 4)).astype(numpy.float32)
b_gpu = cuda.mem_alloc_like(b)
result = numpy.zeros_like(b)
b2_gpu = cuda.mem_alloc_like(b)

stream_1 = cuda.Stream()
stream_1.begin_capture()
cuda.memcpy_htod_async(a_gpu, a, stream_1)
cuda.memcpy_htod_async(b_gpu, b, stream_1)
cuda.memcpy_htod_async(b2_gpu, b, stream_1)
func_plus(a_gpu, numpy.int32(2), block=(4, 4, 1), stream=stream_1)
_, _, graph, deps = stream_1.get_capture_info_v2()
first_node = graph.add_kernel_node(b_gpu, numpy.int32(3), block=(4, 4, 1), func=func_plus, dependencies=deps)
stream_1.update_capture_dependencies([first_node], 1)

_, _, graph, deps = stream_1.get_capture_info_v2()
second_node = graph.add_kernel_node(a_gpu, b_gpu, block=(4, 4, 1), func=func_times, dependencies=deps)
stream_1.update_capture_dependencies([second_node], 1)
cuda.memcpy_dtoh_async(result, a_gpu, stream_1)

graph = stream_1.end_capture()
graph.debug_dot_print("test.dot") # print dotfile of graph
instance = graph.instantiate()

# Setting dynamic parameters
instance.kernel_node_set_params(b2_gpu, numpy.int32(100), block=(4, 4, 1), func=func_plus, kernel_node=first_node)
instance.kernel_node_set_params(a_gpu, b2_gpu, block=(4, 4, 1), func=func_times, kernel_node=second_node)
instance.launch()

print("original arrays:")
print(a)
print(b)
print("(0+2)x(0+100) = 200, using a kernel graph of 3 kernels:")
print(result)
73 changes: 73 additions & 0 deletions pycuda/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,79 @@ def new_func(*args, **kwargs):

_add_functionality()

def _build_arg_buf(args):
handlers = []

arg_data = []
format = ""
for i, arg in enumerate(args):
if isinstance(arg, np.number):
arg_data.append(arg)
format += arg.dtype.char
elif isinstance(arg, (DeviceAllocation, PooledDeviceAllocation)):
arg_data.append(int(arg))
format += "P"
elif isinstance(arg, ArgumentHandler):
handlers.append(arg)
arg_data.append(int(arg.get_device_alloc()))
format += "P"
elif isinstance(arg, np.ndarray):
if isinstance(arg.base, ManagedAllocationOrStub):
arg_data.append(int(arg.base))
format += "P"
else:
arg_data.append(arg)
format += "%ds" % arg.nbytes
elif isinstance(arg, np.void):
arg_data.append(_my_bytes(_memoryview(arg)))
format += "%ds" % arg.itemsize
else:
cai = getattr(arg, "__cuda_array_interface__", None)
if cai:
arg_data.append(cai["data"][0])
format += "P"
continue

try:
gpudata = np.uintp(arg.gpudata)
except AttributeError:
raise TypeError("invalid type on parameter #%d (0-based)" % i)
else:
# for gpuarrays
arg_data.append(int(gpudata))
format += "P"

from pycuda._pvt_struct import pack

return handlers, pack(format, *arg_data)

# {{{ cudagraph

def patch_cudagraph():
def graph_add_kernel_node_call(graph, *args, func, block, grid=(1, ), dependencies=[], shared_mem_bytes=0):
if func is None:
raise ValueError("must specify func")
if block is None:
raise ValueError("must specify block size")
_, arg_buf = _build_arg_buf(args)
return graph._add_kernel_node(dependencies, func, grid, block, arg_buf, shared_mem_bytes)

def exec_graph_set_kernel_node_call(exec_graph, *args, kernel_node, func, block, grid=(1, ), shared_mem_bytes=0):
if kernel_node is None:
raise ValueError("must specify kernel_node")
if func is None:
raise ValueError("must specify func")
if block is None:
raise ValueError("must specify block size")
_, arg_buf = _build_arg_buf(args)
return exec_graph._kernel_node_set_params(kernel_node, func, grid, block, arg_buf, shared_mem_bytes)

Graph.add_kernel_node = graph_add_kernel_node_call
GraphExec.kernel_node_set_params = exec_graph_set_kernel_node_call
if get_version() >= (10,):
patch_cudagraph()

# }}}

# {{{ pagelocked numpy arrays

Expand Down
Loading

0 comments on commit 9b3c043

Please sign in to comment.