Skip to content

Commit

Permalink
use a global buffer for all things cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
gswirski committed Nov 25, 2024
1 parent effb048 commit 849569d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 15 deletions.
27 changes: 17 additions & 10 deletions prover/src/gpu/cuda/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! This module contains GPU acceleration logic for Nvidia CUDA devices.
use std::marker::PhantomData;
use std::{cell::RefCell, marker::PhantomData};

use air::{AuxRandElements, PartitionOptions};
use miden_gpu::{cuda::{constraint::CudaConstraintCommitment, merkle::MerkleTree, trace_lde::CudaTraceLde}, HashFn};
Expand All @@ -27,33 +27,40 @@ const DIGEST_SIZE: usize = Rpo256::DIGEST_RANGE.end - Rpo256::DIGEST_RANGE.start
// ================================================================================================

/// Wraps an [ExecutionProver] and provides GPU acceleration for building trace commitments.
pub(crate) struct CudaExecutionProver<H, D, R>
pub(crate) struct CudaExecutionProver<'g, H, D, R>
where
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField>,
D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>,
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
{
main: RefCell<&'g mut [Felt]>,
aux: RefCell<&'g mut [Felt]>,
ce: RefCell<&'g mut [Felt]>,

pub execution_prover: ExecutionProver<H, R>,
pub hash_fn: HashFn,
phantom_data: PhantomData<D>,
}

impl<H, D, R> CudaExecutionProver<H, D, R>
impl<'g, H, D, R> CudaExecutionProver<'g, H, D, R>
where
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField>,
D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>,
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
{
pub fn new(execution_prover: ExecutionProver<H, R>, hash_fn: HashFn) -> Self {
pub fn new(execution_prover: ExecutionProver<H, R>, hash_fn: HashFn, main: &'g mut [Felt], aux: &'g mut [Felt], ce: &'g mut [Felt]) -> Self {
CudaExecutionProver {
main: RefCell::new(main),
aux: RefCell::new(aux),
ce: RefCell::new(ce),
execution_prover,
hash_fn,
phantom_data: PhantomData,
}
}
}

impl<H, D, R> Prover for CudaExecutionProver<H, D, R>
impl<'g, H, D, R> Prover for CudaExecutionProver<'g, H, D, R>
where
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField> + Send + Sync,
D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>,
Expand All @@ -62,11 +69,11 @@ where
type BaseField = Felt;
type Air = ProcessorAir;
type Trace = ExecutionTrace;
type VC = MerkleTree<Self::HashFn>;
type VC = MerkleTree<'g, Self::HashFn>;
type HashFn = H;
type RandomCoin = R;
type TraceLde<E: FieldElement<BaseField = Felt>> = CudaTraceLde<E, H>;
type ConstraintCommitment<E: FieldElement<BaseField = Felt>> = CudaConstraintCommitment<E, H>;
type TraceLde<E: FieldElement<BaseField = Felt>> = CudaTraceLde<'g, E, H>;
type ConstraintCommitment<E: FieldElement<BaseField = Felt>> = CudaConstraintCommitment<'g, E, H>;
type ConstraintEvaluator<'a, E: FieldElement<BaseField = Felt>> =
DefaultConstraintEvaluator<'a, ProcessorAir, E>;

Expand All @@ -85,7 +92,7 @@ where
domain: &StarkDomain<Felt>,
partition_options: PartitionOptions,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
CudaTraceLde::new(trace_info, main_trace, domain, partition_options, self.hash_fn)
CudaTraceLde::new(self.main.take(), self.aux.take(), trace_info, main_trace, domain, partition_options, self.hash_fn)
}

fn new_evaluator<'a, E: FieldElement<BaseField = Felt>>(
Expand Down Expand Up @@ -119,6 +126,6 @@ where
where
E: FieldElement<BaseField = Self::BaseField>,
{
CudaConstraintCommitment::new(composition_poly_trace, num_constraint_composition_columns, domain, partition_options, self.hash_fn)
CudaConstraintCommitment::new(self.ce.take(), composition_poly_trace, num_constraint_composition_columns, domain, partition_options, self.hash_fn)
}
}
45 changes: 40 additions & 5 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ extern crate alloc;

use core::marker::PhantomData;

