diff --git a/src/nn/init.rs b/src/nn/init.rs index 89c43f9d..13b54431 100644 --- a/src/nn/init.rs +++ b/src/nn/init.rs @@ -79,13 +79,18 @@ pub enum Init { /// Uniform initialization between some lower and upper bounds. Uniform { lo: f64, up: f64 }, - /// Kaiming uniform initialization. + /// Kaiming initialization. /// See "Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification" - /// He, K. et al. (2015). This uses a uniform distribution. + /// He, K. et al. (2015). Kaiming { dist: NormalOrUniform, fan: FanInOut, non_linearity: NonLinearity }, /// Orthogonal initialization Orthogonal { gain: f64 }, + + /// Xavier (Glorot) initialization. + /// See "Understanding the difficulty of training deep feedforward neural networks" + /// Glorot, X. & Bengio, Y. (2010) + Xavier { dist: NormalOrUniform, non_linearity: NonLinearity }, } pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming { @@ -159,6 +164,22 @@ pub fn f_init(i: Init, dims: &[i64], device: Device, kind: Kind) -> Result { + let fan = FanInOut::FanIn.for_weight_dims(dims) + + FanInOut::FanOut.for_weight_dims(dims); + let gain = non_linearity.gain(); + match dist { + NormalOrUniform::Uniform => { + let bound = gain * (6.0 / fan as f64).sqrt(); + Tensor::f_zeros(dims, (kind, device))?.f_uniform_(-bound, bound) + } + NormalOrUniform::Normal => { + let std = gain * (2.0 / fan as f64).sqrt(); + let randn = Tensor::f_randn(dims, (kind, device))?; + Ok(randn * std) + } + } + } } } @@ -200,6 +221,21 @@ impl Init { .unwrap(); crate::no_grad(|| tensor.view_as(&q).copy_(&q)); } + Init::Xavier { dist, non_linearity } => { + let fan = FanInOut::FanIn.for_weight_dims(&tensor.size()) + + FanInOut::FanOut.for_weight_dims(&tensor.size()); + let gain = non_linearity.gain(); + match dist { + NormalOrUniform::Uniform => { + let bound = gain * (6.0 / fan as f64).sqrt(); + let _ = tensor.uniform_(-bound, bound); + } + NormalOrUniform::Normal => { + let std = gain * (2.0 / fan as f64).sqrt(); + tensor.copy_(&(tensor.randn_like() * std)); + } + } + } } } }