From 0ce5b4e33ea629ce94459c4fac68c2a525184c8b Mon Sep 17 00:00:00 2001 From: minsoo Date: Thu, 20 Jun 2024 11:37:50 -0700 Subject: [PATCH] reml for eigen --- src/MultiResponseVarianceComponentModels.jl | 56 +++++----- src/eigen.jl | 108 ++++++++++++++------ 2 files changed, 104 insertions(+), 60 deletions(-) diff --git a/src/MultiResponseVarianceComponentModels.jl b/src/MultiResponseVarianceComponentModels.jl index 6aa5e97..44af935 100644 --- a/src/MultiResponseVarianceComponentModels.jl +++ b/src/MultiResponseVarianceComponentModels.jl @@ -325,20 +325,23 @@ struct MRTVCModel{T <: BlasReal} <: VCModel Σcov :: Union{Nothing, Matrix{T}} # for fisher_Σ! # original data for reml Y_reml :: Union{Nothing, Matrix{T}} + Ỹ_reml :: Union{Nothing, Matrix{T}} X_reml :: Union{Nothing, Matrix{T}} + X̃_reml :: Union{Nothing, Matrix{T}} V_reml :: Union{Nothing, Vector{Matrix{T}}} + U_reml :: Union{Nothing, Matrix{T}} + D_reml :: Union{Nothing, Vector{T}} + logdetV2_reml :: Union{Nothing, T} # fixed effects parameters for reml B_reml :: Union{Nothing, Matrix{T}} # working arrays for reml - Ω_reml :: Union{Nothing, Matrix{T}} - R_reml :: Union{Nothing, Matrix{T}} - storage_nd_nd_reml :: Union{Nothing, Matrix{T}} - storage_pd_pd_reml :: Union{Nothing, Matrix{T}} - storage_n_p_reml :: Union{Nothing, Matrix{T}} + ỸΦ_reml :: Union{Nothing, Matrix{T}} + R̃_reml :: Union{Nothing, Matrix{T}} + R̃Φ_reml :: Union{Nothing, Matrix{T}} + storage_nd_pd_reml :: Union{Nothing, Matrix{T}} storage_nd_1_reml :: Union{Nothing, Vector{T}} storage_nd_2_reml :: Union{Nothing, Vector{T}} - storage_n_d_reml :: Union{Nothing, Matrix{T}} - storage_p_d_reml :: Union{Nothing, Matrix{T}} + storage_pd_pd_reml :: Union{Nothing, Matrix{T}} storage_pd_reml :: Union{Nothing, Vector{T}} logl_reml :: Union{Nothing, Vector{T}} Bcov_reml :: Union{Nothing, Matrix{T}} @@ -376,25 +379,26 @@ function MRTVCModel( n, p = size(Y, 1), 0 nd, pd = n * d, p * d nd_reml, pd_reml = n_reml * d, p_reml * d + D_reml, U_reml = eigen(Symmetric(V_reml[1]), Symmetric(V_reml[2])) + logdetV2_reml = logdet(V_reml[2]) + Ỹ_reml = transpose(U_reml) * Y_reml + X̃_reml = transpose(U_reml) * X_reml B_reml = Matrix{T}(undef, p_reml, d) - Ω_reml = Matrix{T}(undef, nd_reml, nd_reml) - R_reml = Matrix{T}(undef, n_reml, d) - storage_nd_nd_reml = Matrix{T}(undef, nd_reml, nd_reml) - storage_pd_pd_reml = Matrix{T}(undef, pd_reml, pd_reml) - storage_n_p_reml = Matrix{T}(undef, n_reml, p_reml) + ỸΦ_reml = Matrix{T}(undef, n_reml, d) + R̃_reml = Matrix{T}(undef, n_reml, d) + R̃Φ_reml = Matrix{T}(undef, n_reml, d) + storage_nd_pd_reml = Matrix{T}(undef, nd_reml, pd_reml) storage_nd_1_reml = Vector{T}(undef, nd_reml) storage_nd_2_reml = Vector{T}(undef, nd_reml) - storage_n_d_reml = Matrix{T}(undef, n_reml, d) - storage_p_d_reml = Matrix{T}(undef, p_reml, d) + storage_pd_pd_reml = Matrix{T}(undef, pd_reml, pd_reml) storage_pd_reml = Vector{T}(undef, pd_reml) - logl_reml = zeros(T, 1) + logl_reml = zeros(T, 1) else - Y_reml = X_reml = V_reml = B_reml = Ω_reml = R_reml = - storage_nd_nd_reml = storage_pd_pd_reml = - storage_n_p_reml = storage_nd_1_reml = - storage_nd_2_reml = storage_n_d_reml = - storage_p_d_reml = storage_pd_reml = - logl_reml = Bcov_reml = nothing + Y_reml = Ỹ_reml = X_reml = X̃_reml = V_reml = U_reml = D_reml = + logdetV2_reml = B_reml = ỸΦ_reml = R̃_reml = R̃Φ_reml = + storage_nd_pd_reml = storage_nd_1_reml = + storage_nd_2_reml = storage_pd_pd_reml = storage_pd_reml = + logl_reml = Bcov_reml = nothing end if se Bcov = Matrix{T}(undef, pd, pd) @@ -439,11 +443,11 @@ function MRTVCModel( storage_d_1, storage_d_2, storage_d_d_1, storage_d_d_2, storage_p_p, storage_pd, storage_pd_pd, storage_nd_1, storage_nd_2, storage_nd_pd, logl, Bcov, Σcov, - Y_reml, X_reml, V_reml, B_reml, Ω_reml, R_reml, - storage_nd_nd_reml, storage_pd_pd_reml, storage_n_p_reml, - storage_nd_1_reml, storage_nd_2_reml, storage_n_d_reml, - storage_p_d_reml, storage_pd_reml, logl_reml, Bcov_reml, - se, reml + Y_reml, Ỹ_reml, X_reml, X̃_reml, V_reml, U_reml, D_reml, + logdetV2_reml, B_reml, ỸΦ_reml, R̃_reml, R̃Φ_reml, + storage_nd_pd_reml, storage_nd_1_reml, + storage_nd_2_reml, storage_pd_pd_reml, storage_pd_reml, + logl_reml, Bcov_reml, se, reml ) end diff --git a/src/eigen.jl b/src/eigen.jl index c42ec14..a95dac9 100644 --- a/src/eigen.jl +++ b/src/eigen.jl @@ -95,13 +95,13 @@ function fit!( break end end - # if model.reml - # update_B_reml!(model) - # update_res_reml!(model) - # mul!(model.R̃Φ_reml, model.R̃_reml, model.Φ) - # copyto!(model.logl_reml, loglikelihood_reml!(model)) - # model.se ? fisher_B_reml!(model) : nothing - # end + if model.reml + update_B_reml!(model) + update_res_reml!(model) + mul!(model.R̃Φ_reml, model.R̃_reml, model.Φ) + copyto!(model.logl_reml, loglikelihood_reml!(model)) + model.se ? fisher_B_reml!(model) : nothing + end log && IterativeSolvers.shrink!(history) history end @@ -183,7 +183,6 @@ function update_Φ!( copy!(model.Λ, Λ) copy!(model.Φ, Φ) copyto!(model.logdetΣ2, logdet(model.Σ[2])) - mul!(model.ỸΦ, model.Ỹ, model.Φ) end function update_res!( @@ -194,12 +193,20 @@ function update_res!( model.R̃ end +function update_res_reml!( + model :: MRTVCModel{T} + ) where T <: BlasReal + # update R̃ = Ỹ - X̃B + BLAS.gemm!('N', 'N', -one(T), model.X̃_reml, model.B_reml, one(T), copyto!(model.R̃_reml, model.Ỹ_reml)) + model.R̃ +end + function loglikelihood!( model :: MRTVCModel{T} ) where T <: BlasReal n, d = size(model.Ỹ, 1), size(model.Ỹ, 2) # assemble pieces for log-likelihood - logl = n * d * log(2π) + n * model.logdetΣ2[1] + d * model.logdetV2[1] + logl = n * d * log(2π) + n * model.logdetΣ2[1] + d * model.logdetV2 @inbounds for j in 1:d λj = model.Λ[j] @simd for i in 1:n @@ -210,9 +217,26 @@ function loglikelihood!( logl /= -2 end +function loglikelihood_reml!( + model :: MRTVCModel{T} + ) where T <: BlasReal + n, d = size(model.Ỹ_reml, 1), size(model.Ỹ_reml, 2) + # assemble pieces for log-likelihood + logl = n * d * log(2π) + n * model.logdetΣ2[1] + d * model.logdetV2_reml + @inbounds for j in 1:d + λj = model.Λ[j] + @simd for i in 1:n + tmp = model.D_reml[i] * λj + one(T) + logl += log(tmp) + inv(tmp) * model.R̃Φ_reml[i, j]^2 + end + end + logl /= -2 +end + function update_B!( model :: MRTVCModel{T} ) where T <: BlasReal + mul!(model.ỸΦ, model.Ỹ, model.Φ) # Gram matrix G = (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃) G = model.storage_pd_pd fill!(model.storage_nd_pd, zero(T)) @@ -236,31 +260,32 @@ function update_B!( model.B end -# function update_B_reml!( -# model :: MRTVCModel{T} -# ) where T <: BlasReal -# # Gram matrix G = (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃) -# G = model.storage_pd_pd -# fill!(model.storage_nd_pd_reml, zero(T)) -# kron_axpy!(transpose(model.Φ), model.X̃_reml, model.storage_nd_pd_reml) -# fill!(model.storage_nd_1_reml, zero(T)) -# kron_axpy!(model.Λ, model.D_reml, model.storage_nd_1_reml) -# @inbounds @simd for i in eachindex(model.storage_nd_1_reml) -# model.storage_nd_1_reml[i] = one(T) / sqrt(model.storage_nd_1_reml[i] + one(T)) -# end -# lmul!(Diagonal(model.storage_nd_1_reml), model.storage_nd_pd_reml) -# mul!(G, transpose(model.storage_nd_pd_reml), model.storage_nd_pd_reml) -# # (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹vec(ỸΦ) -# copyto!(model.storage_nd_2_reml, model.ỸΦ_reml) -# model.storage_nd_2_reml .= model.storage_nd_1_reml .* model.storage_nd_2_reml -# mul!(model.storage_pd, transpose(model.storage_nd_pd_reml), model.storage_nd_2_reml) -# # Cholesky solve -# _, info = LAPACK.potrf!('U', G) -# info > 0 && throw("Gram matrix (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃) is singular") -# LAPACK.potrs!('U', G, model.storage_pd) -# copyto!(model.B, model.storage_pd) -# model.B -# end +function update_B_reml!( + model :: MRTVCModel{T} + ) where T <: BlasReal + mul!(model.ỸΦ_reml, model.Ỹ_reml, model.Φ) + # Gram matrix G = (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃) + G = model.storage_pd_pd_reml + fill!(model.storage_nd_pd_reml, zero(T)) + kron_axpy!(transpose(model.Φ), model.X̃_reml, model.storage_nd_pd_reml) + fill!(model.storage_nd_1_reml, zero(T)) + kron_axpy!(model.Λ, model.D_reml, model.storage_nd_1_reml) + @inbounds @simd for i in eachindex(model.storage_nd_1_reml) + model.storage_nd_1_reml[i] = one(T) / sqrt(model.storage_nd_1_reml[i] + one(T)) + end + lmul!(Diagonal(model.storage_nd_1_reml), model.storage_nd_pd_reml) + mul!(G, transpose(model.storage_nd_pd_reml), model.storage_nd_pd_reml) + # (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹vec(ỸΦ) + copyto!(model.storage_nd_2_reml, model.ỸΦ_reml) + model.storage_nd_2_reml .= model.storage_nd_1_reml .* model.storage_nd_2_reml + mul!(model.storage_pd_reml, transpose(model.storage_nd_pd_reml), model.storage_nd_2_reml) + # Cholesky solve + _, info = LAPACK.potrf!('U', G) + info > 0 && throw("Gram matrix (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃) is singular") + LAPACK.potrs!('U', G, model.storage_pd_reml) + copyto!(model.B_reml, model.storage_pd_reml) + model.B_reml +end function fisher_B!( model :: MRTVCModel{T} @@ -277,6 +302,21 @@ function fisher_B!( copyto!(model.Bcov, pinv(model.storage_pd_pd)) end +function fisher_B_reml!( + model :: MRTVCModel{T} + ) where T <: BlasReal + fill!(model.storage_nd_pd_reml, zero(T)) + kron_axpy!(transpose(model.Φ), model.X̃_reml, model.storage_nd_pd_reml) + fill!(model.storage_nd_1_reml, zero(T)) + kron_axpy!(model.Λ, model.D_reml, model.storage_nd_1_reml) + @inbounds @simd for i in eachindex(model.storage_nd_1_reml) + model.storage_nd_1_reml[i] = one(T) / sqrt(model.storage_nd_1_reml[i] + one(T)) + end + lmul!(Diagonal(model.storage_nd_1_reml), model.storage_nd_pd_reml) + mul!(model.storage_pd_pd_reml, transpose(model.storage_nd_pd_reml), model.storage_nd_pd_reml) + copyto!(model.Bcov_reml, pinv(model.storage_pd_pd_reml)) +end + function fisher_Σ!( model :: MRTVCModel{T} ) where T <: BlasReal