Replies: 4 comments 3 replies
-
I think one of the keys questions is:
it is true that e.g. batched models which took a huge amount of work in #503 could've been achieved quite easily with the slowest part in the function evaluation are the interoplators: I put some initial studies here: https://github.com/lukasheinrich/pyhf-benchmarks/blob/master/colab/jax_benchmarks.ipynb generally I would expect an upside to appear in places where the JIT can e.g. avoid temporaries (such as those einsum results) by jitting them away there is also a "gradual" backend idea in which define very specific "levels" in which a backend can appear..
|
Beta Was this translation helpful? Give feedback.
-
checking a single interp code (code1) re: establishing whether or not we can gain significant upside: it seems like our backend-agnostic one is not too shabby
|
Beta Was this translation helpful? Give feedback.
-
maybe a more general comment on "functional refactoring" . I would stay that pyhf currently is not too object-y. The interpolators are essentially just functors while the Model just holds a few data items and provides the logpdf method.. |
Beta Was this translation helpful? Give feedback.
-
leaving this here: from numba import jit
import numpy as np
import pyhf
import jax
import jax.experimental.loops
import jax.numpy as jnp
pyhf.set_backend('jax')
@jit(nopython=True)
def numba_func(h,a):
N = h.shape[0]
M = h.shape[1]
O = h.shape[3]
B = a.shape[1]
out = np.zeros((N,M,B,O))
for i in np.arange(N):
for j in np.arange(M):
for k in np.arange(O):
delta_up = h[i,j,2,k]/h[i,j,1,k]
delta_down = h[i,j,0,k]/h[i,j,1,k]
for l in np.arange(B):
alpha = a[i,l]
update = delta_up ** alpha if alpha > 0 else delta_down ** (-alpha)
out[(i,j,l,k)] = update
return out
def jax_func(h,a):
N = h.shape[0]
M = h.shape[1]
O = h.shape[3]
B = a.shape[1]
with jax.experimental.loops.Scope() as s:
s.out = jnp.zeros((N,M,B,O))
for i in s.range(N):
for j in s.range(M):
for k in s.range(O):
delta_up = h[i,j,2,k]/h[i,j,1,k]
delta_down = h[i,j,0,k]/h[i,j,1,k]
for l in s.range(B):
alpha = a[i,l]
update = jnp.where(alpha > 0,jnp.power(delta_up, alpha),jnp.power(delta_down, (-alpha)))
s.out = jax.ops.index_update(s.out, (i,j,l,k),update)
return s.out
h = np.random.normal(size = (10,10,3,1000))
a = np.random.normal(size = (10,1))
numba_func(h,a); #jit
jax_h = jnp.array(h)
jax_a = jnp.array(a)
jax_func = jax.jit(jax_func)
jax_func(jax_h,jax_a); #jit %%timeit
numba_func(h,a)
%%timeit
jax_func(h,a)
%%timeit
pyhf.interpolators.code1(h)(a)
|
Beta Was this translation helpful? Give feedback.
-
This is me testing GH discussions :)
So, something I briefly talked about with @lukasheinrich is the idea of a jax-only pyhf implementation. I think this could have some advantages:
pyhf.optimize
jax.vmap
andjax.jit
without the multiple backend problempyhf.Model.logpdf
by implementing something 'pure' by jax standards that had signaturelogpdf(model, pars, data)
(but im pretty sure one could do that with a class method too)totally just an idea, but I could put some proper time into this if people would be willing to answer all my stupid questions about the current implementation :p
Beta Was this translation helpful? Give feedback.
All reactions