Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global CUDA buffer #1

Draft
wants to merge 1 commit into
base: gswirski/cuda
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions prover/src/gpu/cuda/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! This module contains GPU acceleration logic for Nvidia CUDA devices.
use std::marker::PhantomData;
use std::{cell::RefCell, marker::PhantomData, mem::MaybeUninit};

use air::{AuxRandElements, PartitionOptions};
use miden_gpu::{
Expand Down Expand Up @@ -32,33 +32,40 @@ const DIGEST_SIZE: usize = Rpo256::DIGEST_RANGE.end - Rpo256::DIGEST_RANGE.start
// ================================================================================================

/// Wraps an [ExecutionProver] and provides GPU acceleration for building trace commitments.
pub(crate) struct CudaExecutionProver<H, D, R>
pub(crate) struct CudaExecutionProver<'g, H, D, R>
where
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField>,
D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>,
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
{
main: RefCell<&'g mut [MaybeUninit<Felt>]>,
aux: RefCell<&'g mut [MaybeUninit<Felt>]>,
ce: RefCell<&'g mut [MaybeUninit<Felt>]>,

pub execution_prover: ExecutionProver<H, R>,
pub hash_fn: HashFn,
phantom_data: PhantomData<D>,
}

impl<H, D, R> CudaExecutionProver<H, D, R>
impl<'g, H, D, R> CudaExecutionProver<'g, H, D, R>
where
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField>,
D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>,
R: RandomCoin<BaseField = Felt, Hasher = H> + Send,
{
pub fn new(execution_prover: ExecutionProver<H, R>, hash_fn: HashFn) -> Self {
pub fn new(execution_prover: ExecutionProver<H, R>, hash_fn: HashFn, main: &'g mut [MaybeUninit<Felt>], aux: &'g mut [MaybeUninit<Felt>], ce: &'g mut [MaybeUninit<Felt>]) -> Self {
CudaExecutionProver {
main: RefCell::new(main),
aux: RefCell::new(aux),
ce: RefCell::new(ce),
execution_prover,
hash_fn,
phantom_data: PhantomData,
}
}
}

impl<H, D, R> Prover for CudaExecutionProver<H, D, R>
impl<'g, H, D, R> Prover for CudaExecutionProver<'g, H, D, R>
where
H: Hasher<Digest = D> + ElementHasher<BaseField = R::BaseField> + Send + Sync,
D: Digest + From<[Felt; DIGEST_SIZE]> + Into<[Felt; DIGEST_SIZE]>,
Expand All @@ -67,11 +74,11 @@ where
type BaseField = Felt;
type Air = ProcessorAir;
type Trace = ExecutionTrace;
type VC = MerkleTree<Self::HashFn>;
type VC = MerkleTree<'g, Self::HashFn>;
type HashFn = H;
type RandomCoin = R;
type TraceLde<E: FieldElement<BaseField = Felt>> = CudaTraceLde<E, H>;
type ConstraintCommitment<E: FieldElement<BaseField = Felt>> = CudaConstraintCommitment<E, H>;
type TraceLde<E: FieldElement<BaseField = Felt>> = CudaTraceLde<'g, E, H>;
type ConstraintCommitment<E: FieldElement<BaseField = Felt>> = CudaConstraintCommitment<'g, E, H>;
type ConstraintEvaluator<'a, E: FieldElement<BaseField = Felt>> =
DefaultConstraintEvaluator<'a, ProcessorAir, E>;

Expand All @@ -90,7 +97,7 @@ where
domain: &StarkDomain<Felt>,
partition_options: PartitionOptions,
) -> (Self::TraceLde<E>, TracePolyTable<E>) {
CudaTraceLde::new(trace_info, main_trace, domain, partition_options, self.hash_fn)
CudaTraceLde::new(self.main.take(), self.aux.take(), trace_info, main_trace, domain, partition_options, self.hash_fn)
}

fn new_evaluator<'a, E: FieldElement<BaseField = Felt>>(
Expand Down Expand Up @@ -125,6 +132,7 @@ where
E: FieldElement<BaseField = Self::BaseField>,
{
CudaConstraintCommitment::new(
self.ce.take(),
composition_poly_trace,
num_constraint_composition_columns,
domain,
Expand Down
46 changes: 42 additions & 4 deletions prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ extern crate std;

use core::marker::PhantomData;

use air::{AuxRandElements, PartitionOptions, ProcessorAir, PublicInputs};
use air::{trace::{AUX_TRACE_WIDTH, TRACE_WIDTH}, AuxRandElements, PartitionOptions, ProcessorAir, PublicInputs};
#[cfg(all(target_arch = "x86_64", feature = "cuda"))]
use miden_gpu::cuda::util::{struct_size, CudaStorageOwned};
#[cfg(any(
all(feature = "metal", target_arch = "aarch64", target_os = "macos"),
all(feature = "cuda", target_arch = "x86_64")
Expand All @@ -20,7 +22,7 @@ use processor::{
RpxRandomCoin, WinterRandomCoin,
},
math::{Felt, FieldElement},
ExecutionTrace, Program,
ExecutionTrace, Program, QuadExtension,
};
use tracing::instrument;
use winter_maybe_async::{maybe_async, maybe_await};
Expand Down Expand Up @@ -48,6 +50,37 @@ pub use winter_prover::{crypto::MerkleTree as MerkleTreeVC, Proof};
// PROVER
// ================================================================================================

#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
#[instrument("allocate_memory", skip_all)]
fn allocate_memory(trace: &ExecutionTrace, options: &ProvingOptions) -> CudaStorageOwned {
use winter_prover::{math::fields::CubeExtension, Air};

let main_columns = TRACE_WIDTH;
let aux_columns = AUX_TRACE_WIDTH;
let rows = trace.get_trace_len();
let options: WinterProofOptions = options.clone().into();
let extension = options.field_extension();
let blowup = options.blowup_factor();
let partitions = options.partition_options();

let main = struct_size::<Felt>(main_columns, rows, blowup, partitions);
let aux = match extension {
FieldExtension::None => struct_size::<Felt>(aux_columns, rows, blowup, partitions),
FieldExtension::Quadratic => struct_size::<QuadExtension<Felt>>(aux_columns, rows, blowup, partitions),
FieldExtension::Cubic => struct_size::<CubeExtension<Felt>>(aux_columns, rows, blowup, partitions),
};

let air = ProcessorAir::new(trace.info().clone(), PublicInputs::new(Default::default(), Default::default(), Default::default()), options);
let ce_columns = air.context().num_constraint_composition_columns();
let ce = match extension {
FieldExtension::None => struct_size::<Felt>(ce_columns, rows, blowup, partitions),
FieldExtension::Quadratic => struct_size::<QuadExtension<Felt>>(ce_columns, rows, blowup, partitions),
FieldExtension::Cubic => struct_size::<CubeExtension<Felt>>(ce_columns, rows, blowup, partitions),
};

CudaStorageOwned::new(main, aux, ce)
}

/// Executes and proves the specified `program` and returns the result together with a STARK-based
/// proof of the program's execution.
///
Expand Down Expand Up @@ -84,6 +117,11 @@ pub fn prove(
let stack_outputs = trace.stack_outputs().clone();
let hash_fn = options.hash_fn();

#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
let mut storage = allocate_memory(&trace, &options);
#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
let (main, aux, ce) = storage.borrow_mut();

// generate STARK proof
let proof = match hash_fn {
HashFunction::Blake3_192 => {
Expand Down Expand Up @@ -111,7 +149,7 @@ pub fn prove(
#[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))]
let prover = gpu::metal::MetalExecutionProver::new(prover, HashFn::Rpo256);
#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpo256);
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpo256, main, aux, ce);
maybe_await!(prover.prove(trace))
},
HashFunction::Rpx256 => {
Expand All @@ -123,7 +161,7 @@ pub fn prove(
#[cfg(all(feature = "metal", target_arch = "aarch64", target_os = "macos"))]
let prover = gpu::metal::MetalExecutionProver::new(prover, HashFn::Rpx256);
#[cfg(all(feature = "cuda", target_arch = "x86_64"))]
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpx256);
let prover = gpu::cuda::CudaExecutionProver::new(prover, HashFn::Rpx256, main, aux, ce);
maybe_await!(prover.prove(trace))
},
}
Expand Down