use air::{AuxRandElements, PartitionOptions, ProcessorAir, PublicInputs};
use air::{trace::{AUX_TRACE_WIDTH, TRACE_WIDTH}, AuxRandElements, PartitionOptions, ProcessorAir, PublicInputs};
use miden_gpu::cuda::util::{struct_size, CudaStorageOwned};
#[cfg(any(
all(feature = "metal", target_arch = "aarch64", target_os = "macos"),
all(feature = "cuda", target_arch = "x86_64")
Expand All @@ -17,12 +18,12 @@ use processor::{
RpxRandomCoin, WinterRandomCoin,
},
math::{Felt, FieldElement},
ExecutionTrace, Program,
ExecutionTrace, Program, QuadExtension,
};
use tracing::instrument;
use winter_maybe_async::{maybe_async, maybe_await};
use winter_prover::{
matrix::{ColMatrix, RowMatrix}, CompositionPoly, CompositionPolyTrace, ConstraintCompositionCoefficients, DefaultConstraintCommitment, DefaultConstraintEvaluator, DefaultTraceLde, ProofOptions as WinterProofOptions, Prover, StarkDomain, TraceInfo, TracePolyTable
math::fields::CubeExtension, matrix::ColMatrix, Air, CompositionPoly, CompositionPolyTrace, ConstraintCompositionCoefficients, DefaultConstraintCommitment, DefaultConstraintEvaluator, DefaultTraceLde, ProofOptions as WinterProofOptions, Prover, StarkDomain, TraceInfo, TracePolyTable
};
#[cfg(feature = "std")]
use {std::time::Instant, winter_prover::Trace};
Expand All @@ -42,6 +43,35 @@ pub use miden_gpu::cuda::get_num_of_gpus;
// PROVER
// ================================================================================================

#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
#[instrument("allocate_memory", skip_all)]
fn allocate_memory(trace: &ExecutionTrace, options: &ProvingOptions) -> CudaStorageOwned {
let main_columns = TRACE_WIDTH;
let aux_columns = AUX_TRACE_WIDTH;
let rows = trace.get_trace_len();
let options: WinterProofOptions = options.clone().into();
let extension = options.field_extension();
let blowup = options.blowup_factor();
let partitions = options.partition_options();

let main = struct_size::<Felt>(main_columns, rows, blowup, partitions);
let aux = match extension {
FieldExtension::None => struct_size::<Felt>(aux_columns, rows, blowup, partitions),
FieldExtension::Quadratic => struct_size::<QuadExtension<Felt>>(aux_columns, rows, blowup, partitions),
FieldExtension::Cubic => struct_size::<CubeExtension<Felt>>(aux_columns, rows, blowup, partitions),
};

let air = ProcessorAir::new(trace.info().clone(), PublicInputs::new(Default::default(), Default::default(), Default::default()), options);
let ce_columns = air.context().num_constraint_composition_columns();
let ce = match extension {
FieldExtension::None => struct_size::<Felt>(ce_columns, rows, blowup, partitions),
FieldExtension::Quadratic => struct_size::<QuadExtension<Felt>>(ce_columns, rows, blowup, partitions),
FieldExtension::Cubic => struct_size::<CubeExtension<Felt>>(ce_columns, rows, blowup, partitions),
};

CudaStorageOwned::new(main, aux, ce)
}

/// Executes and proves the specified `program` and returns the result together with a STARK-based
/// proof of the program's execution.
///
Expand Down Expand Up @@ -81,6 +111,11 @@ where
let stack_outputs = trace.stack_outputs().clone();
let hash_fn = options.hash_fn();

#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
let mut storage = allocate_memory(&trace, &options);
#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
let (main, aux, ce) = storage.borrow_mut();

// generate STARK proof
let proof = match hash_fn {
HashFunction::Blake3_192 => {
Expand Down Expand Up @@ -108,7 +143,7 @@ where
#[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))]
let prover = gpu::metal::MetalExecutionProver::new(prover, HashFn::Rpo256);
#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpo256);
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpo256, main, aux, ce);
maybe_await!(prover.prove(trace))
},
HashFunction::Rpx256 => {
Expand All @@ -120,7 +155,7 @@ where
#[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))]
let prover = gpu::metal::MetalExecutionProver::new(prover, HashFn::Rpx256);
#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpx256);
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpx256, main, aux, ce);
maybe_await!(prover.prove(trace))
},
}
Expand Down

0 comments on commit 849569d

Please sign in to comment.