Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chol, qr, lq gradients #55

Merged
merged 2 commits into from
Mar 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,46 @@ dense2arg = Dict{Symbol,Any}(
# cond
# sylvester
# lyap

@primitive chol(x),dy,y chol_back(y, dy)

# ref: formulua in https://arxiv.org/pdf/1602.07527.pdf
# as described in https://arxiv.org/pdf/1710.08717.pdf
# In julia L is upper triangular
function chol_back(L, dL)
dL = triu(dL)
iL = inv(L)
S = iL * Symmetric(L*dL',:L) * iL'
S/2
end

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CarloLucibello let's not break the AutoGrad's regular testing pipeline. Could you pls. add

  1. addtest(:chol, rand(3,3))
  2. I think we need to find a way to write similar tests for qr and lq and remove them from the test/primitives.jl

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to keep tests this way. It is kind of atypical in the julia ecosystem to add tests into the src/ folder, and I don't want to scratch my head to figure out how to add tests for qr in the non-standard way while the easy way is also the standard way

@primitive lq(x),dy,y lq_back(y, dy)


# ref: https://arxiv.org/pdf/1710.08717.pdf
function lq_back(y, dy)
L, Q = y
dL, dQ = dy
dL == nothing && (dL = zeros(L))
dQ == nothing && (dQ = zeros(Q))
dL = tril(dL)
M = Symmetric(L'dL - dQ*Q', :L)
S = inv(L)' *(dQ + M*Q)
S
end

@primitive qr(x),dy,y qr_back(y, dy)

function qr_back(y, dy)
Q, R = y
dQ, dR = dy
dR == nothing && (dR = zeros(R))
dQ == nothing && (dQ = zeros(Q))
dR = triu(dR)
M = Symmetric(R*dR' - dQ'*Q, :L)
S = (dQ + Q*M)*inv(R)'
S
end



12 changes: 12 additions & 0 deletions test/primitives.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
include("header.jl")

@testset "primitives" begin
@test gradcheck(x->chol(x'x), rand(3,3))

for t in AutoGrad.alltests()
#@show t
@test gradcheck(eval(AutoGrad,t[1]), t[2:end]...)
end

@test gradcheck(x->chol(x'x), rand(3,3))

@test gradcheck(x->qr(x)[1], rand(3,3))
@test gradcheck(x->qr(x)[2], rand(3,3))
@test gradcheck(x->(y=qr(x); sum(y[1]+y[2])), rand(3,3))

@test gradcheck(x->lq(x)[1], rand(3,3))
@test gradcheck(x->lq(x)[2], rand(3,3))
@test gradcheck(x->(y=lq(x); sum(y[1]+y[2])), rand(3,3))
end

nothing