Skip to content

Commit

Permalink
reml for eigen
Browse files Browse the repository at this point in the history
  • Loading branch information
mmkim1210 committed Jun 20, 2024
1 parent edf6493 commit 0ce5b4e
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 60 deletions.
56 changes: 30 additions & 26 deletions src/MultiResponseVarianceComponentModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,20 +325,23 @@ struct MRTVCModel{T <: BlasReal} <: VCModel
Σcov :: Union{Nothing, Matrix{T}} # for fisher_Σ!
# original data for reml
Y_reml :: Union{Nothing, Matrix{T}}
Ỹ_reml :: Union{Nothing, Matrix{T}}
X_reml :: Union{Nothing, Matrix{T}}
X̃_reml :: Union{Nothing, Matrix{T}}
V_reml :: Union{Nothing, Vector{Matrix{T}}}
U_reml :: Union{Nothing, Matrix{T}}
D_reml :: Union{Nothing, Vector{T}}
logdetV2_reml :: Union{Nothing, T}
# fixed effects parameters for reml
B_reml :: Union{Nothing, Matrix{T}}
# working arrays for reml
Ω_reml :: Union{Nothing, Matrix{T}}
R_reml :: Union{Nothing, Matrix{T}}
storage_nd_nd_reml :: Union{Nothing, Matrix{T}}
storage_pd_pd_reml :: Union{Nothing, Matrix{T}}
storage_n_p_reml :: Union{Nothing, Matrix{T}}
ỸΦ_reml :: Union{Nothing, Matrix{T}}
R̃_reml :: Union{Nothing, Matrix{T}}
R̃Φ_reml :: Union{Nothing, Matrix{T}}
storage_nd_pd_reml :: Union{Nothing, Matrix{T}}
storage_nd_1_reml :: Union{Nothing, Vector{T}}
storage_nd_2_reml :: Union{Nothing, Vector{T}}
storage_n_d_reml :: Union{Nothing, Matrix{T}}
storage_p_d_reml :: Union{Nothing, Matrix{T}}
storage_pd_pd_reml :: Union{Nothing, Matrix{T}}
storage_pd_reml :: Union{Nothing, Vector{T}}
logl_reml :: Union{Nothing, Vector{T}}
Bcov_reml :: Union{Nothing, Matrix{T}}
Expand Down Expand Up @@ -376,25 +379,26 @@ function MRTVCModel(
n, p = size(Y, 1), 0
nd, pd = n * d, p * d
nd_reml, pd_reml = n_reml * d, p_reml * d
D_reml, U_reml = eigen(Symmetric(V_reml[1]), Symmetric(V_reml[2]))
logdetV2_reml = logdet(V_reml[2])
Ỹ_reml = transpose(U_reml) * Y_reml
X̃_reml = transpose(U_reml) * X_reml
B_reml = Matrix{T}(undef, p_reml, d)
Ω_reml = Matrix{T}(undef, nd_reml, nd_reml)
R_reml = Matrix{T}(undef, n_reml, d)
storage_nd_nd_reml = Matrix{T}(undef, nd_reml, nd_reml)
storage_pd_pd_reml = Matrix{T}(undef, pd_reml, pd_reml)
storage_n_p_reml = Matrix{T}(undef, n_reml, p_reml)
ỸΦ_reml = Matrix{T}(undef, n_reml, d)
R̃_reml = Matrix{T}(undef, n_reml, d)
R̃Φ_reml = Matrix{T}(undef, n_reml, d)
storage_nd_pd_reml = Matrix{T}(undef, nd_reml, pd_reml)
storage_nd_1_reml = Vector{T}(undef, nd_reml)
storage_nd_2_reml = Vector{T}(undef, nd_reml)
storage_n_d_reml = Matrix{T}(undef, n_reml, d)
storage_p_d_reml = Matrix{T}(undef, p_reml, d)
storage_pd_pd_reml = Matrix{T}(undef, pd_reml, pd_reml)
storage_pd_reml = Vector{T}(undef, pd_reml)
logl_reml = zeros(T, 1)
logl_reml = zeros(T, 1)
else
Y_reml = X_reml = V_reml = B_reml = Ω_reml = R_reml =
storage_nd_nd_reml = storage_pd_pd_reml =
storage_n_p_reml = storage_nd_1_reml =
storage_nd_2_reml = storage_n_d_reml =
storage_p_d_reml = storage_pd_reml =
logl_reml = Bcov_reml = nothing
Y_reml = Ỹ_reml = X_reml = X̃_reml = V_reml = U_reml = D_reml =
logdetV2_reml = B_reml = ỸΦ_reml = R̃_reml = R̃Φ_reml =
storage_nd_pd_reml = storage_nd_1_reml =
storage_nd_2_reml = storage_pd_pd_reml = storage_pd_reml =
logl_reml = Bcov_reml = nothing
end
if se
Bcov = Matrix{T}(undef, pd, pd)
Expand Down Expand Up @@ -439,11 +443,11 @@ function MRTVCModel(
storage_d_1, storage_d_2, storage_d_d_1, storage_d_d_2,
storage_p_p, storage_pd, storage_pd_pd,
storage_nd_1, storage_nd_2, storage_nd_pd, logl, Bcov, Σcov,
Y_reml, X_reml, V_reml, B_reml, Ω_reml, R_reml,
storage_nd_nd_reml, storage_pd_pd_reml, storage_n_p_reml,
storage_nd_1_reml, storage_nd_2_reml, storage_n_d_reml,
storage_p_d_reml, storage_pd_reml, logl_reml, Bcov_reml,
se, reml
Y_reml, Ỹ_reml, X_reml, X̃_reml, V_reml, U_reml, D_reml,
logdetV2_reml, B_reml, ỸΦ_reml, R̃_reml, R̃Φ_reml,
storage_nd_pd_reml, storage_nd_1_reml,
storage_nd_2_reml, storage_pd_pd_reml, storage_pd_reml,
logl_reml, Bcov_reml, se, reml
)
end

Expand Down
108 changes: 74 additions & 34 deletions src/eigen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ function fit!(
break
end
end
# if model.reml
# update_B_reml!(model)
# update_res_reml!(model)
# mul!(model.R̃Φ_reml, model.R̃_reml, model.Φ)
# copyto!(model.logl_reml, loglikelihood_reml!(model))
# model.se ? fisher_B_reml!(model) : nothing
# end
if model.reml
update_B_reml!(model)
update_res_reml!(model)
mul!(model.R̃Φ_reml, model.R̃_reml, model.Φ)
copyto!(model.logl_reml, loglikelihood_reml!(model))
model.se ? fisher_B_reml!(model) : nothing
end
log && IterativeSolvers.shrink!(history)
history
end
Expand Down Expand Up @@ -183,7 +183,6 @@ function update_Φ!(
copy!(model.Λ, Λ)
copy!(model.Φ, Φ)
copyto!(model.logdetΣ2, logdet(model.Σ[2]))
mul!(model.ỸΦ, model.Ỹ, model.Φ)
end

function update_res!(
Expand All @@ -194,12 +193,20 @@ function update_res!(
model.
end

function update_res_reml!(
model :: MRTVCModel{T}
) where T <: BlasReal
# update R̃ = Ỹ - X̃B
BLAS.gemm!('N', 'N', -one(T), model.X̃_reml, model.B_reml, one(T), copyto!(model.R̃_reml, model.Ỹ_reml))
model.
end

function loglikelihood!(
model :: MRTVCModel{T}
) where T <: BlasReal
n, d = size(model.Ỹ, 1), size(model.Ỹ, 2)
# assemble pieces for log-likelihood
logl = n * d * log(2π) + n * model.logdetΣ2[1] + d * model.logdetV2[1]
logl = n * d * log(2π) + n * model.logdetΣ2[1] + d * model.logdetV2
@inbounds for j in 1:d
λj = model.Λ[j]
@simd for i in 1:n
Expand All @@ -210,9 +217,26 @@ function loglikelihood!(
logl /= -2
end

function loglikelihood_reml!(
model :: MRTVCModel{T}
) where T <: BlasReal
n, d = size(model.Ỹ_reml, 1), size(model.Ỹ_reml, 2)
# assemble pieces for log-likelihood
logl = n * d * log(2π) + n * model.logdetΣ2[1] + d * model.logdetV2_reml
@inbounds for j in 1:d
λj = model.Λ[j]
@simd for i in 1:n
tmp = model.D_reml[i] * λj + one(T)
logl += log(tmp) + inv(tmp) * model.R̃Φ_reml[i, j]^2
end
end
logl /= -2
end

function update_B!(
model :: MRTVCModel{T}
) where T <: BlasReal
mul!(model.ỸΦ, model.Ỹ, model.Φ)
# Gram matrix G = (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃)
G = model.storage_pd_pd
fill!(model.storage_nd_pd, zero(T))
Expand All @@ -236,31 +260,32 @@ function update_B!(
model.B
end

# function update_B_reml!(
# model :: MRTVCModel{T}
# ) where T <: BlasReal
# # Gram matrix G = (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃)
# G = model.storage_pd_pd
# fill!(model.storage_nd_pd_reml, zero(T))
# kron_axpy!(transpose(model.Φ), model.X̃_reml, model.storage_nd_pd_reml)
# fill!(model.storage_nd_1_reml, zero(T))
# kron_axpy!(model.Λ, model.D_reml, model.storage_nd_1_reml)
# @inbounds @simd for i in eachindex(model.storage_nd_1_reml)
# model.storage_nd_1_reml[i] = one(T) / sqrt(model.storage_nd_1_reml[i] + one(T))
# end
# lmul!(Diagonal(model.storage_nd_1_reml), model.storage_nd_pd_reml)
# mul!(G, transpose(model.storage_nd_pd_reml), model.storage_nd_pd_reml)
# # (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹vec(ỸΦ)
# copyto!(model.storage_nd_2_reml, model.ỸΦ_reml)
# model.storage_nd_2_reml .= model.storage_nd_1_reml .* model.storage_nd_2_reml
# mul!(model.storage_pd, transpose(model.storage_nd_pd_reml), model.storage_nd_2_reml)
# # Cholesky solve
# _, info = LAPACK.potrf!('U', G)
# info > 0 && throw("Gram matrix (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃) is singular")
# LAPACK.potrs!('U', G, model.storage_pd)
# copyto!(model.B, model.storage_pd)
# model.B
# end
function update_B_reml!(
model :: MRTVCModel{T}
) where T <: BlasReal
mul!(model.ỸΦ_reml, model.Ỹ_reml, model.Φ)
# Gram matrix G = (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃)
G = model.storage_pd_pd_reml
fill!(model.storage_nd_pd_reml, zero(T))
kron_axpy!(transpose(model.Φ), model.X̃_reml, model.storage_nd_pd_reml)
fill!(model.storage_nd_1_reml, zero(T))
kron_axpy!(model.Λ, model.D_reml, model.storage_nd_1_reml)
@inbounds @simd for i in eachindex(model.storage_nd_1_reml)
model.storage_nd_1_reml[i] = one(T) / sqrt(model.storage_nd_1_reml[i] + one(T))
end
lmul!(Diagonal(model.storage_nd_1_reml), model.storage_nd_pd_reml)
mul!(G, transpose(model.storage_nd_pd_reml), model.storage_nd_pd_reml)
# (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹vec(ỸΦ)
copyto!(model.storage_nd_2_reml, model.ỸΦ_reml)
model.storage_nd_2_reml .= model.storage_nd_1_reml .* model.storage_nd_2_reml
mul!(model.storage_pd_reml, transpose(model.storage_nd_pd_reml), model.storage_nd_2_reml)
# Cholesky solve
_, info = LAPACK.potrf!('U', G)
info > 0 && throw("Gram matrix (Φ'⊗X̃)'(Λ⊗D + Ind)⁻¹(Φ'⊗X̃) is singular")
LAPACK.potrs!('U', G, model.storage_pd_reml)
copyto!(model.B_reml, model.storage_pd_reml)
model.B_reml
end

function fisher_B!(
model :: MRTVCModel{T}
Expand All @@ -277,6 +302,21 @@ function fisher_B!(
copyto!(model.Bcov, pinv(model.storage_pd_pd))
end

function fisher_B_reml!(
model :: MRTVCModel{T}
) where T <: BlasReal
fill!(model.storage_nd_pd_reml, zero(T))
kron_axpy!(transpose(model.Φ), model.X̃_reml, model.storage_nd_pd_reml)
fill!(model.storage_nd_1_reml, zero(T))
kron_axpy!(model.Λ, model.D_reml, model.storage_nd_1_reml)
@inbounds @simd for i in eachindex(model.storage_nd_1_reml)
model.storage_nd_1_reml[i] = one(T) / sqrt(model.storage_nd_1_reml[i] + one(T))
end
lmul!(Diagonal(model.storage_nd_1_reml), model.storage_nd_pd_reml)
mul!(model.storage_pd_pd_reml, transpose(model.storage_nd_pd_reml), model.storage_nd_pd_reml)
copyto!(model.Bcov_reml, pinv(model.storage_pd_pd_reml))
end

function fisher_Σ!(
model :: MRTVCModel{T}
) where T <: BlasReal
Expand Down

0 comments on commit 0ce5b4e

Please sign in to comment.