forked from jstac/sandpit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lininterp.py
49 lines (31 loc) · 912 Bytes
/
lininterp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
from numba import jit
@jit(nopython=True)
def interp1d(grid, vals, x):
"""
Linearly interpolate (grid, vals) to evaluate at x.
Parameters
----------
grid and vals are numpy arrays, x is a float
Returns
-------
a float, the interpolated value
"""
a, b, G = np.min(grid), np.max(grid), len(grid)
s = (x - a) / (b - a)
q_0 = max(min(int(s * (G - 1)), (G - 2)), 0)
v_0 = vals[q_0]
v_1 = vals[q_0 + 1]
λ = s * (G - 1) - q_0
return (1 - λ) * v_0 + λ * v_1
@jit(nopython=True)
def interp1d_vectorized(grid, vals, x_vec):
"""
Linearly interpolate (grid, vals) to evaluate at x_vec.
All inputs are numpy arrays.
Return value is a numpy array of length len(x_vec).
"""
out = np.empty_like(x_vec)
for i, x in enumerate(x_vec):
out[i] = interp1d(grid, vals, x)
return out