diff --git a/prover/src/gpu/cuda/mod.rs b/prover/src/gpu/cuda/mod.rs index 8bdee9d8b..7e0b20b5f 100644 --- a/prover/src/gpu/cuda/mod.rs +++ b/prover/src/gpu/cuda/mod.rs @@ -1,6 +1,6 @@ //! This module contains GPU acceleration logic for Nvidia CUDA devices. -use std::marker::PhantomData; +use std::{cell::RefCell, marker::PhantomData, mem::MaybeUninit}; use air::{AuxRandElements, PartitionOptions}; use miden_gpu::{ @@ -32,25 +32,32 @@ const DIGEST_SIZE: usize = Rpo256::DIGEST_RANGE.end - Rpo256::DIGEST_RANGE.start // ================================================================================================ /// Wraps an [ExecutionProver] and provides GPU acceleration for building trace commitments. -pub(crate) struct CudaExecutionProver +pub(crate) struct CudaExecutionProver<'g, H, D, R> where H: Hasher + ElementHasher, D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>, R: RandomCoin + Send, { + main: RefCell<&'g mut [MaybeUninit]>, + aux: RefCell<&'g mut [MaybeUninit]>, + ce: RefCell<&'g mut [MaybeUninit]>, + pub execution_prover: ExecutionProver, pub hash_fn: HashFn, phantom_data: PhantomData, } -impl CudaExecutionProver +impl<'g, H, D, R> CudaExecutionProver<'g, H, D, R> where H: Hasher + ElementHasher, D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>, R: RandomCoin + Send, { - pub fn new(execution_prover: ExecutionProver, hash_fn: HashFn) -> Self { + pub fn new(execution_prover: ExecutionProver, hash_fn: HashFn, main: &'g mut [MaybeUninit], aux: &'g mut [MaybeUninit], ce: &'g mut [MaybeUninit]) -> Self { CudaExecutionProver { + main: RefCell::new(main), + aux: RefCell::new(aux), + ce: RefCell::new(ce), execution_prover, hash_fn, phantom_data: PhantomData, @@ -58,7 +65,7 @@ where } } -impl Prover for CudaExecutionProver +impl<'g, H, D, R> Prover for CudaExecutionProver<'g, H, D, R> where H: Hasher + ElementHasher + Send + Sync, D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>, @@ -67,11 +74,11 @@ where type BaseField = Felt; type Air = ProcessorAir; type Trace = ExecutionTrace; - type VC = MerkleTree; + type VC = MerkleTree<'g, Self::HashFn>; type HashFn = H; type RandomCoin = R; - type TraceLde> = CudaTraceLde; - type ConstraintCommitment> = CudaConstraintCommitment; + type TraceLde> = CudaTraceLde<'g, E, H>; + type ConstraintCommitment> = CudaConstraintCommitment<'g, E, H>; type ConstraintEvaluator<'a, E: FieldElement> = DefaultConstraintEvaluator<'a, ProcessorAir, E>; @@ -90,7 +97,7 @@ where domain: &StarkDomain, partition_options: PartitionOptions, ) -> (Self::TraceLde, TracePolyTable) { - CudaTraceLde::new(trace_info, main_trace, domain, partition_options, self.hash_fn) + CudaTraceLde::new(self.main.take(), self.aux.take(), trace_info, main_trace, domain, partition_options, self.hash_fn) } fn new_evaluator<'a, E: FieldElement>( @@ -125,6 +132,7 @@ where E: FieldElement, { CudaConstraintCommitment::new( + self.ce.take(), composition_poly_trace, num_constraint_composition_columns, domain, diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 5520b1d94..efb8f1f97 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -8,7 +8,9 @@ extern crate std; use core::marker::PhantomData; -use air::{AuxRandElements, PartitionOptions, ProcessorAir, PublicInputs}; +use air::{trace::{AUX_TRACE_WIDTH, TRACE_WIDTH}, AuxRandElements, PartitionOptions, ProcessorAir, PublicInputs}; +#[cfg(all(target_arch = "x86_64", feature = "cuda"))] +use miden_gpu::cuda::util::{struct_size, CudaStorageOwned}; #[cfg(any( all(feature = "metal", target_arch = "aarch64", target_os = "macos"), all(feature = "cuda", target_arch = "x86_64") @@ -20,7 +22,7 @@ use processor::{ RpxRandomCoin, WinterRandomCoin, }, math::{Felt, FieldElement}, - ExecutionTrace, Program, + ExecutionTrace, Program, QuadExtension, }; use tracing::instrument; use winter_maybe_async::{maybe_async, maybe_await}; @@ -48,6 +50,37 @@ pub use winter_prover::{crypto::MerkleTree as MerkleTreeVC, Proof}; // PROVER // ================================================================================================ +#[cfg(all(feature = "cuda", target_arch = "x86_64"))] +#[instrument("allocate_memory", skip_all)] +fn allocate_memory(trace: &ExecutionTrace, options: &ProvingOptions) -> CudaStorageOwned { + use winter_prover::{math::fields::CubeExtension, Air}; + + let main_columns = TRACE_WIDTH; + let aux_columns = AUX_TRACE_WIDTH; + let rows = trace.get_trace_len(); + let options: WinterProofOptions = options.clone().into(); + let extension = options.field_extension(); + let blowup = options.blowup_factor(); + let partitions = options.partition_options(); + + let main = struct_size::(main_columns, rows, blowup, partitions); + let aux = match extension { + FieldExtension::None => struct_size::(aux_columns, rows, blowup, partitions), + FieldExtension::Quadratic => struct_size::>(aux_columns, rows, blowup, partitions), + FieldExtension::Cubic => struct_size::>(aux_columns, rows, blowup, partitions), + }; + + let air = ProcessorAir::new(trace.info().clone(), PublicInputs::new(Default::default(), Default::default(), Default::default()), options); + let ce_columns = air.context().num_constraint_composition_columns(); + let ce = match extension { + FieldExtension::None => struct_size::(ce_columns, rows, blowup, partitions), + FieldExtension::Quadratic => struct_size::>(ce_columns, rows, blowup, partitions), + FieldExtension::Cubic => struct_size::>(ce_columns, rows, blowup, partitions), + }; + + CudaStorageOwned::new(main, aux, ce) +} + /// Executes and proves the specified `program` and returns the result together with a STARK-based /// proof of the program's execution. /// @@ -84,6 +117,11 @@ pub fn prove( let stack_outputs = trace.stack_outputs().clone(); let hash_fn = options.hash_fn(); + #[cfg(all(feature = "cuda", target_arch = "x86_64"))] + let mut storage = allocate_memory(&trace, &options); + #[cfg(all(feature = "cuda", target_arch = "x86_64"))] + let (main, aux, ce) = storage.borrow_mut(); + // generate STARK proof let proof = match hash_fn { HashFunction::Blake3_192 => { @@ -111,7 +149,7 @@ pub fn prove( #[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))] let prover = gpu::metal::MetalExecutionProver::new(prover, HashFn::Rpo256); #[cfg(all(feature = "cuda", target_arch = "x86_64"))] - let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpo256); + let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpo256, main, aux, ce); maybe_await!(prover.prove(trace)) }, HashFunction::Rpx256 => { @@ -123,7 +161,7 @@ pub fn prove( #[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))] let prover = gpu::metal::MetalExecutionProver::new(prover, HashFn::Rpx256); #[cfg(all(feature = "cuda", target_arch = "x86_64"))] - let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpx256); + let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpx256, main, aux, ce); maybe_await!(prover.prove(trace)) }, }