Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamax is weirdly slow in my HMM benchmark #359

Open
gdalle opened this issue Feb 22, 2024 · 0 comments
Open

dynamax is weirdly slow in my HMM benchmark #359

gdalle opened this issue Feb 22, 2024 · 0 comments

Comments

@gdalle
Copy link

gdalle commented Feb 22, 2024

Hi, and congrats on the amazing package!

I have developed an HMM library in Julia called HiddenMarkovModels.jl, and I am currently benchmarking it against the Python alternatives (see here for the feature comparison). I want to benchmark fairly but I'm a JAX newbie, so I was wondering if someone might advise me on possible suboptimalities in my dynamax code?

My test case is an HMM with scalar Gaussian emissions and 100 sequences of length 200 each. I'm interested in small-ish models, which is also why I run everything on the CPU.
When I time the forward, forward-backward and Viterbi algorithms, dynamax is among the fastest packages. However I observe a significant slowdown in the EM algorithm, so perhaps something is wrong there (see plots below).
The three inference algorithms have been jit-ed and vmap-ed for multiple sequences, but I don't know if I can do the same with EM learning. Any suggestions are welcome!

forward-1

baum_welch-1

The benchmark is run from Julia with PythonCall.jl, so don't freak out at the weird syntax.
Here are the main bits:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant