diff --git a/zenu-cuda/src/cudnn/batch_norm.rs b/zenu-cuda/src/cudnn/batch_norm.rs new file mode 100644 index 00000000..01a96328 --- /dev/null +++ b/zenu-cuda/src/cudnn/batch_norm.rs @@ -0,0 +1,151 @@ +use crate::ZENU_CUDA_STATE; + +use zenu_cudnn_sys::*; + +use super::{error::ZenuCudnnError, tensor_descriptor_4d, TensorFormat}; + +pub struct BatchNorm2d { + input: cudnnTensorDescriptor_t, + output: cudnnTensorDescriptor_t, + scale_bias_mean_var: cudnnTensorDescriptor_t, + mode: cudnnBatchNormMode_t, +} + +pub struct BatchNorm2dBuilder { + input: Option, + output: Option, + scale_bias_mean_var: Option, + mode: Option, +} + +impl BatchNorm2d { + pub fn forward_train( + &self, + alpha: T, + beta: T, + x: *const std::ffi::c_void, + y: *mut std::ffi::c_void, + scale: *const std::ffi::c_void, + bias: *const std::ffi::c_void, + estimated_mean: *mut std::ffi::c_void, + estimated_variance: *mut std::ffi::c_void, + expotential_average_factor: f64, + result_save_mean: *mut std::ffi::c_void, + result_save_inv_variance: *mut std::ffi::c_void, + ) -> Result<(), ZenuCudnnError> { + let cudnn_handle = ZENU_CUDA_STATE.lock().unwrap().get_cudnn().as_ptr(); + let status = unsafe { + cudnnBatchNormalizationForwardTraining( + cudnn_handle, + self.mode, + &alpha as *const T as *const std::ffi::c_void, + &beta as *const T as *const std::ffi::c_void, + self.input, + x, + self.output, + y, + self.scale_bias_mean_var, + scale as *const T as *const std::ffi::c_void, + bias, + expotential_average_factor, + estimated_mean, + estimated_variance, + 1e-10, + result_save_mean, + result_save_inv_variance, + ) + }; + if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { + return Err(ZenuCudnnError::from(status)); + } + Ok(()) + } +} + +impl Drop for BatchNorm2d { + fn drop(&mut self) { + unsafe { + cudnnDestroyTensorDescriptor(self.input); + cudnnDestroyTensorDescriptor(self.output); + cudnnDestroyTensorDescriptor(self.scale_bias_mean_var); + } + } +} + +impl BatchNorm2dBuilder { + pub fn new() -> Self { + Self { + input: None, + output: None, + scale_bias_mean_var: None, + mode: None, + } + } + + pub fn input( + self, + n: i32, + c: i32, + h: i32, + w: i32, + format: TensorFormat, + ) -> Result { + let input = tensor_descriptor_4d::(n, c, h, w, format)?; + Ok(Self { + input: Some(input), + ..self + }) + } + + pub fn output( + self, + n: i32, + c: i32, + h: i32, + w: i32, + format: TensorFormat, + ) -> Result { + let output = tensor_descriptor_4d::(n, c, h, w, format)?; + Ok(Self { + output: Some(output), + ..self + }) + } + + pub fn scale_bias_mean_var( + self, + n: i32, + c: i32, + h: i32, + w: i32, + format: TensorFormat, + ) -> Result { + let scale_bias_mean_var = tensor_descriptor_4d::(n, c, h, w, format)?; + Ok(Self { + scale_bias_mean_var: Some(scale_bias_mean_var), + ..self + }) + } + + pub fn mode(self, mode: cudnnBatchNormMode_t) -> Self { + Self { + mode: Some(mode), + ..self + } + } + + pub fn build(self) -> Result { + let input = self.input.expect("input is required"); + let output = self.output.expect("output is required"); + let scale_bias_mean_var = self + .scale_bias_mean_var + .expect("scale_bias_mean_var is required"); + let mode = self.mode.expect("mode is required"); + Ok(BatchNorm2d { + input, + output, + scale_bias_mean_var, + mode, + }) + } +} diff --git a/zenu-cuda/src/cudnn/conv.rs b/zenu-cuda/src/cudnn/conv.rs new file mode 100644 index 00000000..94df62a2 --- /dev/null +++ b/zenu-cuda/src/cudnn/conv.rs @@ -0,0 +1,1052 @@ +use super::{error::ZenuCudnnError, tensor_descriptor_4d, zenu_cudnn_data_type, TensorFormat}; + +use crate::ZENU_CUDA_STATE; + +use std::cell::UnsafeCell; + +use zenu_cudnn_sys::*; + +fn filter_descriptor( + k: i32, + c: i32, + h: i32, + w: i32, + format: TensorFormat, +) -> Result { + let data_type = zenu_cudnn_data_type::(); + let format = format.into(); + let mut filter: cudnnFilterDescriptor_t = std::ptr::null_mut(); + unsafe { + let status = cudnnCreateFilterDescriptor(&mut filter as *mut cudnnFilterDescriptor_t); + if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { + return Err(ZenuCudnnError::from(status)); + } + let status = cudnnSetFilter4dDescriptor(filter, data_type, format, k, c, h, w); + if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { + return Err(ZenuCudnnError::from(status)); + } + } + Ok(filter) +} + +fn convolution_descriptor( + pad_h: i32, + pad_w: i32, + stride_h: i32, + stride_w: i32, + dilation_h: i32, + dilation_w: i32, +) -> Result { + let mut conv: cudnnConvolutionDescriptor_t = std::ptr::null_mut(); + unsafe { + let status = + cudnnCreateConvolutionDescriptor(&mut conv as *mut cudnnConvolutionDescriptor_t); + if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { + return Err(ZenuCudnnError::from(status)); + } + let status = cudnnSetConvolution2dDescriptor( + conv, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION, + zenu_cudnn_data_type::(), + ); + if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { + return Err(ZenuCudnnError::from(status)); + } + } + Ok(conv) +} + +fn convolution_algorithm( + input: cudnnTensorDescriptor_t, + filter: cudnnFilterDescriptor_t, + conv: cudnnConvolutionDescriptor_t, + output: cudnnTensorDescriptor_t, + requested_algo_count: usize, +) -> Result { + let state = ZENU_CUDA_STATE.lock().unwrap(); + let handle = state.get_cudnn(); + let mut returned_algo_count = 0; + unsafe { + let mut algorithm: Vec = + Vec::with_capacity(requested_algo_count); + for _ in 0..requested_algo_count { + algorithm.push(cudnnConvolutionFwdAlgoPerf_t::default()); + } + + // enable tensor core + cudnnSetConvolutionMathType(conv, cudnnMathType_t::CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION); + + let state = cudnnGetConvolutionForwardAlgorithm_v7( + handle.as_ptr(), + input, + filter, + conv, + output, + requested_algo_count as i32, + &mut returned_algo_count as *mut i32, + algorithm.as_mut_ptr() as *mut cudnnConvolutionFwdAlgoPerf_t, + ); + if state != cudnnStatus_t::CUDNN_STATUS_SUCCESS { + return Err(ZenuCudnnError::from(state)); + } + + Ok(algorithm[0].algo) + } +} + +fn convolution_backward_data_algorithm( + input: cudnnTensorDescriptor_t, + filter: cudnnFilterDescriptor_t, + conv: cudnnConvolutionDescriptor_t, + output: cudnnTensorDescriptor_t, + requested_algo_count: usize, +) -> Result { + let state = ZENU_CUDA_STATE.lock().unwrap(); + let handle = state.get_cudnn(); + let mut returned_algo_count = 0; + unsafe { + let mut algorithm: Vec = Vec::with_capacity(1); + for _ in 0..requested_algo_count { + algorithm.push(cudnnConvolutionBwdDataAlgoPerf_t::default()); + } + + // enable tensor core + cudnnSetConvolutionMathType(conv, cudnnMathType_t::CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION); + + let state = cudnnGetConvolutionBackwardDataAlgorithm_v7( + handle.as_ptr(), + filter, + input, + conv, + output, + requested_algo_count as i32, + &mut returned_algo_count as *mut i32, + algorithm.as_mut_ptr() as *mut cudnnConvolutionBwdDataAlgoPerf_t, + ); + if state != cudnnStatus_t::CUDNN_STATUS_SUCCESS { + return Err(ZenuCudnnError::from(state)); + } + + Ok(algorithm[0].algo) + } +} + +fn convolution_backward_filter_algorithm( + input: cudnnTensorDescriptor_t, + filter: cudnnFilterDescriptor_t, + conv: cudnnConvolutionDescriptor_t, + output: cudnnTensorDescriptor_t, + requested_algo_count: usize, +) -> Result { + let state = ZENU_CUDA_STATE.lock().unwrap(); + let handle = state.get_cudnn(); + let mut returned_algo_count = 0; + unsafe { + let mut algorithm: Vec = + Vec::with_capacity(requested_algo_count); + for _ in 0..requested_algo_count { + algorithm.push(cudnnConvolutionBwdFilterAlgoPerf_t::default()); + } + + // enable tensor core + cudnnSetConvolutionMathType(conv, cudnnMathType_t::CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION); + + let state = cudnnGetConvolutionBackwardFilterAlgorithm_v7( + handle.as_ptr(), + input, + output, + conv, + filter, + requested_algo_count as i32, + &mut returned_algo_count as *mut i32, + algorithm.as_mut_ptr() as *mut cudnnConvolutionBwdFilterAlgoPerf_t, + ); + if state != cudnnStatus_t::CUDNN_STATUS_SUCCESS { + return Err(ZenuCudnnError::from(state)); + } + + Ok(algorithm[0].algo) + } +} + +fn convolution_workspace( + input: cudnnTensorDescriptor_t, + filter: cudnnFilterDescriptor_t, + conv: cudnnConvolutionDescriptor_t, + output: cudnnTensorDescriptor_t, + algorithm: cudnnConvolutionFwdAlgo_t, +) -> Result { + let state = ZENU_CUDA_STATE.lock().unwrap(); + let handle = state.get_cudnn(); + let mut workspace_size = 0; + unsafe { + let status = cudnnGetConvolutionForwardWorkspaceSize( + handle.as_ptr(), + input, + filter, + conv, + output, + algorithm, + &mut workspace_size as *mut usize, + ); + if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { + panic!("Failed to get convolution forward workspace size"); + } + Ok(Workspace::new(workspace_size)) + } +} + +fn convolution_backward_data_workspace( + input: cudnnTensorDescriptor_t, + filter: cudnnFilterDescriptor_t, + conv: cudnnConvolutionDescriptor_t, + output: cudnnTensorDescriptor_t, + algorithm: cudnnConvolutionBwdDataAlgo_t, +) -> Result { + let state = ZENU_CUDA_STATE.lock().unwrap(); + let handle = state.get_cudnn(); + let mut workspace_size = 0; + unsafe { + let status = cudnnGetConvolutionBackwardDataWorkspaceSize( + handle.as_ptr(), + filter, + output, + conv, + input, + algorithm, + &mut workspace_size as *mut usize, + ); + if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { + panic!("Failed to get convolution backward data workspace size"); + } + Ok(Workspace::new(workspace_size)) + } +} + +fn convolution_backward_filter_workspace( + input: cudnnTensorDescriptor_t, + filter: cudnnFilterDescriptor_t, + conv: cudnnConvolutionDescriptor_t, + output: cudnnTensorDescriptor_t, + algorithm: cudnnConvolutionBwdFilterAlgo_t, +) -> Result { + let state = ZENU_CUDA_STATE.lock().unwrap(); + let handle = state.get_cudnn(); + let mut workspace_size = 0; + unsafe { + let status = cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle.as_ptr(), + input, + output, + conv, + filter, + algorithm, + &mut workspace_size as *mut usize, + ); + if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { + panic!("Failed to get convolution backward filter workspace size"); + } + Ok(Workspace::new(workspace_size)) + } +} + +#[derive(Debug)] +pub struct Workspace { + workspace: UnsafeCell>, + workspace_size: usize, +} + +impl Workspace { + pub fn new(workspace_size: usize) -> Self { + Self { + workspace: UnsafeCell::new(None), + workspace_size, + } + } + + pub fn workspace(&self) -> *mut libc::c_void { + let workspace = unsafe { &mut *self.workspace.get() }; + if workspace.is_none() { + let ptr = unsafe { + let mut ptr = std::ptr::null_mut(); + cudaMalloc(&mut ptr as *mut *mut libc::c_void, self.workspace_size); + ptr + }; + *workspace = Some(ptr); + } + workspace.unwrap() + } + + pub fn free_workspace(&self) { + let workspace = unsafe { &mut *self.workspace.get() }; + if let Some(ptr) = workspace.take() { + unsafe { + cudaFree(ptr); + } + *workspace = None; + } + } +} + +impl Drop for Workspace { + fn drop(&mut self) { + let workspace = unsafe { &mut *self.workspace.get() }; + if let Some(ptr) = workspace.take() { + unsafe { + cudaFree(ptr); + } + } + } +} + +#[derive(Debug)] +pub struct ConvDescriptor { + input: cudnnTensorDescriptor_t, + filter: cudnnFilterDescriptor_t, + conv: cudnnConvolutionDescriptor_t, + output: cudnnTensorDescriptor_t, + algorithm: cudnnConvolutionFwdAlgo_t, + workspace: Workspace, +} + +impl ConvDescriptor { + pub fn forward( + &self, + alpha: T, + input: *const T, + filter: *const T, + beta: T, + output: *mut T, + ) { + let state = ZENU_CUDA_STATE.lock().unwrap(); + let handle = state.get_cudnn(); + unsafe { + cudnnConvolutionForward( + handle.as_ptr(), + &alpha as *const T as *const libc::c_void, + self.input, + input as *const libc::c_void, + self.filter, + filter as *const libc::c_void, + self.conv, + self.algorithm, + self.workspace.workspace(), + self.workspace.workspace_size, + &beta as *const T as *const libc::c_void, + self.output, + output as *mut libc::c_void, + ); + } + } +} + +impl Drop for ConvDescriptor { + fn drop(&mut self) { + unsafe { + cudnnDestroyTensorDescriptor(self.input); + cudnnDestroyFilterDescriptor(self.filter); + cudnnDestroyConvolutionDescriptor(self.conv); + cudnnDestroyTensorDescriptor(self.output); + } + } +} + +#[derive(Debug, Default, PartialEq, Eq, Hash)] +pub struct ConvolutionBuilder { + input: Option, + filter: Option, + conv: Option, + output: Option, + algorithm: Option, +} + +impl ConvolutionBuilder { + pub fn input( + self, + n: i32, + c: i32, + h: i32, + w: i32, + format: TensorFormat, + ) -> Result { + let input = tensor_descriptor_4d::(n, c, h, w, format)?; + Ok(Self { + input: Some(input), + ..self + }) + } + + pub fn filter( + self, + k: i32, + c: i32, + h: i32, + w: i32, + format: TensorFormat, + ) -> Result { + let filter = filter_descriptor::(k, c, h, w, format)?; + Ok(Self { + filter: Some(filter), + ..self + }) + } + + pub fn conv( + self, + pad_h: i32, + pad_w: i32, + stride_h: i32, + stride_w: i32, + dilation_h: i32, + dilation_w: i32, + ) -> Result { + let conv = + convolution_descriptor(pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w)?; + Ok(Self { + conv: Some(conv), + ..self + }) + } + + pub fn output( + self, + n: i32, + c: i32, + h: i32, + w: i32, + format: TensorFormat, + ) -> Result { + let output = tensor_descriptor_4d::(n, c, h, w, format)?; + Ok(Self { + output: Some(output), + ..self + }) + } + + pub fn algorithm(self, requested_algo_count: usize) -> Result { + let input = self.input.unwrap(); + let filter = self.filter.unwrap(); + let conv = self.conv.unwrap(); + let output = self.output.unwrap(); + let algorithm = convolution_algorithm(input, filter, conv, output, requested_algo_count)?; + Ok(Self { + algorithm: Some(algorithm), + ..self + }) + } + + pub fn build(self) -> Result { + let input = self.input.unwrap(); + let filter = self.filter.unwrap(); + let conv = self.conv.unwrap(); + let output = self.output.unwrap(); + let algorithm = self.algorithm.unwrap(); + let workspace = convolution_workspace(input, filter, conv, output, algorithm)?; + Ok(ConvDescriptor { + input, + filter, + conv, + output, + algorithm, + workspace, + }) + } +} + +pub struct ConvolutionBackwardData { + input: cudnnTensorDescriptor_t, + filter: cudnnFilterDescriptor_t, + conv: cudnnConvolutionDescriptor_t, + output: cudnnTensorDescriptor_t, + algorithm: cudnnConvolutionBwdDataAlgo_t, + workspace: Workspace, +} + +impl ConvolutionBackwardData { + pub fn backward_data( + &self, + alpha: T, + filter: *const T, + output: *const T, + beta: T, + input: *mut T, + ) { + let state = ZENU_CUDA_STATE.lock().unwrap(); + let handle = state.get_cudnn(); + unsafe { + cudnnConvolutionBackwardData( + handle.as_ptr(), + &alpha as *const T as *const libc::c_void, + self.filter, + filter as *const libc::c_void, + self.output, + output as *const libc::c_void, + self.conv, + self.algorithm, + self.workspace.workspace(), + self.workspace.workspace_size, + &beta as *const T as *const libc::c_void, + self.input, + input as *mut libc::c_void, + ); + } + } +} + +impl Drop for ConvolutionBackwardData { + fn drop(&mut self) { + unsafe { + cudnnDestroyTensorDescriptor(self.input); + cudnnDestroyFilterDescriptor(self.filter); + cudnnDestroyConvolutionDescriptor(self.conv); + cudnnDestroyTensorDescriptor(self.output); + } + } +} + +#[derive(Debug, Default, PartialEq, Eq, Hash)] +pub struct ConvolutionBackwardDataBuilder { + input: Option, + filter: Option, + conv: Option, + output: Option, + algorithm: Option, +} + +impl ConvolutionBackwardDataBuilder { + pub fn input( + self, + n: i32, + c: i32, + h: i32, + w: i32, + format: TensorFormat, + ) -> Result { + let input = tensor_descriptor_4d::(n, c, h, w, format)?; + Ok(Self { + input: Some(input), + ..self + }) + } + + pub fn filter( + self, + k: i32, + c: i32, + h: i32, + w: i32, + format: TensorFormat, + ) -> Result { + let filter = filter_descriptor::(k, c, h, w, format)?; + Ok(Self { + filter: Some(filter), + ..self + }) + } + + pub fn conv( + self, + pad_h: i32, + pad_w: i32, + stride_h: i32, + stride_w: i32, + dilation_h: i32, + dilation_w: i32, + ) -> Result { + let conv = + convolution_descriptor(pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w)?; + Ok(Self { + conv: Some(conv), + ..self + }) + } + + pub fn output( + self, + n: i32, + c: i32, + h: i32, + w: i32, + format: TensorFormat, + ) -> Result { + let output = tensor_descriptor_4d::(n, c, h, w, format)?; + Ok(Self { + output: Some(output), + ..self + }) + } + + pub fn algorithm(self, requested_algo_count: usize) -> Result { + let input = self.input.unwrap(); + let filter = self.filter.unwrap(); + let conv = self.conv.unwrap(); + let output = self.output.unwrap(); + let algorithm = + convolution_backward_data_algorithm(input, filter, conv, output, requested_algo_count)?; + Ok(Self { + algorithm: Some(algorithm), + ..self + }) + } + + pub fn build(self) -> Result { + let input = self.input.unwrap(); + let filter = self.filter.unwrap(); + let conv = self.conv.unwrap(); + let output = self.output.unwrap(); + let algorithm = self.algorithm.unwrap(); + let workspace = + convolution_backward_data_workspace(input, filter, conv, output, algorithm)?; + Ok(ConvolutionBackwardData { + input, + filter, + conv, + output, + algorithm, + workspace, + }) + } +} + +pub struct ConvolutionBackwardFilter { + input: cudnnTensorDescriptor_t, + filter: cudnnFilterDescriptor_t, + conv: cudnnConvolutionDescriptor_t, + output: cudnnTensorDescriptor_t, + algorithm: cudnnConvolutionBwdFilterAlgo_t, + workspace: Workspace, +} + +impl ConvolutionBackwardFilter { + pub fn backward_filter( + &self, + alpha: T, + input: *const T, + d_output: *const T, + beta: T, + filter: *mut T, + ) { + let state = ZENU_CUDA_STATE.lock().unwrap(); + let handle = state.get_cudnn(); + unsafe { + cudnnConvolutionBackwardFilter( + handle.as_ptr(), + &alpha as *const T as *const libc::c_void, + self.input, + input as *const libc::c_void, + self.output, + d_output as *const libc::c_void, + self.conv, + self.algorithm, + self.workspace.workspace(), + self.workspace.workspace_size, + &beta as *const T as *const libc::c_void, + self.filter, + filter as *mut libc::c_void, + ); + } + } +} + +impl Drop for ConvolutionBackwardFilter { + fn drop(&mut self) { + unsafe { + cudnnDestroyTensorDescriptor(self.input); + cudnnDestroyFilterDescriptor(self.filter); + cudnnDestroyConvolutionDescriptor(self.conv); + cudnnDestroyTensorDescriptor(self.output); + } + } +} + +#[derive(Debug, Default, PartialEq, Eq, Hash)] +pub struct ConvolutionBackwardFilterBuilder { + input: Option, + filter: Option, + conv: Option, + output: Option, + algorithm: Option, +} + +impl ConvolutionBackwardFilterBuilder { + pub fn input( + self, + n: i32, + c: i32, + h: i32, + w: i32, + format: TensorFormat, + ) -> Result { + let input = tensor_descriptor_4d::(n, c, h, w, format)?; + Ok(Self { + input: Some(input), + ..self + }) + } + + pub fn filter( + self, + k: i32, + c: i32, + h: i32, + w: i32, + format: TensorFormat, + ) -> Result { + let filter = filter_descriptor::(k, c, h, w, format)?; + Ok(Self { + filter: Some(filter), + ..self + }) + } + + pub fn conv( + self, + pad_h: i32, + pad_w: i32, + stride_h: i32, + stride_w: i32, + dilation_h: i32, + dilation_w: i32, + ) -> Result { + let conv = + convolution_descriptor(pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w)?; + Ok(Self { + conv: Some(conv), + ..self + }) + } + + pub fn output( + self, + n: i32, + c: i32, + h: i32, + w: i32, + format: TensorFormat, + ) -> Result { + let output = tensor_descriptor_4d::(n, c, h, w, format)?; + Ok(Self { + output: Some(output), + ..self + }) + } + + pub fn algorithm(self, requested_algo_count: usize) -> Result { + let input = self.input.unwrap(); + let filter = self.filter.unwrap(); + let conv = self.conv.unwrap(); + let output = self.output.unwrap(); + let algorithm = convolution_backward_filter_algorithm( + input, + filter, + conv, + output, + requested_algo_count, + )?; + Ok(Self { + algorithm: Some(algorithm), + ..self + }) + } + + pub fn build(self) -> Result { + let input = self.input.unwrap(); + let filter = self.filter.unwrap(); + let conv = self.conv.unwrap(); + let output = self.output.unwrap(); + let algorithm = self.algorithm.unwrap(); + let workspace = + convolution_backward_filter_workspace(input, filter, conv, output, algorithm)?; + Ok(ConvolutionBackwardFilter { + input, + filter, + conv, + output, + algorithm, + workspace, + }) + } +} + +#[cfg(test)] +mod cudnn { + use crate::runtime::{cuda_copy, cuda_malloc, ZenuCudaMemCopyKind}; + + use super::*; + + #[test] + fn test_convolution() { + let n = 1; + let c = 3; + let h = 5; + let w = 5; + let k = 3; + let kh = 3; + let kw = 3; + let pad_h = 1; + let pad_w = 1; + let stride_h = 1; + let stride_w = 1; + + // 畳み込み後の出力テンソルのサイズ + let out_h = (h + 2 * pad_h - kh) / stride_h + 1; + let out_w = (w + 2 * pad_w - kw) / stride_w + 1; + + let conv = ConvolutionBuilder::default() + .input::(n, c, h, w, TensorFormat::NCHW) + .unwrap() + .filter::(k, c, kh, kw, TensorFormat::NCHW) + .unwrap() + .conv(pad_h, pad_w, stride_h, stride_w, 1, 1) + .unwrap() + .output::(n, k, out_h, out_w, TensorFormat::NCHW) // ここで出力テンソルのサイズを変更 + .unwrap() + .algorithm(1) + .unwrap() + .build() + .unwrap(); + + // create input tensor + let mut input_cpu = Vec::new(); + for idx in 0..n * c * h * w { + input_cpu.push(idx as f32); + } + let input_gpu = cuda_malloc::((n * c * h * w) as usize).unwrap(); + cuda_copy( + input_gpu, + input_cpu.as_ptr(), + (n * c * h * w) as usize, + ZenuCudaMemCopyKind::HostToDevice, + ) + .unwrap(); + + // create filter tensor + let mut filter_cpu = Vec::new(); + for idx in 0..k * c * kh * kw { + filter_cpu.push(idx as f32); + } + let filter_gpu = cuda_malloc::((k * c * kh * kw) as usize).unwrap(); + cuda_copy( + filter_gpu, + filter_cpu.as_ptr(), + (k * c * kh * kw) as usize, + ZenuCudaMemCopyKind::HostToDevice, + ) + .unwrap(); + + // create output tensor + let output_gpu = cuda_malloc::((n * k * out_h * out_w) as usize).unwrap(); + + // execute convolution + conv.forward(1.0, input_gpu, filter_gpu, 0.0, output_gpu); + + // copy output tensor to cpu + let mut output_cpu = Vec::new(); + for _ in 0..n * k * out_h * out_w { + output_cpu.push(0.0); + } + cuda_copy( + output_cpu.as_mut_ptr(), + output_gpu, + (n * k * out_h * out_w) as usize, + ZenuCudaMemCopyKind::DeviceToHost, + ) + .unwrap(); + + // check output tensor + let ans = vec![ + 6888, 10218, 10479, 10740, 7056, 10296, 15219, 15570, 15921, 10422, 11511, 16974, + 17325, 17676, 11547, 12726, 18729, 19080, 19431, 12672, 8040, 11784, 11991, 12198, + 7920, 15960, 24069, 24816, 25563, 17100, 25119, 37818, 38898, 39978, 26703, 28764, + 43218, 44298, 45378, 30258, 32409, 48618, 49698, 50778, 33813, 21972, 32925, 33618, + 34311, 22824, 25032, 37920, 39153, 40386, 27144, 39942, 60417, 62226, 64035, 42984, + 46017, 69462, 71271, 73080, 48969, 52092, 78507, 80316, 82125, 54954, 35904, 54066, + 55245, 56424, 37728, + ]; + let ans = ans.iter().map(|&x| x as f32).collect::>(); + assert_eq!(output_cpu, ans); + } + + #[test] + fn bkwd_data() { + let n = 1; + let c = 3; + let h = 5; + let w = 5; + let k = 3; + let kh = 3; + let kw = 3; + let pad_h = 1; + let pad_w = 1; + let stride_h = 1; + let stride_w = 1; + + // 畳み込み後の出力テンソルのサイズ + let out_h = (h + 2 * pad_h - kh) / stride_h + 1; + let out_w = (w + 2 * pad_w - kw) / stride_w + 1; + + let conv = ConvolutionBackwardDataBuilder::default() + .input::(n, c, out_h, out_w, TensorFormat::NCHW) + .unwrap() + .filter::(k, c, kh, kw, TensorFormat::NCHW) + .unwrap() + .conv(pad_h, pad_w, stride_h, stride_w, 1, 1) + .unwrap() + .output::(n, k, h, w, TensorFormat::NCHW) // ここで出力テンソルのサイズを変更 + .unwrap() + .algorithm(5) + .unwrap() + .build() + .unwrap(); + + let mut input_cpu = Vec::new(); + for idx in 0..n * c * out_h * out_w { + input_cpu.push(idx as f32); + } + + let mut filter_cpu = Vec::new(); + for idx in 0..k * c * kh * kw { + filter_cpu.push(idx as f32); + } + + let input_gpu = cuda_malloc::((n * c * out_h * out_w) as usize).unwrap(); + let filter_gpu = cuda_malloc::((k * c * kh * kw) as usize).unwrap(); + let output_gpu = cuda_malloc::((n * k * h * w) as usize).unwrap(); + + cuda_copy( + input_gpu, + input_cpu.as_ptr(), + (n * c * out_h * out_w) as usize, + ZenuCudaMemCopyKind::HostToDevice, + ) + .unwrap(); + cuda_copy( + filter_gpu, + filter_cpu.as_ptr(), + (k * c * kh * kw) as usize, + ZenuCudaMemCopyKind::HostToDevice, + ) + .unwrap(); + + conv.backward_data(1.0, filter_gpu, input_gpu, 0.0, output_gpu); + + let mut output_cpu = Vec::new(); + for _ in 0..n * k * h * w { + output_cpu.push(0.0); + } + cuda_copy( + output_cpu.as_mut_ptr(), + output_gpu, + (n * k * h * w) as usize, + ZenuCudaMemCopyKind::DeviceToHost, + ) + .unwrap(); + let ans = vec![ + 15096.0, 23154.0, 23685.0, 24216.0, 16512.0, 24660.0, 37809.0, 38646.0, 39483.0, + 26910.0, 27405.0, 41994.0, 42831.0, 43668.0, 29745.0, 30150.0, 46179.0, 47016.0, + 47853.0, 32580.0, 21864.0, 33468.0, 34053.0, 34638.0, 23568.0, 18120.0, 27771.0, + 28464.0, 29157.0, 19860.0, 29601.0, 45342.0, 46422.0, 47502.0, 32337.0, 33156.0, + 50742.0, 51822.0, 52902.0, 35982.0, 36711.0, 56142.0, 57222.0, 58302.0, 39627.0, + 26508.0, 40515.0, 41262.0, 42009.0, 28536.0, 21144.0, 32388.0, 33243.0, 34098.0, + 23208.0, 34542.0, 52875.0, 54198.0, 55521.0, 37764.0, 38907.0, 59490.0, 60813.0, + 62136.0, 42219.0, 43272.0, 66105.0, 67428.0, 68751.0, 46674.0, 31152.0, 47562.0, + 48471.0, 49380.0, 33504.0, + ]; + assert_eq!(output_cpu, ans); + } + + #[test] + fn bkwd_filter() { + let n = 1; + let c = 3; + let h = 5; + let w = 5; + let k = 3; + let kh = 3; + let kw = 3; + let pad_h = 1; + let pad_w = 1; + let stride_h = 1; + let stride_w = 1; + + // 畳み込み後の出力テンソルのサイズ + let out_h = (h + 2 * pad_h - kh) / stride_h + 1; + let out_w = (w + 2 * pad_w - kw) / stride_w + 1; + + let conv = ConvolutionBackwardFilterBuilder::default() + .input::(n, c, h, w, TensorFormat::NCHW) + .unwrap() + .filter::(k, c, kh, kw, TensorFormat::NCHW) + .unwrap() + .conv(pad_h, pad_w, stride_h, stride_w, 1, 1) + .unwrap() + .output::(n, k, out_h, out_w, TensorFormat::NCHW) + .unwrap() + .algorithm(1) + .unwrap() + .build() + .unwrap(); + + let mut input_cpu = Vec::new(); + for idx in 0..n * c * h * w { + input_cpu.push(idx as f32); + } + + let mut d_output_cpu = Vec::new(); + for idx in 0..n * k * out_h * out_w { + d_output_cpu.push((idx % 10) as f32); + } + + let input_gpu = cuda_malloc::((n * c * h * w) as usize).unwrap(); + let filter_gpu = cuda_malloc::((k * c * kh * kw) as usize).unwrap(); + let output_gpu = cuda_malloc::((n * k * out_h * out_w) as usize).unwrap(); + + cuda_copy( + input_gpu, + input_cpu.as_ptr(), + (n * c * h * w) as usize, + ZenuCudaMemCopyKind::HostToDevice, + ) + .unwrap(); + cuda_copy( + output_gpu, + d_output_cpu.as_ptr(), + (n * k * out_h * out_w) as usize, + ZenuCudaMemCopyKind::HostToDevice, + ) + .unwrap(); + + conv.backward_filter(1.0, input_gpu, output_gpu, 0.0, filter_gpu); + + let mut filter_cpu = Vec::new(); + for _ in 0..k * c * kh * kw { + filter_cpu.push(0.0); + } + cuda_copy( + filter_cpu.as_mut_ptr(), + filter_gpu, + (k * c * kh * kw) as usize, + ZenuCudaMemCopyKind::DeviceToHost, + ) + .unwrap(); + + let ans = vec![ + 640.0, 770.0, 560.0, 1060.0, 1250.0, 900.0, 1240.0, 1470.0, 1080.0, 2640.0, 3020.0, + 2160.0, 3310.0, 3750.0, 2650.0, 3240.0, 3720.0, 2680.0, 4640.0, 5270.0, 3760.0, 5560.0, + 6250.0, 4400.0, 5240.0, 5970.0, 4280.0, 840.0, 1020.0, 760.0, 1290.0, 1550.0, 1150.0, + 1040.0, 1220.0, 880.0, 2840.0, 3270.0, 2360.0, 4040.0, 4675.0, 3400.0, 3040.0, 3470.0, + 2480.0, 4840.0, 5520.0, 3960.0, 6790.0, 7800.0, 5650.0, 5040.0, 5720.0, 4080.0, 640.0, + 770.0, 560.0, 1060.0, 1250.0, 900.0, 1240.0, 1470.0, 1080.0, 2640.0, 3020.0, 2160.0, + 3310.0, 3750.0, 2650.0, 3240.0, 3720.0, 2680.0, 4640.0, 5270.0, 3760.0, 5560.0, 6250.0, + 4400.0, 5240.0, 5970.0, 4280.0, + ]; + + assert_eq!(filter_cpu, ans); + } +} diff --git a/zenu-cuda/src/cudnn/mod.rs b/zenu-cuda/src/cudnn/mod.rs index a1d3d57a..f837af65 100644 --- a/zenu-cuda/src/cudnn/mod.rs +++ b/zenu-cuda/src/cudnn/mod.rs @@ -2,13 +2,13 @@ use std::any::TypeId; use zenu_cudnn_sys::*; -use crate::ZENU_CUDA_STATE; - use self::error::ZenuCudnnError; +pub mod batch_norm; +pub mod conv; pub mod error; -pub fn zenu_cudnn_data_type() -> cudnnDataType_t { +pub(crate) fn zenu_cudnn_data_type() -> cudnnDataType_t { if TypeId::of::() == TypeId::of::() { cudnnDataType_t::CUDNN_DATA_FLOAT } else if TypeId::of::() == TypeId::of::() { @@ -38,7 +38,7 @@ impl From for cudnnTensorFormat_t { } } -fn tensor_descriptor( +pub(crate) fn tensor_descriptor_4d( n: i32, c: i32, h: i32, @@ -60,1039 +60,3 @@ fn tensor_descriptor( } Ok(tensor) } - -fn filter_descriptor( - k: i32, - c: i32, - h: i32, - w: i32, - format: TensorFormat, -) -> Result { - let data_type = zenu_cudnn_data_type::(); - let format = format.into(); - let mut filter: cudnnFilterDescriptor_t = std::ptr::null_mut(); - unsafe { - let status = cudnnCreateFilterDescriptor(&mut filter as *mut cudnnFilterDescriptor_t); - if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { - return Err(ZenuCudnnError::from(status)); - } - let status = cudnnSetFilter4dDescriptor(filter, data_type, format, k, c, h, w); - if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { - return Err(ZenuCudnnError::from(status)); - } - } - Ok(filter) -} - -fn convolution_descriptor( - pad_h: i32, - pad_w: i32, - stride_h: i32, - stride_w: i32, - dilation_h: i32, - dilation_w: i32, -) -> Result { - let mut conv: cudnnConvolutionDescriptor_t = std::ptr::null_mut(); - unsafe { - let status = - cudnnCreateConvolutionDescriptor(&mut conv as *mut cudnnConvolutionDescriptor_t); - if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { - return Err(ZenuCudnnError::from(status)); - } - let status = cudnnSetConvolution2dDescriptor( - conv, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - cudnnConvolutionMode_t::CUDNN_CROSS_CORRELATION, - zenu_cudnn_data_type::(), - ); - if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { - return Err(ZenuCudnnError::from(status)); - } - } - Ok(conv) -} - -fn convolution_algorithm( - input: cudnnTensorDescriptor_t, - filter: cudnnFilterDescriptor_t, - conv: cudnnConvolutionDescriptor_t, - output: cudnnTensorDescriptor_t, - requested_algo_count: usize, -) -> Result { - let state = ZENU_CUDA_STATE.lock().unwrap(); - let handle = state.get_cudnn(); - let mut returned_algo_count = 0; - unsafe { - let mut algorithm: Vec = - Vec::with_capacity(requested_algo_count); - for _ in 0..requested_algo_count { - algorithm.push(cudnnConvolutionFwdAlgoPerf_t::default()); - } - - // enable tensor core - cudnnSetConvolutionMathType(conv, cudnnMathType_t::CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION); - - let state = cudnnGetConvolutionForwardAlgorithm_v7( - handle.as_ptr(), - input, - filter, - conv, - output, - requested_algo_count as i32, - &mut returned_algo_count as *mut i32, - algorithm.as_mut_ptr() as *mut cudnnConvolutionFwdAlgoPerf_t, - ); - if state != cudnnStatus_t::CUDNN_STATUS_SUCCESS { - return Err(ZenuCudnnError::from(state)); - } - - Ok(algorithm[0].algo) - } -} - -fn convolution_backward_data_algorithm( - input: cudnnTensorDescriptor_t, - filter: cudnnFilterDescriptor_t, - conv: cudnnConvolutionDescriptor_t, - output: cudnnTensorDescriptor_t, - requested_algo_count: usize, -) -> Result { - let state = ZENU_CUDA_STATE.lock().unwrap(); - let handle = state.get_cudnn(); - let mut returned_algo_count = 0; - unsafe { - let mut algorithm: Vec = Vec::with_capacity(1); - for _ in 0..requested_algo_count { - algorithm.push(cudnnConvolutionBwdDataAlgoPerf_t::default()); - } - - // enable tensor core - cudnnSetConvolutionMathType(conv, cudnnMathType_t::CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION); - - let state = cudnnGetConvolutionBackwardDataAlgorithm_v7( - handle.as_ptr(), - filter, - input, - conv, - output, - requested_algo_count as i32, - &mut returned_algo_count as *mut i32, - algorithm.as_mut_ptr() as *mut cudnnConvolutionBwdDataAlgoPerf_t, - ); - if state != cudnnStatus_t::CUDNN_STATUS_SUCCESS { - return Err(ZenuCudnnError::from(state)); - } - - Ok(algorithm[0].algo) - } -} - -fn convolution_backward_filter_algorithm( - input: cudnnTensorDescriptor_t, - filter: cudnnFilterDescriptor_t, - conv: cudnnConvolutionDescriptor_t, - output: cudnnTensorDescriptor_t, - requested_algo_count: usize, -) -> Result { - let state = ZENU_CUDA_STATE.lock().unwrap(); - let handle = state.get_cudnn(); - let mut returned_algo_count = 0; - unsafe { - let mut algorithm: Vec = - Vec::with_capacity(requested_algo_count); - for _ in 0..requested_algo_count { - algorithm.push(cudnnConvolutionBwdFilterAlgoPerf_t::default()); - } - - // enable tensor core - cudnnSetConvolutionMathType(conv, cudnnMathType_t::CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION); - - let state = cudnnGetConvolutionBackwardFilterAlgorithm_v7( - handle.as_ptr(), - input, - output, - conv, - filter, - requested_algo_count as i32, - &mut returned_algo_count as *mut i32, - algorithm.as_mut_ptr() as *mut cudnnConvolutionBwdFilterAlgoPerf_t, - ); - if state != cudnnStatus_t::CUDNN_STATUS_SUCCESS { - return Err(ZenuCudnnError::from(state)); - } - - Ok(algorithm[0].algo) - } -} - -fn convolution_workspace( - input: cudnnTensorDescriptor_t, - filter: cudnnFilterDescriptor_t, - conv: cudnnConvolutionDescriptor_t, - output: cudnnTensorDescriptor_t, - algorithm: cudnnConvolutionFwdAlgo_t, -) -> Result { - let state = ZENU_CUDA_STATE.lock().unwrap(); - let handle = state.get_cudnn(); - let mut workspace_size = 0; - unsafe { - let status = cudnnGetConvolutionForwardWorkspaceSize( - handle.as_ptr(), - input, - filter, - conv, - output, - algorithm, - &mut workspace_size as *mut usize, - ); - if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { - panic!("Failed to get convolution forward workspace size"); - } - let mut workspace = std::ptr::null_mut(); - let status = cudaMalloc(&mut workspace as *mut *mut libc::c_void, workspace_size); - if status != cudaError_t::cudaSuccess { - panic!("Failed to allocate convolution forward workspace"); - } - Ok(Workspace { - workspace, - workspace_size, - }) - } -} - -fn convolution_backward_data_workspace( - input: cudnnTensorDescriptor_t, - filter: cudnnFilterDescriptor_t, - conv: cudnnConvolutionDescriptor_t, - output: cudnnTensorDescriptor_t, - algorithm: cudnnConvolutionBwdDataAlgo_t, -) -> Result { - let state = ZENU_CUDA_STATE.lock().unwrap(); - let handle = state.get_cudnn(); - let mut workspace_size = 0; - unsafe { - let status = cudnnGetConvolutionBackwardDataWorkspaceSize( - handle.as_ptr(), - filter, - output, - conv, - input, - algorithm, - &mut workspace_size as *mut usize, - ); - if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { - panic!("Failed to get convolution backward data workspace size"); - } - let mut workspace = std::ptr::null_mut(); - let status = cudaMalloc(&mut workspace as *mut *mut libc::c_void, workspace_size); - if status != cudaError_t::cudaSuccess { - panic!("Failed to allocate convolution backward data workspace"); - } - Ok(Workspace { - workspace, - workspace_size, - }) - } -} - -fn convolution_backward_filter_workspace( - input: cudnnTensorDescriptor_t, - filter: cudnnFilterDescriptor_t, - conv: cudnnConvolutionDescriptor_t, - output: cudnnTensorDescriptor_t, - algorithm: cudnnConvolutionBwdFilterAlgo_t, -) -> Result { - let state = ZENU_CUDA_STATE.lock().unwrap(); - let handle = state.get_cudnn(); - let mut workspace_size = 0; - unsafe { - let status = cudnnGetConvolutionBackwardFilterWorkspaceSize( - handle.as_ptr(), - input, - output, - conv, - filter, - algorithm, - &mut workspace_size as *mut usize, - ); - if status != cudnnStatus_t::CUDNN_STATUS_SUCCESS { - panic!("Failed to get convolution backward filter workspace size"); - } - let mut workspace = std::ptr::null_mut(); - let status = cudaMalloc(&mut workspace as *mut *mut libc::c_void, workspace_size); - if status != cudaError_t::cudaSuccess { - panic!("Failed to allocate convolution backward filter workspace"); - } - Ok(Workspace { - workspace, - workspace_size, - }) - } -} - -#[derive(Debug)] -pub struct Workspace { - workspace: *mut libc::c_void, - workspace_size: usize, -} - -#[derive(Debug)] -pub struct ConvDescriptor { - input: cudnnTensorDescriptor_t, - filter: cudnnFilterDescriptor_t, - conv: cudnnConvolutionDescriptor_t, - output: cudnnTensorDescriptor_t, - algorithm: cudnnConvolutionFwdAlgo_t, - workspace: Workspace, -} - -impl ConvDescriptor { - pub fn forward( - &self, - alpha: T, - input: *const T, - filter: *const T, - beta: T, - output: *mut T, - ) { - let state = ZENU_CUDA_STATE.lock().unwrap(); - let handle = state.get_cudnn(); - unsafe { - cudnnConvolutionForward( - handle.as_ptr(), - &alpha as *const T as *const libc::c_void, - self.input, - input as *const libc::c_void, - self.filter, - filter as *const libc::c_void, - self.conv, - self.algorithm, - self.workspace.workspace, - self.workspace.workspace_size, - &beta as *const T as *const libc::c_void, - self.output, - output as *mut libc::c_void, - ); - } - } -} - -impl Drop for ConvDescriptor { - fn drop(&mut self) { - unsafe { - cudnnDestroyTensorDescriptor(self.input); - cudnnDestroyFilterDescriptor(self.filter); - cudnnDestroyConvolutionDescriptor(self.conv); - cudnnDestroyTensorDescriptor(self.output); - cudaFree(self.workspace.workspace); - } - } -} - -#[derive(Debug, Default, PartialEq, Eq, Hash)] -pub struct ConvolutionBuilder { - input: Option, - filter: Option, - conv: Option, - output: Option, - algorithm: Option, -} - -impl ConvolutionBuilder { - pub fn input( - self, - n: i32, - c: i32, - h: i32, - w: i32, - format: TensorFormat, - ) -> Result { - let input = tensor_descriptor::(n, c, h, w, format)?; - Ok(Self { - input: Some(input), - ..self - }) - } - - pub fn filter( - self, - k: i32, - c: i32, - h: i32, - w: i32, - format: TensorFormat, - ) -> Result { - let filter = filter_descriptor::(k, c, h, w, format)?; - Ok(Self { - filter: Some(filter), - ..self - }) - } - - pub fn conv( - self, - pad_h: i32, - pad_w: i32, - stride_h: i32, - stride_w: i32, - dilation_h: i32, - dilation_w: i32, - ) -> Result { - let conv = - convolution_descriptor(pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w)?; - Ok(Self { - conv: Some(conv), - ..self - }) - } - - pub fn output( - self, - n: i32, - c: i32, - h: i32, - w: i32, - format: TensorFormat, - ) -> Result { - let output = tensor_descriptor::(n, c, h, w, format)?; - Ok(Self { - output: Some(output), - ..self - }) - } - - pub fn algorithm(self, requested_algo_count: usize) -> Result { - let input = self.input.unwrap(); - let filter = self.filter.unwrap(); - let conv = self.conv.unwrap(); - let output = self.output.unwrap(); - let algorithm = convolution_algorithm(input, filter, conv, output, requested_algo_count)?; - Ok(Self { - algorithm: Some(algorithm), - ..self - }) - } - - pub fn build(self) -> Result { - let input = self.input.unwrap(); - let filter = self.filter.unwrap(); - let conv = self.conv.unwrap(); - let output = self.output.unwrap(); - let algorithm = self.algorithm.unwrap(); - let workspace = convolution_workspace(input, filter, conv, output, algorithm)?; - Ok(ConvDescriptor { - input, - filter, - conv, - output, - algorithm, - workspace, - }) - } -} - -pub struct ConvolutionBackwardData { - input: cudnnTensorDescriptor_t, - filter: cudnnFilterDescriptor_t, - conv: cudnnConvolutionDescriptor_t, - output: cudnnTensorDescriptor_t, - algorithm: cudnnConvolutionBwdDataAlgo_t, - workspace: Workspace, -} - -impl ConvolutionBackwardData { - pub fn backward_data( - &self, - alpha: T, - filter: *const T, - output: *const T, - beta: T, - input: *mut T, - ) { - let state = ZENU_CUDA_STATE.lock().unwrap(); - let handle = state.get_cudnn(); - unsafe { - cudnnConvolutionBackwardData( - handle.as_ptr(), - &alpha as *const T as *const libc::c_void, - self.filter, - filter as *const libc::c_void, - self.output, - output as *const libc::c_void, - self.conv, - self.algorithm, - self.workspace.workspace, - self.workspace.workspace_size, - &beta as *const T as *const libc::c_void, - self.input, - input as *mut libc::c_void, - ); - } - } -} - -impl Drop for ConvolutionBackwardData { - fn drop(&mut self) { - unsafe { - cudnnDestroyTensorDescriptor(self.input); - cudnnDestroyFilterDescriptor(self.filter); - cudnnDestroyConvolutionDescriptor(self.conv); - cudnnDestroyTensorDescriptor(self.output); - cudaFree(self.workspace.workspace); - } - } -} - -#[derive(Debug, Default, PartialEq, Eq, Hash)] -pub struct ConvolutionBackwardDataBuilder { - input: Option, - filter: Option, - conv: Option, - output: Option, - algorithm: Option, -} - -impl ConvolutionBackwardDataBuilder { - pub fn input( - self, - n: i32, - c: i32, - h: i32, - w: i32, - format: TensorFormat, - ) -> Result { - let input = tensor_descriptor::(n, c, h, w, format)?; - Ok(Self { - input: Some(input), - ..self - }) - } - - pub fn filter( - self, - k: i32, - c: i32, - h: i32, - w: i32, - format: TensorFormat, - ) -> Result { - let filter = filter_descriptor::(k, c, h, w, format)?; - Ok(Self { - filter: Some(filter), - ..self - }) - } - - pub fn conv( - self, - pad_h: i32, - pad_w: i32, - stride_h: i32, - stride_w: i32, - dilation_h: i32, - dilation_w: i32, - ) -> Result { - let conv = - convolution_descriptor(pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w)?; - Ok(Self { - conv: Some(conv), - ..self - }) - } - - pub fn output( - self, - n: i32, - c: i32, - h: i32, - w: i32, - format: TensorFormat, - ) -> Result { - let output = tensor_descriptor::(n, c, h, w, format)?; - Ok(Self { - output: Some(output), - ..self - }) - } - - pub fn algorithm(self, requested_algo_count: usize) -> Result { - let input = self.input.unwrap(); - let filter = self.filter.unwrap(); - let conv = self.conv.unwrap(); - let output = self.output.unwrap(); - let algorithm = - convolution_backward_data_algorithm(input, filter, conv, output, requested_algo_count)?; - Ok(Self { - algorithm: Some(algorithm), - ..self - }) - } - - pub fn build(self) -> Result { - let input = self.input.unwrap(); - let filter = self.filter.unwrap(); - let conv = self.conv.unwrap(); - let output = self.output.unwrap(); - let algorithm = self.algorithm.unwrap(); - let workspace = - convolution_backward_data_workspace(input, filter, conv, output, algorithm)?; - Ok(ConvolutionBackwardData { - input, - filter, - conv, - output, - algorithm, - workspace, - }) - } -} - -pub struct ConvolutionBackwardFilter { - input: cudnnTensorDescriptor_t, - filter: cudnnFilterDescriptor_t, - conv: cudnnConvolutionDescriptor_t, - output: cudnnTensorDescriptor_t, - algorithm: cudnnConvolutionBwdFilterAlgo_t, - workspace: Workspace, -} - -impl ConvolutionBackwardFilter { - pub fn backward_filter( - &self, - alpha: T, - input: *const T, - d_output: *const T, - beta: T, - filter: *mut T, - ) { - let state = ZENU_CUDA_STATE.lock().unwrap(); - let handle = state.get_cudnn(); - unsafe { - cudnnConvolutionBackwardFilter( - handle.as_ptr(), - &alpha as *const T as *const libc::c_void, - self.input, - input as *const libc::c_void, - self.output, - d_output as *const libc::c_void, - self.conv, - self.algorithm, - self.workspace.workspace, - self.workspace.workspace_size, - &beta as *const T as *const libc::c_void, - self.filter, - filter as *mut libc::c_void, - ); - } - } -} - -impl Drop for ConvolutionBackwardFilter { - fn drop(&mut self) { - unsafe { - cudnnDestroyTensorDescriptor(self.input); - cudnnDestroyFilterDescriptor(self.filter); - cudnnDestroyConvolutionDescriptor(self.conv); - cudnnDestroyTensorDescriptor(self.output); - cudaFree(self.workspace.workspace); - } - } -} - -#[derive(Debug, Default, PartialEq, Eq, Hash)] -pub struct ConvolutionBackwardFilterBuilder { - input: Option, - filter: Option, - conv: Option, - output: Option, - algorithm: Option, -} - -impl ConvolutionBackwardFilterBuilder { - pub fn input( - self, - n: i32, - c: i32, - h: i32, - w: i32, - format: TensorFormat, - ) -> Result { - let input = tensor_descriptor::(n, c, h, w, format)?; - Ok(Self { - input: Some(input), - ..self - }) - } - - pub fn filter( - self, - k: i32, - c: i32, - h: i32, - w: i32, - format: TensorFormat, - ) -> Result { - let filter = filter_descriptor::(k, c, h, w, format)?; - Ok(Self { - filter: Some(filter), - ..self - }) - } - - pub fn conv( - self, - pad_h: i32, - pad_w: i32, - stride_h: i32, - stride_w: i32, - dilation_h: i32, - dilation_w: i32, - ) -> Result { - let conv = - convolution_descriptor(pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w)?; - Ok(Self { - conv: Some(conv), - ..self - }) - } - - pub fn output( - self, - n: i32, - c: i32, - h: i32, - w: i32, - format: TensorFormat, - ) -> Result { - let output = tensor_descriptor::(n, c, h, w, format)?; - Ok(Self { - output: Some(output), - ..self - }) - } - - pub fn algorithm(self, requested_algo_count: usize) -> Result { - let input = self.input.unwrap(); - let filter = self.filter.unwrap(); - let conv = self.conv.unwrap(); - let output = self.output.unwrap(); - let algorithm = convolution_backward_filter_algorithm( - input, - filter, - conv, - output, - requested_algo_count, - )?; - Ok(Self { - algorithm: Some(algorithm), - ..self - }) - } - - pub fn build(self) -> Result { - let input = self.input.unwrap(); - let filter = self.filter.unwrap(); - let conv = self.conv.unwrap(); - let output = self.output.unwrap(); - let algorithm = self.algorithm.unwrap(); - let workspace = - convolution_backward_filter_workspace(input, filter, conv, output, algorithm)?; - Ok(ConvolutionBackwardFilter { - input, - filter, - conv, - output, - algorithm, - workspace, - }) - } -} - -#[cfg(test)] -mod cudnn { - use crate::runtime::{cuda_copy, cuda_malloc, ZenuCudaMemCopyKind}; - - use super::*; - - #[test] - fn test_convolution() { - let n = 1; - let c = 3; - let h = 5; - let w = 5; - let k = 3; - let kh = 3; - let kw = 3; - let pad_h = 1; - let pad_w = 1; - let stride_h = 1; - let stride_w = 1; - - // 畳み込み後の出力テンソルのサイズ - let out_h = (h + 2 * pad_h - kh) / stride_h + 1; - let out_w = (w + 2 * pad_w - kw) / stride_w + 1; - - let conv = ConvolutionBuilder::default() - .input::(n, c, h, w, TensorFormat::NCHW) - .unwrap() - .filter::(k, c, kh, kw, TensorFormat::NCHW) - .unwrap() - .conv(pad_h, pad_w, stride_h, stride_w, 1, 1) - .unwrap() - .output::(n, k, out_h, out_w, TensorFormat::NCHW) // ここで出力テンソルのサイズを変更 - .unwrap() - .algorithm(1) - .unwrap() - .build() - .unwrap(); - - // create input tensor - let mut input_cpu = Vec::new(); - for idx in 0..n * c * h * w { - input_cpu.push(idx as f32); - } - let input_gpu = cuda_malloc::((n * c * h * w) as usize).unwrap(); - cuda_copy( - input_gpu, - input_cpu.as_ptr(), - (n * c * h * w) as usize, - ZenuCudaMemCopyKind::HostToDevice, - ) - .unwrap(); - - // create filter tensor - let mut filter_cpu = Vec::new(); - for idx in 0..k * c * kh * kw { - filter_cpu.push(idx as f32); - } - let filter_gpu = cuda_malloc::((k * c * kh * kw) as usize).unwrap(); - cuda_copy( - filter_gpu, - filter_cpu.as_ptr(), - (k * c * kh * kw) as usize, - ZenuCudaMemCopyKind::HostToDevice, - ) - .unwrap(); - - // create output tensor - let output_gpu = cuda_malloc::((n * k * out_h * out_w) as usize).unwrap(); - - // execute convolution - conv.forward(1.0, input_gpu, filter_gpu, 0.0, output_gpu); - - // copy output tensor to cpu - let mut output_cpu = Vec::new(); - for _ in 0..n * k * out_h * out_w { - output_cpu.push(0.0); - } - cuda_copy( - output_cpu.as_mut_ptr(), - output_gpu, - (n * k * out_h * out_w) as usize, - ZenuCudaMemCopyKind::DeviceToHost, - ) - .unwrap(); - - // check output tensor - let ans = vec![ - 6888, 10218, 10479, 10740, 7056, 10296, 15219, 15570, 15921, 10422, 11511, 16974, - 17325, 17676, 11547, 12726, 18729, 19080, 19431, 12672, 8040, 11784, 11991, 12198, - 7920, 15960, 24069, 24816, 25563, 17100, 25119, 37818, 38898, 39978, 26703, 28764, - 43218, 44298, 45378, 30258, 32409, 48618, 49698, 50778, 33813, 21972, 32925, 33618, - 34311, 22824, 25032, 37920, 39153, 40386, 27144, 39942, 60417, 62226, 64035, 42984, - 46017, 69462, 71271, 73080, 48969, 52092, 78507, 80316, 82125, 54954, 35904, 54066, - 55245, 56424, 37728, - ]; - let ans = ans.iter().map(|&x| x as f32).collect::>(); - assert_eq!(output_cpu, ans); - } - - #[test] - fn bkwd_data() { - let n = 1; - let c = 3; - let h = 5; - let w = 5; - let k = 3; - let kh = 3; - let kw = 3; - let pad_h = 1; - let pad_w = 1; - let stride_h = 1; - let stride_w = 1; - - // 畳み込み後の出力テンソルのサイズ - let out_h = (h + 2 * pad_h - kh) / stride_h + 1; - let out_w = (w + 2 * pad_w - kw) / stride_w + 1; - - let conv = ConvolutionBackwardDataBuilder::default() - .input::(n, c, out_h, out_w, TensorFormat::NCHW) - .unwrap() - .filter::(k, c, kh, kw, TensorFormat::NCHW) - .unwrap() - .conv(pad_h, pad_w, stride_h, stride_w, 1, 1) - .unwrap() - .output::(n, k, h, w, TensorFormat::NCHW) // ここで出力テンソルのサイズを変更 - .unwrap() - .algorithm(5) - .unwrap() - .build() - .unwrap(); - - let mut input_cpu = Vec::new(); - for idx in 0..n * c * out_h * out_w { - input_cpu.push(idx as f32); - } - - let mut filter_cpu = Vec::new(); - for idx in 0..k * c * kh * kw { - filter_cpu.push(idx as f32); - } - - let input_gpu = cuda_malloc::((n * c * out_h * out_w) as usize).unwrap(); - let filter_gpu = cuda_malloc::((k * c * kh * kw) as usize).unwrap(); - let output_gpu = cuda_malloc::((n * k * h * w) as usize).unwrap(); - - cuda_copy( - input_gpu, - input_cpu.as_ptr(), - (n * c * out_h * out_w) as usize, - ZenuCudaMemCopyKind::HostToDevice, - ) - .unwrap(); - cuda_copy( - filter_gpu, - filter_cpu.as_ptr(), - (k * c * kh * kw) as usize, - ZenuCudaMemCopyKind::HostToDevice, - ) - .unwrap(); - - conv.backward_data(1.0, filter_gpu, input_gpu, 0.0, output_gpu); - - let mut output_cpu = Vec::new(); - for _ in 0..n * k * h * w { - output_cpu.push(0.0); - } - cuda_copy( - output_cpu.as_mut_ptr(), - output_gpu, - (n * k * h * w) as usize, - ZenuCudaMemCopyKind::DeviceToHost, - ) - .unwrap(); - println!("{:?}", output_cpu); - let ans = vec![ - 15096.0, 23154.0, 23685.0, 24216.0, 16512.0, 24660.0, 37809.0, 38646.0, 39483.0, - 26910.0, 27405.0, 41994.0, 42831.0, 43668.0, 29745.0, 30150.0, 46179.0, 47016.0, - 47853.0, 32580.0, 21864.0, 33468.0, 34053.0, 34638.0, 23568.0, 18120.0, 27771.0, - 28464.0, 29157.0, 19860.0, 29601.0, 45342.0, 46422.0, 47502.0, 32337.0, 33156.0, - 50742.0, 51822.0, 52902.0, 35982.0, 36711.0, 56142.0, 57222.0, 58302.0, 39627.0, - 26508.0, 40515.0, 41262.0, 42009.0, 28536.0, 21144.0, 32388.0, 33243.0, 34098.0, - 23208.0, 34542.0, 52875.0, 54198.0, 55521.0, 37764.0, 38907.0, 59490.0, 60813.0, - 62136.0, 42219.0, 43272.0, 66105.0, 67428.0, 68751.0, 46674.0, 31152.0, 47562.0, - 48471.0, 49380.0, 33504.0, - ]; - assert_eq!(output_cpu, ans); - } - - #[test] - fn bkwd_filter() { - let n = 1; - let c = 3; - let h = 5; - let w = 5; - let k = 3; - let kh = 3; - let kw = 3; - let pad_h = 1; - let pad_w = 1; - let stride_h = 1; - let stride_w = 1; - - // 畳み込み後の出力テンソルのサイズ - let out_h = (h + 2 * pad_h - kh) / stride_h + 1; - let out_w = (w + 2 * pad_w - kw) / stride_w + 1; - - let conv = ConvolutionBackwardFilterBuilder::default() - .input::(n, c, h, w, TensorFormat::NCHW) - .unwrap() - .filter::(k, c, kh, kw, TensorFormat::NCHW) - .unwrap() - .conv(pad_h, pad_w, stride_h, stride_w, 1, 1) - .unwrap() - .output::(n, k, out_h, out_w, TensorFormat::NCHW) - .unwrap() - .algorithm(1) - .unwrap() - .build() - .unwrap(); - - let mut input_cpu = Vec::new(); - for idx in 0..n * c * h * w { - input_cpu.push(idx as f32); - } - - let mut d_output_cpu = Vec::new(); - for idx in 0..n * k * out_h * out_w { - d_output_cpu.push((idx % 10) as f32); - } - - let input_gpu = cuda_malloc::((n * c * h * w) as usize).unwrap(); - let filter_gpu = cuda_malloc::((k * c * kh * kw) as usize).unwrap(); - let output_gpu = cuda_malloc::((n * k * out_h * out_w) as usize).unwrap(); - - cuda_copy( - input_gpu, - input_cpu.as_ptr(), - (n * c * h * w) as usize, - ZenuCudaMemCopyKind::HostToDevice, - ) - .unwrap(); - cuda_copy( - output_gpu, - d_output_cpu.as_ptr(), - (n * k * out_h * out_w) as usize, - ZenuCudaMemCopyKind::HostToDevice, - ) - .unwrap(); - - conv.backward_filter(1.0, input_gpu, output_gpu, 0.0, filter_gpu); - - let mut filter_cpu = Vec::new(); - for _ in 0..k * c * kh * kw { - filter_cpu.push(0.0); - } - cuda_copy( - filter_cpu.as_mut_ptr(), - filter_gpu, - (k * c * kh * kw) as usize, - ZenuCudaMemCopyKind::DeviceToHost, - ) - .unwrap(); - - println!("{:?}", filter_cpu); - - // Expected results (you need to calculate them or use another framework) - // let expected = vec![ - // 3480.0, 3945.0, 4410.0, 5055.0, 5685.0, 6315.0, 7140.0, 7995.0, 8850.0, - // ]; - let ans = vec![ - 640.0, 770.0, 560.0, 1060.0, 1250.0, 900.0, 1240.0, 1470.0, 1080.0, 2640.0, 3020.0, - 2160.0, 3310.0, 3750.0, 2650.0, 3240.0, 3720.0, 2680.0, 4640.0, 5270.0, 3760.0, 5560.0, - 6250.0, 4400.0, 5240.0, 5970.0, 4280.0, 840.0, 1020.0, 760.0, 1290.0, 1550.0, 1150.0, - 1040.0, 1220.0, 880.0, 2840.0, 3270.0, 2360.0, 4040.0, 4675.0, 3400.0, 3040.0, 3470.0, - 2480.0, 4840.0, 5520.0, 3960.0, 6790.0, 7800.0, 5650.0, 5040.0, 5720.0, 4080.0, 640.0, - 770.0, 560.0, 1060.0, 1250.0, 900.0, 1240.0, 1470.0, 1080.0, 2640.0, 3020.0, 2160.0, - 3310.0, 3750.0, 2650.0, 3240.0, 3720.0, 2680.0, 4640.0, 5270.0, 3760.0, 5560.0, 6250.0, - 4400.0, 5240.0, 5970.0, 4280.0, - ]; - - assert_eq!(filter_cpu, ans); - } -}