From ce86cc5d5ca85ca76f3cee11e4e98e945e0cf22b Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Tue, 20 Feb 2018 10:51:30 -0500 Subject: [PATCH 1/2] cholesky --- src/linalg/dense.jl | 18 ++++++++++++++++++ test/primitives.jl | 2 ++ 2 files changed, 20 insertions(+) diff --git a/src/linalg/dense.jl b/src/linalg/dense.jl index cd6f4f3..a863034 100644 --- a/src/linalg/dense.jl +++ b/src/linalg/dense.jl @@ -37,3 +37,21 @@ dense2arg = Dict{Symbol,Any}( # cond # sylvester # lyap + +@primitive chol(x),dy,y chol_back(y, dy) + +chol_ϕ(A) = tril(A) - 0.5diagm(diag(A)) + +# ref: Iain Murray's https://arxiv.org/pdf/1602.07527.pdf +# difference with the paper with respect to the paper: +# julia L is upper triangular and we do not need the +# final simmetrization of S +function chol_back(L, dL) + dL = triu(dL) + iL = inv(L) + S = iL * chol_ϕ(L*dL') * iL' + # S + S' - diagm(diag(S)) + S +end + + diff --git a/test/primitives.jl b/test/primitives.jl index 0f74da4..38171e8 100644 --- a/test/primitives.jl +++ b/test/primitives.jl @@ -1,6 +1,8 @@ include("header.jl") @testset "primitives" begin + @test gradcheck(x->chol(x'x), rand(3,3)) + for t in AutoGrad.alltests() #@show t @test gradcheck(eval(AutoGrad,t[1]), t[2:end]...) From dbef75680df916787ba56b88ba5cc96cf4d04742 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Tue, 20 Feb 2018 17:17:08 -0500 Subject: [PATCH 2/2] add lq and qr --- src/linalg/dense.jl | 41 +++++++++++++++++++++++++++++++++-------- test/primitives.jl | 10 ++++++++++ 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/src/linalg/dense.jl b/src/linalg/dense.jl index a863034..4f6dc30 100644 --- a/src/linalg/dense.jl +++ b/src/linalg/dense.jl @@ -40,18 +40,43 @@ dense2arg = Dict{Symbol,Any}( @primitive chol(x),dy,y chol_back(y, dy) -chol_ϕ(A) = tril(A) - 0.5diagm(diag(A)) - -# ref: Iain Murray's https://arxiv.org/pdf/1602.07527.pdf -# difference with the paper with respect to the paper: -# julia L is upper triangular and we do not need the -# final simmetrization of S +# ref: formulua in https://arxiv.org/pdf/1602.07527.pdf +# as described in https://arxiv.org/pdf/1710.08717.pdf +# In julia L is upper triangular function chol_back(L, dL) dL = triu(dL) iL = inv(L) - S = iL * chol_ϕ(L*dL') * iL' - # S + S' - diagm(diag(S)) + S = iL * Symmetric(L*dL',:L) * iL' + S/2 +end + +@primitive lq(x),dy,y lq_back(y, dy) + + +# ref: https://arxiv.org/pdf/1710.08717.pdf +function lq_back(y, dy) + L, Q = y + dL, dQ = dy + dL == nothing && (dL = zeros(L)) + dQ == nothing && (dQ = zeros(Q)) + dL = tril(dL) + M = Symmetric(L'dL - dQ*Q', :L) + S = inv(L)' *(dQ + M*Q) S end +@primitive qr(x),dy,y qr_back(y, dy) + +function qr_back(y, dy) + Q, R = y + dQ, dR = dy + dR == nothing && (dR = zeros(R)) + dQ == nothing && (dQ = zeros(Q)) + dR = triu(dR) + M = Symmetric(R*dR' - dQ'*Q, :L) + S = (dQ + Q*M)*inv(R)' + S +end + + diff --git a/test/primitives.jl b/test/primitives.jl index 38171e8..3214efb 100644 --- a/test/primitives.jl +++ b/test/primitives.jl @@ -7,6 +7,16 @@ include("header.jl") #@show t @test gradcheck(eval(AutoGrad,t[1]), t[2:end]...) end + + @test gradcheck(x->chol(x'x), rand(3,3)) + + @test gradcheck(x->qr(x)[1], rand(3,3)) + @test gradcheck(x->qr(x)[2], rand(3,3)) + @test gradcheck(x->(y=qr(x); sum(y[1]+y[2])), rand(3,3)) + + @test gradcheck(x->lq(x)[1], rand(3,3)) + @test gradcheck(x->lq(x)[2], rand(3,3)) + @test gradcheck(x->(y=lq(x); sum(y[1]+y[2])), rand(3,3)) end nothing