diff --git a/Cargo.toml b/Cargo.toml index b76c118..1a25cae 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,6 @@ members = [ exclude = [ "crates/optix/examples/common", - "crates/cuda_std_cg", ] [profile.dev.package.rustc_codegen_nvvm] diff --git a/Justfile b/Justfile deleted file mode 100644 index af59104..0000000 --- a/Justfile +++ /dev/null @@ -1,7 +0,0 @@ -build_cuda_std_cg: - #!/usr/bin/env bash - set -euxo pipefail - nvcc --ptx -arch=sm_75 \ - -I crates/cuda_std_cg/src -I${CUDA_ROOT}/include \ - --device-c crates/cuda_std_cg/src/cg_bridge.cu \ - -o crates/cuda_std_cg/cg_bridge.ptx diff --git a/crates/cuda_builder/Cargo.toml b/crates/cuda_builder/Cargo.toml index 53de4fd..1ab6a51 100644 --- a/crates/cuda_builder/Cargo.toml +++ b/crates/cuda_builder/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "cuda_builder" -version = "0.3.0" +version = "0.4.0" edition = "2021" authors = ["Riccardo D'Ambrosio ", "The rust-gpu Authors"] license = "MIT OR Apache-2.0" @@ -9,8 +9,16 @@ repository = "https://github.com/Rust-GPU/Rust-CUDA" readme = "../../README.md" [dependencies] +anyhow = "1" +thiserror = "1" +cc = { version = "1", default-features = false, optional = true } +cust = { path = "../cust", optional = true } rustc_codegen_nvvm = { version = "0.3", path = "../rustc_codegen_nvvm" } nvvm = { path = "../nvvm", version = "0.1" } serde = { version = "1.0.130", features = ["derive"] } serde_json = "1.0.68" find_cuda_helper = { version = "0.2", path = "../find_cuda_helper" } + +[features] +default = [] +cooperative_groups = ["cc", "cust"] diff --git a/crates/cuda_std_cg/src/cg_bridge.cu b/crates/cuda_builder/cg/cg_bridge.cu similarity index 95% rename from crates/cuda_std_cg/src/cg_bridge.cu rename to crates/cuda_builder/cg/cg_bridge.cu index 110d9f2..51a5869 100644 --- a/crates/cuda_std_cg/src/cg_bridge.cu +++ b/crates/cuda_builder/cg/cg_bridge.cu @@ -1,6 +1,5 @@ #include "cooperative_groups.h" #include "cg_bridge.cuh" -// #include namespace cg = cooperative_groups; __device__ GridGroup this_grid() @@ -24,7 +23,6 @@ __device__ bool GridGroup_is_valid(GridGroup gg) __device__ void GridGroup_sync(GridGroup gg) { - // std::printf("calling sync from bridge"); GridGroupWrapper* g = static_cast(gg); return g->gg.sync(); } diff --git a/crates/cuda_std_cg/src/cg_bridge.cuh b/crates/cuda_builder/cg/cg_bridge.cuh similarity index 100% rename from crates/cuda_std_cg/src/cg_bridge.cuh rename to crates/cuda_builder/cg/cg_bridge.cuh diff --git a/crates/cuda_builder/src/cg.rs b/crates/cuda_builder/src/cg.rs new file mode 100644 index 0000000..186708e --- /dev/null +++ b/crates/cuda_builder/src/cg.rs @@ -0,0 +1,174 @@ +//! Cooperative Groups compilation and linking. + +use std::path::{Path, PathBuf}; + +use anyhow::Context; + +use crate::{CudaBuilderError, CudaBuilderResult}; + +/// An artifact which may be linked together with the Cooperative Groups API bridge PTX code. +pub enum LinkableArtifact { + /// A PTX artifact. + Ptx(PathBuf), + /// A cubin artifact. + Cubin(PathBuf), + /// A fatbin artifact. + Fatbin(PathBuf), +} + +impl LinkableArtifact { + /// Add this artifact to the given linker. + fn link_artifact(&self, linker: &mut cust::link::Linker) -> CudaBuilderResult<()> { + match &self { + LinkableArtifact::Ptx(path) => { + let mut data = std::fs::read_to_string(&path).with_context(|| { + format!("error reading PTX file for linking, file={:?}", path) + })?; + if !data.ends_with('\0') { + // If the PTX is not null-terminated, then linking will fail. Only required for PTX. + data.push('\0'); + } + linker + .add_ptx(&data) + .with_context(|| format!("error linking PTX file={:?}", path))?; + } + LinkableArtifact::Cubin(path) => { + let data = std::fs::read(&path).with_context(|| { + format!("error reading cubin file for linking, file={:?}", path) + })?; + linker + .add_cubin(&data) + .with_context(|| format!("error linking cubin file={:?}", path))?; + } + LinkableArtifact::Fatbin(path) => { + let data = std::fs::read(&path).with_context(|| { + format!("error reading fatbin file for linking, file={:?}", path) + })?; + linker + .add_fatbin(&data) + .with_context(|| format!("error linking fatbin file={:?}", path))?; + } + } + Ok(()) + } +} + +/// A builder which will compile the Cooperative Groups API bridging code, and will then link it +/// together with any other artifacts provided to this builder. +/// +/// The result of this process will be a `cubin` file containing the linked Cooperative Groups +/// PTX code along with any other linked artifacts provided to this builder. The output `cubin` +/// may then be loaded via `cust::module::Module::from_cubin(..)` and used as normal. +#[derive(Default)] +pub struct CooperativeGroups { + /// Artifacts to be linked together with the Cooperative Groups bridge code. + artifacts: Vec, + /// Flags to pass to nvcc for Cooperative Groups API bridge compilation. + nvcc_flags: Vec, +} + +impl CooperativeGroups { + /// Construct a new instance. + pub fn new() -> Self { + Self::default() + } + + /// Add the artifact at the given path for linking. + /// + /// This only applies to linking with the Cooperative Groups API bridge code. Typically, + /// this will be the PTX of your main program which has already been built via `CudaBuilder`. + pub fn link(mut self, artifact: LinkableArtifact) -> Self { + self.artifacts.push(artifact); + self + } + + /// Add a flag to be passed along to `nvcc` during compilation of the Cooperative Groups API bridge code. + /// + /// This provides maximum flexibility for code generation. If needed, multiple architectures + /// may be generated by adding the appropriate flags to the `nvcc` call. + /// + /// By default, `nvcc` will generate code for `sm_52`. Override by specifying any of `--gpu-architecture`, + /// `--gpu-code`, or `--generate-code` flags. + /// + /// Regardless of the flags added via this method, this builder will always added the following flags: + /// - `-I/include`: ensuring `cooperative_groups.h` can be found. + /// - `-Icg`: ensuring the bridging header can be found. + /// - `--ptx`: forces the compiled output to be in PTX form. + /// - `--device-c`: to compile the bridging code as relocatable device code. + /// - `src/cg_bridge.cu` will be added as the code to be compiled, which generates the + /// Cooperative Groups API bridge. + /// + /// Docs: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#command-option-description + pub fn nvcc_flag(mut self, val: impl AsRef) -> Self { + self.nvcc_flags.push(val.as_ref().to_string()); + self + } + + /// Compile the Cooperative Groups API bridging code, and then link it together + /// with any other artifacts provided to this builder. + /// + /// - `cg_out` specifies the output location for the Cooperative Groups API bridge PTX. + /// - `cubin_out` specifies the output location for the fully linked `cubin`. + /// + /// ## Errors + /// - At least one artifact must be provided to this builder for linking. + /// - Any errors which take place from the `nvcc` compilation of the Cooperative Groups briding + /// code, or any errors which take place during module linking. + pub fn compile( + mut self, + cg_out: impl AsRef, + cubin_out: impl AsRef, + ) -> CudaBuilderResult<()> { + // Perform some initial validation. + if self.artifacts.is_empty() { + return Err(anyhow::anyhow!("must provide at least 1 ptx/cubin/fatbin artifact to be linked with the Cooperative Groups API bridge code").into()); + } + + // Find the cuda installation directory for compilation of CG API. + let cuda_root = + find_cuda_helper::find_cuda_root().ok_or(CudaBuilderError::CudaRootNotFound)?; + let cuda_include = cuda_root.join("include"); + let cg_src = std::path::Path::new(std::file!()) + .parent() + .context("error accessing parent dir cuda_builder/src")? + .parent() + .context("error accessing parent dir cuda_builder")? + .join("cg") + .canonicalize() + .context("error taking canonical path to cooperative groups API bridge code")?; + let cg_bridge_cu = cg_src.join("cg_bridge.cu"); + + // Build up the `nvcc` invocation and then build the bridging code. + let mut nvcc = std::process::Command::new("nvcc"); + nvcc.arg(format!("-I{:?}", &cuda_include).as_str()) + .arg(format!("-I{:?}", &cg_src).as_str()) + .arg("--ptx") + .arg("-o") + .arg(cg_out.as_ref().to_string_lossy().as_ref()) + .arg("--device-c") + .arg(cg_bridge_cu.to_string_lossy().as_ref()); + for flag in self.nvcc_flags.iter() { + nvcc.arg(flag.as_str()); + } + nvcc.status() + .context("error calling nvcc for Cooperative Groups API bridge compilation")?; + + // Link together the briding code with any given PTX/cubin/fatbin artifacts. + let _ctx = cust::quick_init().context("error building cuda context")?; + let mut linker = cust::link::Linker::new().context("error building cust linker")?; + self.artifacts + .push(LinkableArtifact::Ptx(cg_out.as_ref().to_path_buf())); + for artifact in self.artifacts.iter() { + artifact.link_artifact(&mut linker)?; + } + let linked_cubin = linker + .complete() + .context("error linking artifacts with Cooperative Groups API bridge PTX")?; + + // Write finalized cubin. + std::fs::write(&cubin_out, &linked_cubin) + .with_context(|| format!("error writing linked cubin to {:?}", cubin_out.as_ref()))?; + + Ok(()) + } +} diff --git a/crates/cuda_builder/src/lib.rs b/crates/cuda_builder/src/lib.rs index e5b1e60..8552484 100644 --- a/crates/cuda_builder/src/lib.rs +++ b/crates/cuda_builder/src/lib.rs @@ -1,36 +1,37 @@ //! Utility crate for easily building CUDA crates using rustc_codegen_nvvm. Derived from rust-gpu's spirv_builder. +#[cfg(feature = "cooperative_groups")] +pub mod cg; + pub use nvvm::*; use serde::Deserialize; use std::{ borrow::Borrow, env, ffi::OsString, - fmt, path::{Path, PathBuf}, process::{Command, Stdio}, }; -#[derive(Debug)] +/// Cuda builder result type. +pub type CudaBuilderResult = Result; + +/// Cuda builder error type. +#[derive(thiserror::Error, Debug)] #[non_exhaustive] pub enum CudaBuilderError { + #[error("crate path {0} does not exist")] CratePathDoesntExist(PathBuf), - FailedToCopyPtxFile(std::io::Error), + #[error("build failed")] BuildFailed, -} - -impl fmt::Display for CudaBuilderError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - CudaBuilderError::CratePathDoesntExist(path) => { - write!(f, "Crate path {} does not exist", path.display()) - } - CudaBuilderError::BuildFailed => f.write_str("Build failed"), - CudaBuilderError::FailedToCopyPtxFile(err) => { - f.write_str(&format!("Failed to copy PTX file: {:?}", err)) - } - } - } + #[error("failed to copy PTX file: {0:?}")] + FailedToCopyPtxFile(#[from] std::io::Error), + #[cfg(feature = "cooperative_groups")] + #[error("could not find cuda root installation dir")] + CudaRootNotFound, + #[cfg(feature = "cooperative_groups")] + #[error("compilation of the Cooperative Groups API bridge code failed: {0}")] + CGError(#[from] anyhow::Error), } #[derive(Debug, Clone, Copy, PartialEq)] diff --git a/crates/cuda_std/Cargo.toml b/crates/cuda_std/Cargo.toml index 0a25f7e..ec6893e 100644 --- a/crates/cuda_std/Cargo.toml +++ b/crates/cuda_std/Cargo.toml @@ -13,3 +13,10 @@ cuda_std_macros = { version = "0.2", path = "../cuda_std_macros" } half = "1.7.1" bitflags = "1.3.2" paste = "1.0.5" + +[features] +default = [] +cooperative_groups = [] + +[package.metadata.docs.rs] +all-features = true diff --git a/crates/cuda_std/src/cg.rs b/crates/cuda_std/src/cg.rs index 5265798..cb01e21 100644 --- a/crates/cuda_std/src/cg.rs +++ b/crates/cuda_std/src/cg.rs @@ -1,3 +1,5 @@ +//! Cuda Cooperative Groups API interface. + use crate::gpu_only; mod ffi { diff --git a/crates/cuda_std/src/lib.rs b/crates/cuda_std/src/lib.rs index a2f38d4..e8c43a0 100644 --- a/crates/cuda_std/src/lib.rs +++ b/crates/cuda_std/src/lib.rs @@ -46,6 +46,7 @@ pub mod misc; // pub mod rt; pub mod atomic; pub mod cfg; +#[cfg(feature = "cooperative_groups")] pub mod cg; pub mod ptr; pub mod shared; diff --git a/crates/cuda_std_cg/cg_bridge.ptx b/crates/cuda_std_cg/cg_bridge.ptx deleted file mode 100644 index dd84c8c..0000000 --- a/crates/cuda_std_cg/cg_bridge.ptx +++ /dev/null @@ -1,304 +0,0 @@ -// -// Generated by NVIDIA NVVM Compiler -// -// Compiler Build ID: CL-31442593 -// Cuda compilation tools, release 11.7, V11.7.99 -// Based on NVVM 7.0.1 -// - -.version 7.7 -.target sm_75 -.address_size 64 - - // .globl this_grid -.extern .func (.param .b64 func_retval0) malloc -( - .param .b64 malloc_param_0 -) -; -.extern .func free -( - .param .b64 free_param_0 -) -; -.weak .global .align 4 .b8 _ZZN4cuda3std3__48__detail21__stronger_order_cudaEiiE7__xform[16] = {3, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 3, 0, 0, 0}; - -.visible .func (.param .b64 func_retval0) this_grid() -{ - .reg .pred %p<2>; - .reg .b32 %r<3>; - .reg .b64 %rd<9>; - - - // begin inline asm - mov.u32 %r1, %envreg2; - // end inline asm - cvt.u64.u32 %rd5, %r1; - // begin inline asm - mov.u32 %r2, %envreg1; - // end inline asm - cvt.u64.u32 %rd6, %r2; - bfi.b64 %rd1, %rd6, %rd5, 32, 32; - mov.u64 %rd7, 16; - { // callseq 0, 0 - .reg .b32 temp_param_reg; - .param .b64 param0; - st.param.b64 [param0+0], %rd7; - .param .b64 retval0; - call.uni (retval0), - malloc, - ( - param0 - ); - ld.param.b64 %rd2, [retval0+0]; - } // callseq 0 - setp.eq.s64 %p1, %rd2, 0; - mov.u64 %rd8, 0; - @%p1 bra $L__BB0_2; - - st.u64 [%rd2], %rd1; - mov.u64 %rd8, %rd2; - -$L__BB0_2: - st.param.b64 [func_retval0+0], %rd8; - ret; - -} - // .globl GridGroup_destroy -.visible .func GridGroup_destroy( - .param .b64 GridGroup_destroy_param_0 -) -{ - .reg .b64 %rd<2>; - - - ld.param.u64 %rd1, [GridGroup_destroy_param_0]; - { // callseq 1, 0 - .reg .b32 temp_param_reg; - .param .b64 param0; - st.param.b64 [param0+0], %rd1; - call.uni - free, - ( - param0 - ); - } // callseq 1 - ret; - -} - // .globl GridGroup_is_valid -.visible .func (.param .b32 func_retval0) GridGroup_is_valid( - .param .b64 GridGroup_is_valid_param_0 -) -{ - .reg .pred %p<2>; - .reg .b32 %r<2>; - .reg .b64 %rd<3>; - - - ld.param.u64 %rd1, [GridGroup_is_valid_param_0]; - ld.u64 %rd2, [%rd1]; - setp.ne.s64 %p1, %rd2, 0; - selp.u32 %r1, 1, 0, %p1; - st.param.b32 [func_retval0+0], %r1; - ret; - -} - // .globl GridGroup_sync -.visible .func GridGroup_sync( - .param .b64 GridGroup_sync_param_0 -) -{ - .reg .pred %p<5>; - .reg .b32 %r<24>; - .reg .b64 %rd<9>; - - - ld.param.u64 %rd1, [GridGroup_sync_param_0]; - ld.u64 %rd8, [%rd1]; - setp.ne.s64 %p1, %rd8, 0; - @%p1 bra $L__BB3_2; - - // begin inline asm - trap; - // end inline asm - ld.u64 %rd8, [%rd1]; - -$L__BB3_2: - mov.u32 %r2, %tid.y; - mov.u32 %r3, %tid.x; - add.s32 %r4, %r3, %r2; - mov.u32 %r5, %tid.z; - neg.s32 %r6, %r5; - setp.ne.s32 %p2, %r4, %r6; - bar.sync 0; - @%p2 bra $L__BB3_6; - - add.s64 %rd6, %rd8, 4; - mov.u32 %r9, %ctaid.z; - neg.s32 %r10, %r9; - mov.u32 %r11, %ctaid.x; - mov.u32 %r12, %ctaid.y; - add.s32 %r13, %r11, %r12; - setp.eq.s32 %p3, %r13, %r10; - mov.u32 %r14, %nctaid.z; - mov.u32 %r15, %nctaid.x; - mov.u32 %r16, %nctaid.y; - mul.lo.s32 %r17, %r15, %r16; - mul.lo.s32 %r18, %r17, %r14; - mov.u32 %r19, -2147483647; - sub.s32 %r20, %r19, %r18; - selp.b32 %r8, %r20, 1, %p3; - membar.gl; - // begin inline asm - atom.add.release.gpu.u32 %r7,[%rd6],%r8; - // end inline asm - -$L__BB3_4: - ld.volatile.u32 %r21, [%rd6]; - xor.b32 %r22, %r21, %r7; - setp.gt.s32 %p4, %r22, -1; - @%p4 bra $L__BB3_4; - - // begin inline asm - ld.acquire.gpu.u32 %r23,[%rd6]; - // end inline asm - -$L__BB3_6: - bar.sync 0; - ret; - -} - // .globl GridGroup_size -.visible .func (.param .b64 func_retval0) GridGroup_size( - .param .b64 GridGroup_size_param_0 -) -{ - .reg .b32 %r<10>; - .reg .b64 %rd<4>; - - - mov.u32 %r1, %nctaid.x; - mov.u32 %r2, %nctaid.y; - mov.u32 %r3, %nctaid.z; - mul.lo.s32 %r4, %r2, %r3; - mul.wide.u32 %rd1, %r4, %r1; - mov.u32 %r5, %ntid.x; - mov.u32 %r6, %ntid.y; - mul.lo.s32 %r7, %r5, %r6; - mov.u32 %r8, %ntid.z; - mul.lo.s32 %r9, %r7, %r8; - cvt.u64.u32 %rd2, %r9; - mul.lo.s64 %rd3, %rd1, %rd2; - st.param.b64 [func_retval0+0], %rd3; - ret; - -} - // .globl GridGroup_thread_rank -.visible .func (.param .b64 func_retval0) GridGroup_thread_rank( - .param .b64 GridGroup_thread_rank_param_0 -) -{ - .reg .b32 %r<16>; - .reg .b64 %rd<12>; - - - mov.u32 %r1, %ctaid.x; - mov.u32 %r2, %ctaid.y; - mov.u32 %r3, %ctaid.z; - mov.u32 %r4, %nctaid.x; - mov.u32 %r5, %nctaid.y; - mul.wide.u32 %rd1, %r5, %r3; - cvt.u64.u32 %rd2, %r4; - cvt.u64.u32 %rd3, %r2; - add.s64 %rd4, %rd1, %rd3; - mul.lo.s64 %rd5, %rd4, %rd2; - cvt.u64.u32 %rd6, %r1; - add.s64 %rd7, %rd5, %rd6; - mov.u32 %r6, %ntid.x; - mov.u32 %r7, %ntid.y; - mul.lo.s32 %r8, %r6, %r7; - mov.u32 %r9, %ntid.z; - mul.lo.s32 %r10, %r8, %r9; - cvt.u64.u32 %rd8, %r10; - mul.lo.s64 %rd9, %rd7, %rd8; - mov.u32 %r11, %tid.x; - mov.u32 %r12, %tid.y; - mov.u32 %r13, %tid.z; - mad.lo.s32 %r14, %r7, %r13, %r12; - mad.lo.s32 %r15, %r14, %r6, %r11; - cvt.u64.u32 %rd10, %r15; - add.s64 %rd11, %rd9, %rd10; - st.param.b64 [func_retval0+0], %rd11; - ret; - -} - // .globl GridGroup_num_threads -.visible .func (.param .b64 func_retval0) GridGroup_num_threads( - .param .b64 GridGroup_num_threads_param_0 -) -{ - .reg .b32 %r<10>; - .reg .b64 %rd<4>; - - - mov.u32 %r1, %nctaid.x; - mov.u32 %r2, %nctaid.y; - mov.u32 %r3, %nctaid.z; - mul.lo.s32 %r4, %r2, %r3; - mul.wide.u32 %rd1, %r4, %r1; - mov.u32 %r5, %ntid.x; - mov.u32 %r6, %ntid.y; - mul.lo.s32 %r7, %r5, %r6; - mov.u32 %r8, %ntid.z; - mul.lo.s32 %r9, %r7, %r8; - cvt.u64.u32 %rd2, %r9; - mul.lo.s64 %rd3, %rd1, %rd2; - st.param.b64 [func_retval0+0], %rd3; - ret; - -} - // .globl GridGroup_num_blocks -.visible .func (.param .b64 func_retval0) GridGroup_num_blocks( - .param .b64 GridGroup_num_blocks_param_0 -) -{ - .reg .b32 %r<5>; - .reg .b64 %rd<2>; - - - mov.u32 %r1, %nctaid.x; - mov.u32 %r2, %nctaid.y; - mov.u32 %r3, %nctaid.z; - mul.lo.s32 %r4, %r2, %r3; - mul.wide.u32 %rd1, %r4, %r1; - st.param.b64 [func_retval0+0], %rd1; - ret; - -} - // .globl GridGroup_block_rank -.visible .func (.param .b64 func_retval0) GridGroup_block_rank( - .param .b64 GridGroup_block_rank_param_0 -) -{ - .reg .b32 %r<6>; - .reg .b64 %rd<8>; - - - mov.u32 %r1, %ctaid.x; - mov.u32 %r2, %ctaid.y; - mov.u32 %r3, %ctaid.z; - mov.u32 %r4, %nctaid.x; - mov.u32 %r5, %nctaid.y; - mul.wide.u32 %rd1, %r5, %r3; - cvt.u64.u32 %rd2, %r4; - cvt.u64.u32 %rd3, %r2; - add.s64 %rd4, %rd1, %rd3; - mul.lo.s64 %rd5, %rd4, %rd2; - cvt.u64.u32 %rd6, %r1; - add.s64 %rd7, %rd5, %rd6; - st.param.b64 [func_retval0+0], %rd7; - ret; - -} - diff --git a/crates/cust/src/error.rs b/crates/cust/src/error.rs index 8857c5b..9e7224a 100644 --- a/crates/cust/src/error.rs +++ b/crates/cust/src/error.rs @@ -215,10 +215,7 @@ impl ToResult for cudaError_enum { } cudaError_enum::CUDA_ERROR_NOT_PERMITTED => Err(CudaError::NotPermitted), cudaError_enum::CUDA_ERROR_NOT_SUPPORTED => Err(CudaError::NotSupported), - err => { - println!("error encountered: {:?}", err); - Err(CudaError::UnknownError) - } + err => Err(CudaError::UnknownError), } } } diff --git a/crates/cust/src/function.rs b/crates/cust/src/function.rs index 3f6faf2..07c06ca 100644 --- a/crates/cust/src/function.rs +++ b/crates/cust/src/function.rs @@ -545,3 +545,38 @@ macro_rules! launch { } }; } + +/// Launch a cooperative kernel function asynchronously. +/// +/// This macro is the same as `launch!`, except that it will launch kernels using the driver API +/// `cuLaunchCooperativeKernel` function. +#[macro_export] +macro_rules! launch_cooperative { + ($module:ident . $function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* $(,)?)) => { + { + let function = $module.get_function(stringify!($function)); + match function { + Ok(f) => launch_cooperative!(f<<<$grid, $block, $shared, $stream>>>( $($arg),* ) ), + Err(e) => Err(e), + } + } + }; + ($function:ident <<<$grid:expr, $block:expr, $shared:expr, $stream:ident>>>( $( $arg:expr),* $(,)?)) => { + { + fn assert_impl_devicecopy(_val: T) {} + if false { + $( + assert_impl_devicecopy($arg); + )* + }; + + $stream.launch_cooperative(&$function, $grid, $block, $shared, + &[ + $( + &$arg as *const _ as *mut ::std::ffi::c_void, + )* + ] + ) + } + }; +} diff --git a/crates/cust/src/stream.rs b/crates/cust/src/stream.rs index 6c83703..aedb4a3 100644 --- a/crates/cust/src/stream.rs +++ b/crates/cust/src/stream.rs @@ -262,49 +262,38 @@ impl Stream { let grid_size: GridSize = grid_size.into(); let block_size: BlockSize = block_size.into(); - // cuda::cuLaunchKernel( - // f: CUfunction, - // gridDimX: ::std::os::raw::c_uint, - // gridDimY: ::std::os::raw::c_uint, - // gridDimZ: ::std::os::raw::c_uint, - // blockDimX: ::std::os::raw::c_uint, - // blockDimY: ::std::os::raw::c_uint, - // blockDimZ: ::std::os::raw::c_uint, - // sharedMemBytes: ::std::os::raw::c_uint, - // hStream: CUstream, - // kernelParams: *mut *mut ::std::os::raw::c_void, - // extra: *mut *mut ::std::os::raw::c_void, - // ).to_result(); - - // cuda::cuLaunchCooperativeKernel( - // f: CUfunction, - // gridDimX: ::std::os::raw::c_uint, - // gridDimY: ::std::os::raw::c_uint, - // gridDimZ: ::std::os::raw::c_uint, - // blockDimX: ::std::os::raw::c_uint, - // blockDimY: ::std::os::raw::c_uint, - // blockDimZ: ::std::os::raw::c_uint, - // sharedMemBytes: ::std::os::raw::c_uint, - // hStream: CUstream, - // kernelParams: *mut *mut ::std::os::raw::c_void, - // ).to_result(); - - // cuda::cuLaunchKernel( - // func.to_raw(), - // grid_size.x, - // grid_size.y, - // grid_size.z, - // block_size.x, - // block_size.y, - // block_size.z, - // shared_mem_bytes, - // self.inner, - // args.as_ptr() as *mut _, - // ptr::null_mut(), - // ) - // .to_result() + cuda::cuLaunchKernel( + func.to_raw(), + grid_size.x, + grid_size.y, + grid_size.z, + block_size.x, + block_size.y, + block_size.z, + shared_mem_bytes, + self.inner, + args.as_ptr() as *mut _, + ptr::null_mut(), + ) + .to_result() + } - // TODO: make this configurable based on invocation patterns. For now, just testing. + // Hidden implementation detail function. Highly unsafe. Use the `launch!` macro instead. + #[doc(hidden)] + pub unsafe fn launch_cooperative( + &self, + func: &Function, + grid_size: G, + block_size: B, + shared_mem_bytes: u32, + args: &[*mut c_void], + ) -> CudaResult<()> + where + G: Into, + B: Into, + { + let grid_size: GridSize = grid_size.into(); + let block_size: BlockSize = block_size.into(); cuda::cuLaunchCooperativeKernel( func.to_raw(), diff --git a/crates/nvvm/src/lib.rs b/crates/nvvm/src/lib.rs index e8bae63..c92c668 100644 --- a/crates/nvvm/src/lib.rs +++ b/crates/nvvm/src/lib.rs @@ -254,6 +254,8 @@ impl FromStr for NvvmOption { "72" => NvvmArch::Compute72, "75" => NvvmArch::Compute75, "80" => NvvmArch::Compute80, + "86" => NvvmArch::Compute86, + "87" => NvvmArch::Compute87, _ => return Err("unknown arch"), }; Self::Arch(arch) @@ -278,6 +280,8 @@ pub enum NvvmArch { Compute72, Compute75, Compute80, + Compute86, + Compute87, } impl Display for NvvmArch { @@ -432,6 +436,8 @@ mod tests { "-arch=compute_72", "-arch=compute_75", "-arch=compute_80", + "-arch=compute_86", + "-arch=compute_87", "-ftz=1", "-prec-sqrt=0", "-prec-div=0", @@ -453,6 +459,8 @@ mod tests { Arch(Compute72), Arch(Compute75), Arch(Compute80), + Arch(Compute86), + Arch(Compute87), Ftz, FastSqrt, FastDiv,