You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have developed an HMM library in Julia called HiddenMarkovModels.jl, and I am currently benchmarking it against the Python alternatives (see here for the feature comparison). I want to benchmark fairly but I'm a JAX newbie, so I was wondering if someone might advise me on possible suboptimalities in my dynamax code?
My test case is an HMM with scalar Gaussian emissions and 100 sequences of length 200 each. I'm interested in small-ish models, which is also why I run everything on the CPU.
When I time the forward, forward-backward and Viterbi algorithms, dynamax is among the fastest packages. However I observe a significant slowdown in the EM algorithm, so perhaps something is wrong there (see plots below).
The three inference algorithms have been jit-ed and vmap-ed for multiple sequences, but I don't know if I can do the same with EM learning. Any suggestions are welcome!
The benchmark is run from Julia with PythonCall.jl, so don't freak out at the weird syntax.
Here are the main bits:
Hi, and congrats on the amazing package!
I have developed an HMM library in Julia called HiddenMarkovModels.jl, and I am currently benchmarking it against the Python alternatives (see here for the feature comparison). I want to benchmark fairly but I'm a JAX newbie, so I was wondering if someone might advise me on possible suboptimalities in my dynamax code?
My test case is an HMM with scalar Gaussian emissions and 100 sequences of length 200 each. I'm interested in small-ish models, which is also why I run everything on the CPU.
When I time the forward, forward-backward and Viterbi algorithms, dynamax is among the fastest packages. However I observe a significant slowdown in the EM algorithm, so perhaps something is wrong there (see plots below).
The three inference algorithms have been
jit
-ed andvmap
-ed for multiple sequences, but I don't know if I can do the same with EM learning. Any suggestions are welcome!The benchmark is run from Julia with PythonCall.jl, so don't freak out at the weird syntax.
Here are the main bits:
The text was updated successfully, but these errors were encountered: