Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work out good approach for automatic Hessian sparsity detection #13

Closed
ElOceanografo opened this issue Jun 1, 2023 · 14 comments
Closed

Comments

@ElOceanografo
Copy link
Owner

SparsityDetection.jl is no longer maintained, and the sparsity detection functionality in Symbolics.jl still seems somewhat brittle. Currently this package just uses ForwardDiff, but that is a suboptimal solution (very slow in high dimensions).

@ElOceanografo
Copy link
Owner Author

Should be solved by some combination of SparseConnectivityTracer.jl and the solution to gdalle/DifferentiationInterface.jl#263

@gdalle
Copy link

gdalle commented May 30, 2024

Now that SparseConnectivityTracer supports linear algebra, can you see if this is enough for your use case?
Note that to support linear algebra, you will have to use local sparsity detection, which means the sparsity pattern cannot be reused between calls because it depends on the value of x. If you need a global sparsity pattern, do tell us which linear-algebraic functions you use and we can add specific overloads for them?

@ElOceanografo
Copy link
Owner Author

That's great! The main functions that are both important to have, and difficult to handle till now, are log-determinants (of both dense and sparse matrices). I need them to calculate loglikelihoods of Gaussian Markov random fields (aka MV Normals paramaterized by a precision matrix instead of a covariance). If I can trace through dense and sparse versions of the MWE function in gdalle/DifferentiationInterface.jl#263 I will be pretty happy:

const y = randn(10)

function f_dense(u)
    Q = diagm(exp.(u))
    return logdet(Q) - y' * Q * y
end

function f_sparse(u)
    Q = spdiagm(exp.(u))
    return logdet(Q) - y' * Q * y
end

In most problems, the structure of Q will be constant, so in theory the sparsity pattern should not change between calls.

Currently, SCT can handle the dense one with local tracing, but gets a stack overflow error with the sparse one.

@gdalle
Copy link

gdalle commented May 31, 2024

I opened an issue to keep track. It's weird that it doesn't happen with det, only logdet

@gdalle
Copy link

gdalle commented May 31, 2024

Can you use this code in the meantime? I still haven't decided whether it belongs in DI itself or just in the docs as an example, given how easy and short it is

using ADTypes
using DifferentiationInterface
using SparseArrays

struct DenseSparsityDetector{B} <: ADTypes.AbstractSparsityDetector
    backend::B
    atol::Float64
end

function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector)
    J = jacobian(f, detector.backend, x)
    return sparse(abs.(J) .> detector.atol)
end

function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector)
    J = jacobian(f!, y, detector.backend, x)
    return sparse(abs.(J) .> detector.atol)
end

function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector)
    H = hessian(f, detector.backend, x)
    return sparse(abs.(H) .> detector.atol)
end

@ElOceanografo
Copy link
Owner Author

ElOceanografo commented Jun 1, 2024

Yes, that solution is probably good enough to close this issue, and I already have it as a patch on a development branch here. My $0.02 is it would be great to have it in DI unless there's a compelling reason not to.

@gdalle
Copy link

gdalle commented Jun 4, 2024

@gdalle
Copy link

gdalle commented Jun 4, 2024

DenseSparsityDetector is part of the newly released DifferentiationInterface v0.5.3. Can you take it out for a spin?

@ElOceanografo
Copy link
Owner Author

Working here, will close this issue once I merge that branch. Thanks for the quick fix!

@gdalle
Copy link

gdalle commented Jun 5, 2024

Awesome! I have a few remarks on the PR, will put them here:

- `hess_adtype = nothing` : Specifies how to calculate the Hessian of the marginalized
variables. If not specified, defaults to a sparse second-order method using finite
differences over the AD type given in the `method` (`AutoForwardDiff()` is the default).
Other backends can be set by loading the appropriate AD package and using the ADTypes.jl
interface.

Why pick finite differences over forward as the default second order method? Forward over reverse would be much faster for large problems

- `sparsity_detector = DenseSparsityDetector(method.adtype, atol=cbrt(eps))` : How to
perform the sparsity detection. Detecting sparsity takes some time and may not be worth it
for small problems, but for larger problems it can be extremely worth it. The default
`DenseSparsityDetector` is most robust, but if it's too slow, or if you're running out of
memory on a larger problem, try the tracing-based dectectors from SparseConnectivityTracer.jl.
- `coloring_algorithm = GreedyColoringAlgorithm()` : How to determine the matrix "colors"
to compress the sparse Hessian.

Why separate the sparsity detector, coloring algorithm and backend? With the ADTypes AutoSparse struct, the user can provide all of them at once, with less bookkeeping for you.

Also keep in mind that the DenseSparsityDetector is local by nature. To avoid accidental cancellations and incorrect sparsity patterns, you should make sure that your function has no value-dependent control flow, and pick a random point x for evaluating it.

extras = prepare_hessian(w -> f(w, p2), hess_adtype, w)
H = hessian(w -> f(w, p2), hess_adtype, w, extras)

I'm not entirely sure that preparation is correct when you have different function objects such as two anonymous functions. I think it's better to define g = Base.Fix2(f, p2) for instance, and use it several times.

@gdalle
Copy link

gdalle commented Jun 5, 2024

Also I don't think your test failures on nightly are related to DI. I think it's Reexport.jl interacting badly with the new public keyword

@ElOceanografo
Copy link
Owner Author

Most of these decision were made to err on the side of reliability (at least for now). If I'm showing this package to someone who is currently using R/TMB, I'd rather they encounter something that runs slower than they'd like, rather than something that breaks in a situation that they are used to having work. In my anecdotal experience, it's not uncommon to try some model that should work, but breaks because of some AD corner case, hence my conservatism. What I really need to do at this point is work up a more comprehensive set of realistic problems and test them systematically for performance comparisons, and to find bugs.

Why pick finite differences over forward as the default second order method? Forward over reverse would be much faster for large problems

I made FiniteDiff the default because it is guaranteed to work with any inner backend, and is often not actually much slower than ForwardDiff if the Hessian is sparse enough.

Why separate the sparsity detector, coloring algorithm and backend? With the ADTypes AutoSparse struct, the user can provide all of them at once, with less bookkeeping for you.

True, not sure why I made it this way now. Will probably change it.

Also keep in mind that the DenseSparsityDetector is local by nature. To avoid accidental cancellations and incorrect sparsity patterns, you should make sure that your function has no value-dependent control flow, and pick a random point x for evaluating it.

Good to know.

I'm not entirely sure that preparation is correct when you have different function objects such as two anonymous functions. I think it's better to define g = Base.Fix2(f, p2) for instance, and use it several times.

Also good to know, and might explain some odd timings I've seen. Will try it out!

@ElOceanografo
Copy link
Owner Author

Closed via #25

@gdalle
Copy link

gdalle commented Jun 5, 2024

Also keep in mind that the DenseSparsityDetector is local by nature. To avoid accidental cancellations and incorrect sparsity patterns, you should make sure that your function has no value-dependent control flow, and pick a random point x for evaluating it.

Good to know.

The docstring for DenseSparsityDetector explains this in more detail:

https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/api/#DifferentiationInterface.DenseSparsityDetector

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants