Skip to content

Commit

Permalink
use an uninitialized global buffer for CUDA
Browse files Browse the repository at this point in the history
  • Loading branch information
gswirski committed Jan 7, 2025
1 parent 401f1a3 commit 98b3671
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 13 deletions.
26 changes: 17 additions & 9 deletions prover/src/gpu/cuda/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! This module contains GPU acceleration logic for Nvidia CUDA devices.
use std::marker::PhantomData;
use std::{cell::RefCell, marker::PhantomData, mem::MaybeUninit};

use air::{AuxRandElements, PartitionOptions};
use miden_gpu::{
Expand Down Expand Up @@ -32,33 +32,40 @@ const DIGEST_SIZE: usize = Rpo256::DIGEST_RANGE.end - Rpo256::DIGEST_RANGE.start
// ================================================================================================

/// Wraps an [ExecutionProver] and provides GPU acceleration for building trace commitments.
pub(crate) struct CudaExecutionProver<H, D, R>
pub(crate) struct CudaExecutionProver<'g, H, D, R>
where
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField>,
D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>,
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
{
main: RefCell<&'g mut [MaybeUninit<Felt>]>,
aux: RefCell<&'g mut [MaybeUninit<Felt>]>,
ce: RefCell<&'g mut [MaybeUninit<Felt>]>,

pub execution_prover: ExecutionProver<H, R>,
pub hash_fn: HashFn,
phantom_data: PhantomData<D>,
}

impl<H, D, R> CudaExecutionProver<H, D, R>
impl<'g, H, D, R> CudaExecutionProver<'g, H, D, R>
where
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField>,
D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>,
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
{
pub fn new(execution_prover: ExecutionProver<H, R>, hash_fn: HashFn) -> Self {
pub fn new(execution_prover: ExecutionProver<H, R>, hash_fn: HashFn, main: &'g mut [MaybeUninit<Felt>], aux: &'g mut [MaybeUninit<Felt>], ce: &'g mut [MaybeUninit<Felt>]) -> Self {
CudaExecutionProver {
main: RefCell::new(main),
aux: RefCell::new(aux),
ce: RefCell::new(ce),
execution_prover,
hash_fn,
phantom_data: PhantomData,
}
}
}

impl<H, D, R> Prover for CudaExecutionProver<H, D, R>
impl<'g, H, D, R> Prover for CudaExecutionProver<'g, H, D, R>
where
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField> + Send + Sync,
D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>,
Expand All @@ -67,11 +74,11 @@ where
type BaseField = Felt;
type Air = ProcessorAir;
type Trace = ExecutionTrace;
type VC = MerkleTree<Self::HashFn>;
type VC = MerkleTree<'g, Self::HashFn>;
type HashFn = H;
type RandomCoin = R;
type TraceLde<E: FieldElement<BaseField = Felt>> = CudaTraceLde<E, H>;
type ConstraintCommitment<E: FieldElement<BaseField = Felt>> = CudaConstraintCommitment<E, H>;
type TraceLde<E: FieldElement<BaseField = Felt>> = CudaTraceLde<'g, E, H>;
type ConstraintCommitment<E: FieldElement<BaseField = Felt>> = CudaConstraintCommitment<'g, E, H>;
type ConstraintEvaluator<'a, E: FieldElement<BaseField = Felt>> =
DefaultConstraintEvaluator<'a, ProcessorAir, E>;

Expand All @@ -90,7 +97,7 @@ where
domain: &StarkDomain<Felt>,
partition_options: PartitionOptions,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
CudaTraceLde::new(trace_info, main_trace, domain, partition_options, self.hash_fn)
CudaTraceLde::new(self.main.take(), self.aux.take(), trace_info, main_trace, domain, partition_options, self.hash_fn)
}

fn new_evaluator<'a, E: FieldElement<BaseField = Felt>>(
Expand Down Expand Up @@ -125,6 +132,7 @@ where
E: FieldElement<BaseField = Self::BaseField>,
{
CudaConstraintCommitment::new(
self.ce.take(),
composition_poly_trace,
num_constraint_composition_columns,
domain,
Expand Down
46 changes: 42 additions & 4 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ extern crate std;

use core::marker::PhantomData;

use air::{AuxRandElements, PartitionOptions, ProcessorAir, PublicInputs};
use air::{trace::{AUX_TRACE_WIDTH, TRACE_WIDTH}, AuxRandElements, PartitionOptions, ProcessorAir, PublicInputs};
#[cfg(all(target_arch = "x86_64", feature = "cuda"))]
use miden_gpu::cuda::util::{struct_size, CudaStorageOwned};
#[cfg(any(
all(feature = "metal", target_arch = "aarch64", target_os = "macos"),
all(feature = "cuda", target_arch = "x86_64")
Expand All @@ -20,7 +22,7 @@ use processor::{
RpxRandomCoin, WinterRandomCoin,
},
math::{Felt, FieldElement},
ExecutionTrace, Program,
ExecutionTrace, Program, QuadExtension,
};
use tracing::instrument;
use winter_maybe_async::{maybe_async, maybe_await};
Expand Down Expand Up @@ -48,6 +50,37 @@ pub use winter_prover::{crypto::MerkleTree as MerkleTreeVC, Proof};
// PROVER
// ================================================================================================

#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
#[instrument("allocate_memory", skip_all)]
fn allocate_memory(trace: &ExecutionTrace, options: &ProvingOptions) -> CudaStorageOwned {
use winter_prover::{math::fields::CubeExtension, Air};

let main_columns = TRACE_WIDTH;
let aux_columns = AUX_TRACE_WIDTH;
let rows = trace.get_trace_len();
let options: WinterProofOptions = options.clone().into();
let extension = options.field_extension();
let blowup = options.blowup_factor();
let partitions = options.partition_options();

let main = struct_size::<Felt>(main_columns, rows, blowup, partitions);
let aux = match extension {
FieldExtension::None => struct_size::<Felt>(aux_columns, rows, blowup, partitions),
FieldExtension::Quadratic => struct_size::<QuadExtension<Felt>>(aux_columns, rows, blowup, partitions),
FieldExtension::Cubic => struct_size::<CubeExtension<Felt>>(aux_columns, rows, blowup, partitions),
};

let air = ProcessorAir::new(trace.info().clone(), PublicInputs::new(Default::default(), Default::default(), Default::default()), options);
let ce_columns = air.context().num_constraint_composition_columns();
let ce = match extension {
FieldExtension::None => struct_size::<Felt>(ce_columns, rows, blowup, partitions),
FieldExtension::Quadratic => struct_size::<QuadExtension<Felt>>(ce_columns, rows, blowup, partitions),
FieldExtension::Cubic => struct_size::<CubeExtension<Felt>>(ce_columns, rows, blowup, partitions),
};

CudaStorageOwned::new(main, aux, ce)
}

/// Executes and proves the specified `program` and returns the result together with a STARK-based
/// proof of the program's execution.
///
Expand Down Expand Up @@ -84,6 +117,11 @@ pub fn prove(
let stack_outputs = trace.stack_outputs().clone();
let hash_fn = options.hash_fn();

#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
let mut storage = allocate_memory(&trace, &options);
#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
let (main, aux, ce) = storage.borrow_mut();

// generate STARK proof
let proof = match hash_fn {
HashFunction::Blake3_192 => {
Expand Down Expand Up @@ -111,7 +149,7 @@ pub fn prove(
#[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))]
let prover = gpu::metal::MetalExecutionProver::new(prover, HashFn::Rpo256);
#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpo256);
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpo256, main, aux, ce);
maybe_await!(prover.prove(trace))
},
HashFunction::Rpx256 => {
Expand All @@ -123,7 +161,7 @@ pub fn prove(
#[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))]
let prover = gpu::metal::MetalExecutionProver::new(prover, HashFn::Rpx256);
#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpx256);
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpx256, main, aux, ce);
maybe_await!(prover.prove(trace))
},
}
Expand Down

0 comments on commit 98b3671

Please sign in to comment.