-
Notifications
You must be signed in to change notification settings - Fork 123
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Cooperative Groups API integration
This works as follows: - Users build their Cuda code via `CudaBuilder` as normal. - If they want to use the cooperative groups API, then in their `build.rs`, just after building their PTX, they will: - Create a `cuda_builder::cg::CooperativeGroups` instance, - Add any needed opts for building the Cooperative Groups API bridge code (`-arch=sm_*` and so on), - Add their newly built PTX code to be linked with the CG API, which can include multiple PTX, cubin or fatbin files, - Call `.compile(..)`, which will spit out a fully linked `cubin`, - In the user's main application code, instead of using `launch!` to schedule their GPU work, they will now use `launch_cooperative!`.
- Loading branch information
Showing
14 changed files
with
464 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[package] | ||
name = "cuda_builder" | ||
version = "0.3.0" | ||
version = "0.4.0" | ||
edition = "2021" | ||
authors = ["Riccardo D'Ambrosio <[email protected]>", "The rust-gpu Authors"] | ||
license = "MIT OR Apache-2.0" | ||
|
@@ -9,8 +9,16 @@ repository = "https://github.com/Rust-GPU/Rust-CUDA" | |
readme = "../../README.md" | ||
|
||
[dependencies] | ||
anyhow = "1" | ||
thiserror = "1" | ||
cc = { version = "1", default-features = false, optional = true } | ||
cust = { path = "../cust", optional = true } | ||
rustc_codegen_nvvm = { version = "0.3", path = "../rustc_codegen_nvvm" } | ||
nvvm = { path = "../nvvm", version = "0.1" } | ||
serde = { version = "1.0.130", features = ["derive"] } | ||
serde_json = "1.0.68" | ||
find_cuda_helper = { version = "0.2", path = "../find_cuda_helper" } | ||
|
||
[features] | ||
default = [] | ||
cooperative_groups = ["cc", "cust"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#include "cooperative_groups.h" | ||
#include "cg_bridge.cuh" | ||
namespace cg = cooperative_groups; | ||
|
||
__device__ GridGroup this_grid() | ||
{ | ||
cg::grid_group gg = cg::this_grid(); | ||
GridGroupWrapper* ggp = new GridGroupWrapper { gg }; | ||
return ggp; | ||
} | ||
|
||
__device__ void GridGroup_destroy(GridGroup gg) | ||
{ | ||
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg); | ||
delete g; | ||
} | ||
|
||
__device__ bool GridGroup_is_valid(GridGroup gg) | ||
{ | ||
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg); | ||
return g->gg.is_valid(); | ||
} | ||
|
||
__device__ void GridGroup_sync(GridGroup gg) | ||
{ | ||
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg); | ||
return g->gg.sync(); | ||
} | ||
|
||
__device__ unsigned long long GridGroup_size(GridGroup gg) | ||
{ | ||
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg); | ||
return g->gg.size(); | ||
} | ||
|
||
__device__ unsigned long long GridGroup_thread_rank(GridGroup gg) | ||
{ | ||
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg); | ||
return g->gg.thread_rank(); | ||
} | ||
|
||
__device__ unsigned long long GridGroup_num_threads(GridGroup gg) | ||
{ | ||
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg); | ||
return g->gg.num_threads(); | ||
} | ||
|
||
__device__ unsigned long long GridGroup_num_blocks(GridGroup gg) | ||
{ | ||
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg); | ||
return g->gg.num_blocks(); | ||
} | ||
|
||
__device__ unsigned long long GridGroup_block_rank(GridGroup gg) | ||
{ | ||
GridGroupWrapper* g = static_cast<GridGroupWrapper*>(gg); | ||
return g->gg.block_rank(); | ||
} | ||
|
||
__host__ int main() | ||
{} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#pragma once | ||
#include "cooperative_groups.h" | ||
namespace cg = cooperative_groups; | ||
|
||
typedef struct GridGroupWrapper { | ||
cg::grid_group gg; | ||
} GridGroupWrapper; | ||
|
||
extern "C" typedef void* GridGroup; | ||
extern "C" __device__ GridGroup this_grid(); | ||
extern "C" __device__ void GridGroup_destroy(GridGroup gg); | ||
extern "C" __device__ bool GridGroup_is_valid(GridGroup gg); | ||
extern "C" __device__ void GridGroup_sync(GridGroup gg); | ||
extern "C" __device__ unsigned long long GridGroup_size(GridGroup gg); | ||
extern "C" __device__ unsigned long long GridGroup_thread_rank(GridGroup gg); | ||
// extern "C" dim3 GridGroup_group_dim(); // TODO: impl these. | ||
extern "C" __device__ unsigned long long GridGroup_num_threads(GridGroup gg); | ||
// extern "C" dim3 GridGroup_dim_blocks(); // TODO: impl these. | ||
extern "C" __device__ unsigned long long GridGroup_num_blocks(GridGroup gg); | ||
// extern "C" dim3 GridGroup_block_index(); // TODO: impl these. | ||
extern "C" __device__ unsigned long long GridGroup_block_rank(GridGroup gg); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
//! Cooperative Groups compilation and linking. | ||
use std::path::{Path, PathBuf}; | ||
|
||
use anyhow::Context; | ||
|
||
use crate::{CudaBuilderError, CudaBuilderResult}; | ||
|
||
/// An artifact which may be linked together with the Cooperative Groups API bridge PTX code. | ||
pub enum LinkableArtifact { | ||
/// A PTX artifact. | ||
Ptx(PathBuf), | ||
/// A cubin artifact. | ||
Cubin(PathBuf), | ||
/// A fatbin artifact. | ||
Fatbin(PathBuf), | ||
} | ||
|
||
impl LinkableArtifact { | ||
/// Add this artifact to the given linker. | ||
fn link_artifact(&self, linker: &mut cust::link::Linker) -> CudaBuilderResult<()> { | ||
match &self { | ||
LinkableArtifact::Ptx(path) => { | ||
let mut data = std::fs::read_to_string(&path).with_context(|| { | ||
format!("error reading PTX file for linking, file={:?}", path) | ||
})?; | ||
if !data.ends_with('\0') { | ||
// If the PTX is not null-terminated, then linking will fail. Only required for PTX. | ||
data.push('\0'); | ||
} | ||
linker | ||
.add_ptx(&data) | ||
.with_context(|| format!("error linking PTX file={:?}", path))?; | ||
} | ||
LinkableArtifact::Cubin(path) => { | ||
let data = std::fs::read(&path).with_context(|| { | ||
format!("error reading cubin file for linking, file={:?}", path) | ||
})?; | ||
linker | ||
.add_cubin(&data) | ||
.with_context(|| format!("error linking cubin file={:?}", path))?; | ||
} | ||
LinkableArtifact::Fatbin(path) => { | ||
let data = std::fs::read(&path).with_context(|| { | ||
format!("error reading fatbin file for linking, file={:?}", path) | ||
})?; | ||
linker | ||
.add_fatbin(&data) | ||
.with_context(|| format!("error linking fatbin file={:?}", path))?; | ||
} | ||
} | ||
Ok(()) | ||
} | ||
} | ||
|
||
/// A builder which will compile the Cooperative Groups API bridging code, and will then link it | ||
/// together with any other artifacts provided to this builder. | ||
/// | ||
/// The result of this process will be a `cubin` file containing the linked Cooperative Groups | ||
/// PTX code along with any other linked artifacts provided to this builder. The output `cubin` | ||
/// may then be loaded via `cust::module::Module::from_cubin(..)` and used as normal. | ||
#[derive(Default)] | ||
pub struct CooperativeGroups { | ||
/// Artifacts to be linked together with the Cooperative Groups bridge code. | ||
artifacts: Vec<LinkableArtifact>, | ||
/// Flags to pass to nvcc for Cooperative Groups API bridge compilation. | ||
nvcc_flags: Vec<String>, | ||
} | ||
|
||
impl CooperativeGroups { | ||
/// Construct a new instance. | ||
pub fn new() -> Self { | ||
Self::default() | ||
} | ||
|
||
/// Add the artifact at the given path for linking. | ||
/// | ||
/// This only applies to linking with the Cooperative Groups API bridge code. Typically, | ||
/// this will be the PTX of your main program which has already been built via `CudaBuilder`. | ||
pub fn link(mut self, artifact: LinkableArtifact) -> Self { | ||
self.artifacts.push(artifact); | ||
self | ||
} | ||
|
||
/// Add a flag to be passed along to `nvcc` during compilation of the Cooperative Groups API bridge code. | ||
/// | ||
/// This provides maximum flexibility for code generation. If needed, multiple architectures | ||
/// may be generated by adding the appropriate flags to the `nvcc` call. | ||
/// | ||
/// By default, `nvcc` will generate code for `sm_52`. Override by specifying any of `--gpu-architecture`, | ||
/// `--gpu-code`, or `--generate-code` flags. | ||
/// | ||
/// Regardless of the flags added via this method, this builder will always added the following flags: | ||
/// - `-I<cudaRoot>/include`: ensuring `cooperative_groups.h` can be found. | ||
/// - `-Icg`: ensuring the bridging header can be found. | ||
/// - `--ptx`: forces the compiled output to be in PTX form. | ||
/// - `--device-c`: to compile the bridging code as relocatable device code. | ||
/// - `src/cg_bridge.cu` will be added as the code to be compiled, which generates the | ||
/// Cooperative Groups API bridge. | ||
/// | ||
/// Docs: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#command-option-description | ||
pub fn nvcc_flag(mut self, val: impl AsRef<str>) -> Self { | ||
self.nvcc_flags.push(val.as_ref().to_string()); | ||
self | ||
} | ||
|
||
/// Compile the Cooperative Groups API bridging code, and then link it together | ||
/// with any other artifacts provided to this builder. | ||
/// | ||
/// - `cg_out` specifies the output location for the Cooperative Groups API bridge PTX. | ||
/// - `cubin_out` specifies the output location for the fully linked `cubin`. | ||
/// | ||
/// ## Errors | ||
/// - At least one artifact must be provided to this builder for linking. | ||
/// - Any errors which take place from the `nvcc` compilation of the Cooperative Groups briding | ||
/// code, or any errors which take place during module linking. | ||
pub fn compile( | ||
mut self, | ||
cg_out: impl AsRef<Path>, | ||
cubin_out: impl AsRef<Path>, | ||
) -> CudaBuilderResult<()> { | ||
// Perform some initial validation. | ||
if self.artifacts.is_empty() { | ||
return Err(anyhow::anyhow!("must provide at least 1 ptx/cubin/fatbin artifact to be linked with the Cooperative Groups API bridge code").into()); | ||
} | ||
|
||
// Find the cuda installation directory for compilation of CG API. | ||
let cuda_root = | ||
find_cuda_helper::find_cuda_root().ok_or(CudaBuilderError::CudaRootNotFound)?; | ||
let cuda_include = cuda_root.join("include"); | ||
let cg_src = std::path::Path::new(std::file!()) | ||
.parent() | ||
.context("error accessing parent dir cuda_builder/src")? | ||
.parent() | ||
.context("error accessing parent dir cuda_builder")? | ||
.join("cg") | ||
.canonicalize() | ||
.context("error taking canonical path to cooperative groups API bridge code")?; | ||
let cg_bridge_cu = cg_src.join("cg_bridge.cu"); | ||
|
||
// Build up the `nvcc` invocation and then build the bridging code. | ||
let mut nvcc = std::process::Command::new("nvcc"); | ||
nvcc.arg(format!("-I{:?}", &cuda_include).as_str()) | ||
.arg(format!("-I{:?}", &cg_src).as_str()) | ||
.arg("--ptx") | ||
.arg("-o") | ||
.arg(cg_out.as_ref().to_string_lossy().as_ref()) | ||
.arg("--device-c") | ||
.arg(cg_bridge_cu.to_string_lossy().as_ref()); | ||
for flag in self.nvcc_flags.iter() { | ||
nvcc.arg(flag.as_str()); | ||
} | ||
nvcc.status() | ||
.context("error calling nvcc for Cooperative Groups API bridge compilation")?; | ||
|
||
// Link together the briding code with any given PTX/cubin/fatbin artifacts. | ||
let _ctx = cust::quick_init().context("error building cuda context")?; | ||
let mut linker = cust::link::Linker::new().context("error building cust linker")?; | ||
self.artifacts | ||
.push(LinkableArtifact::Ptx(cg_out.as_ref().to_path_buf())); | ||
for artifact in self.artifacts.iter() { | ||
artifact.link_artifact(&mut linker)?; | ||
} | ||
let linked_cubin = linker | ||
.complete() | ||
.context("error linking artifacts with Cooperative Groups API bridge PTX")?; | ||
|
||
// Write finalized cubin. | ||
std::fs::write(&cubin_out, &linked_cubin) | ||
.with_context(|| format!("error writing linked cubin to {:?}", cubin_out.as_ref()))?; | ||
|
||
Ok(()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.