-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathheat_eq.py
64 lines (46 loc) · 1.4 KB
/
heat_eq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import jax.numpy as jnp
import numpy as np
import jax
from utils import calculate_cfl
from functools import partial
def get_integrator(model,dt):
def rk45_integrator(T):
k1 = model(T)
k2 = model(T+dt*k1/2)
k3 = model(T+dt*k2/2)
k4 = model(T+dt*k3)
return T+dt/6 *( k1 + 2*k2 + 2*k3 +k4)
return rk45_integrator
def get_heatmodel(diffusivity,sim_parameters):
dx = sim_parameters["dx"]
def gradient(T):
#3 order approximation to gradient
f = -jnp.roll(T,-2)
f += 8*jnp.roll(T,-1)
f += -8*jnp.roll(T,1)
f += jnp.roll(T,2)
return f/(12*dx)
def laplacian(T):
f = -jnp.roll(T,-2)
f += +16*jnp.roll(T,-1)
f += -30*T
f += 16*jnp.roll(T,1)
f += -jnp.roll(T,2)
f += f/(12*dx**2)
return f
def heat_varying_diffusivity(x):
return gradient(diffusivity*gradient(x))
return heat_varying_diffusivity
def get_rollout_fn(diffusivity,sim_parameters):
Nt = sim_parameters["Nt"]
dt = sim_parameters["dt"]
model = get_heatmodel(diffusivity,sim_parameters)
integrator = get_integrator(model,dt)
def scan_fn(x,_):
x_next = integrator(x)
return x_next, x_next
def rollout_fn(T_0):
_, trj = jax.lax.scan(scan_fn,T_0,None,Nt)
trj = jnp.vstack((T_0,trj))
return trj
return rollout_fn