Skip to content

Commit

Permalink
Add update_capture_dependencies flags
Browse files Browse the repository at this point in the history
  • Loading branch information
nimlgen committed Oct 20, 2023
1 parent 66d7e36 commit f232a64
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
4 changes: 2 additions & 2 deletions examples/demo_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@
func_plus(a_gpu, numpy.int32(2), block=(4, 4, 1), stream=stream_1)
_, _, graph, deps = stream_1.get_capture_info_v2()
first_node = graph.add_kernel_node(b_gpu, numpy.int32(3), block=(4, 4, 1), func=func_plus, dependencies=deps)
stream_1.update_capture_dependencies([first_node], 1)
stream_1.update_capture_dependencies([first_node], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)

_, _, graph, deps = stream_1.get_capture_info_v2()
second_node = graph.add_kernel_node(a_gpu, b_gpu, block=(4, 4, 1), func=func_times, dependencies=deps)
stream_1.update_capture_dependencies([second_node], 1)
stream_1.update_capture_dependencies([second_node], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)
cuda.memcpy_dtoh_async(result, a_gpu, stream_1)

graph = stream_1.end_capture()
Expand Down
10 changes: 9 additions & 1 deletion src/wrapper/wrap_cudadrv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,12 @@ BOOST_PYTHON_MODULE(_driver)
.value("ACTIVE", CU_STREAM_CAPTURE_STATUS_ACTIVE)
.value("INVALIDATED", CU_STREAM_CAPTURE_STATUS_INVALIDATED)
;
#endif
#if CUDAPP_CUDA_VERSION >= 11030
py::enum_<CUstreamUpdateCaptureDependencies_flags>("update_capture_dependencies_flags")
.value("ADD_CAPTURE_DEPENDENCIES", CU_STREAM_ADD_CAPTURE_DEPENDENCIES)
.value("SET_CAPTURE_DEPENDENCIES", CU_STREAM_SET_CAPTURE_DEPENDENCIES)
;
#endif
{
typedef stream cl;
Expand All @@ -1294,7 +1300,9 @@ BOOST_PYTHON_MODULE(_driver)
py::return_value_policy<py::manage_new_object>())
.def("get_capture_info_v2", &cl::get_capture_info_v2)
#if CUDAPP_CUDA_VERSION >= 11030
.def("update_capture_dependencies", &cl::update_capture_dependencies)
.def("update_capture_dependencies", &cl::update_capture_dependencies,
(py::arg("dependencies"),
py::arg("flags") = CU_STREAM_ADD_CAPTURE_DEPENDENCIES))
#endif
#endif
.add_property("handle", &cl::handle_int)
Expand Down
6 changes: 3 additions & 3 deletions test/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_dynamic_params(self):
assert stat == drv.capture_status.ACTIVE, "Capture should be active"
assert len(deps) == 0, "Nothing on deps"
newnode = x_graph.add_kernel_node(a_gpu, numpy.int32(3), block=(4, 4, 1), func=func_plus, dependencies=deps)
stream_1.update_capture_dependencies([newnode], 1)
stream_1.update_capture_dependencies([newnode], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)
drv.memcpy_dtoh_async(result, a_gpu, stream_1) # Capture a copy as well.
graph = stream_1.end_capture()
assert graph == x_graph, "Should be the same"
Expand Down Expand Up @@ -110,11 +110,11 @@ def test_many_dynamic_params(self):
assert stat == drv.capture_status.ACTIVE, "Capture should be active"
assert len(deps) == 0, "Nothing on deps"
newnode = x_graph.add_kernel_node(a_gpu, numpy.int32(3), block=(4, 4, 1), func=func_plus, dependencies=deps)
stream_1.update_capture_dependencies([newnode], 1)
stream_1.update_capture_dependencies([newnode], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)
_, _, x_graph, deps = stream_1.get_capture_info_v2()
assert deps == [newnode], "Call to update_capture_dependencies should set newnode as the only dep"
newnode2 = x_graph.add_kernel_node(b_gpu, numpy.int32(3), block=(4, 4, 1), func=func_plus, dependencies=deps)
stream_1.update_capture_dependencies([newnode2], 1)
stream_1.update_capture_dependencies([newnode2], cuda.update_capture_dependencies_flags.SET_CAPTURE_DEPENDENCIES)

# Static capture
func_times(a_gpu, b_gpu, block=(4, 4, 1), stream=stream_1)
Expand Down

0 comments on commit f232a64

Please sign in to comment.