Skip to content

Commit

Permalink
Move eval_func_on_grid to utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Nov 15, 2023
1 parent b8f27c1 commit 4fee8d5
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 53 deletions.
41 changes: 0 additions & 41 deletions examples/ex_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,47 +27,6 @@ def plot_corner(samples, labels=None):
fig = corner.corner(samples, labels=labels_corner)


def eval_func_on_grid(func, xmin, xmax, ymin, ymax, nx, ny):
"""
Evalute 2D function on a grid.
Args:
- func:
Function to evalate.
- xmin:
Minimum x value to consider in grid domain.
- xmax:
Maximum x value to consider in grid domain.
- ymin:
Minimum y value to consider in grid domain.
- ymax:
Maximum y value to consider in grid domain.
- nx:
Number of samples to include in grid in x direction.
- ny:
Number of samples to include in grid in y direction.
Returns:
- func_eval_grid:
Function values evaluated on the 2D grid.
- x_grid:
x values over the 2D grid.
- y_grid:
y values over the 2D grid.
"""

# Evaluate func over grid.
x = np.linspace(xmin, xmax, nx)
y = np.linspace(ymin, ymax, ny)
x_grid, y_grid = np.meshgrid(x, y)
func_eval_grid = np.zeros((nx, ny))
for i in range(nx):
for j in range(ny):
func_eval_grid[i, j] = func(np.array([x_grid[i, j], y_grid[i, j]]))

return func_eval_grid, x_grid, y_grid


def plot_surface(
func_eval_grid,
x_grid,
Expand Down
4 changes: 2 additions & 2 deletions examples/normal_gamma_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def run_example(
x_n=x_n,
prior_params=prior_params,
)
ln_posterior_grid, x_grid, y_grid = ex_utils.eval_func_on_grid(
ln_posterior_grid, x_grid, y_grid = hm.utils.eval_func_on_grid(
ln_posterior_func,
xmin=-0.6,
xmax=0.6,
Expand Down Expand Up @@ -471,7 +471,7 @@ def run_example(
)

# Evaluate model on grid.
model_grid, x_grid, y_grid = ex_utils.eval_func_on_grid(
model_grid, x_grid, y_grid = hm.utils.eval_func_on_grid(
model.predict, xmin=-0.6, xmax=0.6, ymin=0.4, ymax=1.8, nx=500, ny=500
)

Expand Down
6 changes: 3 additions & 3 deletions examples/radiata_pine_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def model_predict_x0x1(x_2d):
# print("x01x1: x = {}".format(x))
return model.predict(x)

model_grid, x_grid, y_grid = ex_utils.eval_func_on_grid(
model_grid, x_grid, y_grid = hm.utils.eval_func_on_grid(
model_predict_x0x1,
xmin=2900.0,
xmax=3100.0,
Expand Down Expand Up @@ -575,7 +575,7 @@ def model_predict_x1x2(x_2d):
# print("x1x2: x = {}".format(x))
return model.predict(x)

model_grid, x_grid, y_grid = ex_utils.eval_func_on_grid(
model_grid, x_grid, y_grid = hm.utils.eval_func_on_grid(
model_predict_x1x2,
xmin=185.0 - 30.0,
xmax=185.0 + 30.0,
Expand Down Expand Up @@ -617,7 +617,7 @@ def model_predict_x0x2(x_2d):
x = np.append(x, x_2d[1])
return model.predict(x)

model_grid, x_grid, y_grid = ex_utils.eval_func_on_grid(
model_grid, x_grid, y_grid = hm.utils.eval_func_on_grid(
model_predict_x0x2,
xmin=2900.0,
xmax=3100.0,
Expand Down
4 changes: 2 additions & 2 deletions examples/rastrigin.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def run_example(
if ndim == 2:
hm.logs.debug_log("Compute evidence by numerical integration...")
ln_posterior_func = partial(ln_posterior, ln_prior=ln_prior)
ln_posterior_grid, x_grid, y_grid = ex_utils.eval_func_on_grid(
ln_posterior_grid, x_grid, y_grid = hm.utils.eval_func_on_grid(
ln_posterior_func,
xmin=-6.0,
xmax=6.0,
Expand Down Expand Up @@ -377,7 +377,7 @@ def run_example(
)

# Evaluate model on grid.
model_grid, x_grid, y_grid = ex_utils.eval_func_on_grid(
model_grid, x_grid, y_grid = hm.utils.eval_func_on_grid(
model.predict,
xmin=-6.0,
xmax=6.0,
Expand Down
5 changes: 2 additions & 3 deletions examples/rosenbrock.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import numpy as np
import sys
import emcee
import time
import matplotlib.pyplot as plt
from functools import partial
import harmonic as hm
import ex_utils


def ln_prior_uniform(x, xmin=-10.0, xmax=10.0, ymin=-5.0, ymax=15.0):
Expand Down Expand Up @@ -144,6 +142,7 @@ def run_example(
a = 1.0
b = 100.0

# Beginning of path where plots will be saved
save_name_start = "examples/plots/" + flow_type

if flow_type == "RealNVP":
Expand Down Expand Up @@ -250,7 +249,7 @@ def run_example(
if ndim == 2:
hm.logs.debug_log("Compute evidence by numerical integration...")
ln_posterior_func = partial(ln_posterior, ln_prior=ln_prior, a=a, b=b)
ln_posterior_grid, x_grid, y_grid = ex_utils.eval_func_on_grid(
ln_posterior_grid, x_grid, y_grid = hm.utils.eval_func_on_grid(
ln_posterior_func,
xmin=-10.0,
xmax=10.0,
Expand Down
4 changes: 2 additions & 2 deletions examples/rosenbrock_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def run_example(
if ndim == 2:
hm.logs.debug_log("Compute evidence by numerical integration...")
ln_posterior_func = partial(ln_posterior, ln_prior=ln_prior, a=a, b=b)
ln_posterior_grid, x_grid, y_grid = ex_utils.eval_func_on_grid(
ln_posterior_grid, x_grid, y_grid = hm.utils.eval_func_on_grid(
ln_posterior_func,
xmin=-10.0,
xmax=10.0,
Expand Down Expand Up @@ -417,7 +417,7 @@ def run_example(
)

# Evaluate model on grid.
model_grid, x_grid, y_grid = ex_utils.eval_func_on_grid(
model_grid, x_grid, y_grid = hm.utils.eval_func_on_grid(
model.predict,
xmin=-10.0,
xmax=10.0,
Expand Down
154 changes: 154 additions & 0 deletions harmonic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,160 @@
from getdist import plots


def eval_func_on_grid(func, xmin, xmax, ymin, ymax, nx, ny):
"""
Evalute 2D function on a grid.
Args:
- func:
Function to evalate.
- xmin:
Minimum x value to consider in grid domain.
- xmax:
Maximum x value to consider in grid domain.
- ymin:
Minimum y value to consider in grid domain.
- ymax:
Maximum y value to consider in grid domain.
- nx:
Number of samples to include in grid in x direction.
- ny:
Number of samples to include in grid in y direction.
Returns:
- func_eval_grid:
Function values evaluated on the 2D grid.
- x_grid:
x values over the 2D grid.
- y_grid:
y values over the 2D grid.
"""

# Evaluate func over grid.
x = np.linspace(xmin, xmax, nx)
y = np.linspace(ymin, ymax, ny)
x_grid, y_grid = np.meshgrid(x, y)
func_eval_grid = np.zeros((nx, ny))
for i in range(nx):
for j in range(ny):
func_eval_grid[i, j] = func(np.array([x_grid[i, j], y_grid[i, j]]))

return func_eval_grid, x_grid, y_grid


def plot_surface(
func_eval_grid,
x_grid,
y_grid,
samples=None,
vals=None,
contour_z_offset=None,
contours=None,
alpha=0.3,
):
"""
Plot surface defined by 2D function on a grid.
Samples may also be optionally plotted.
Args:
- func_eval_grid:
Function evalated over 2D grid.
- x_grid:
x values over the 2D grid.
- y_grid:
y values over the 2D grid.
- samples:
2D array of shape (ndim, nsamples) containing samples.
- vals:
1D array of function values at sample locations. Both samples and
vals must be provided if they are to be plotted.
- contour_z_offset:
If not None then plot contour in plane specified by z offset.
- contours:
Values at which to draw contours (must be in increasing order).
- alpha:
Opacity of surface plot.
Returns:
- ax:
Plot axis.
"""

# Set up axis for surface plot.
fig, ax = plt.subplots(subplot_kw=dict(projection="3d"))

# Create an instance of a LightSource and use it to illuminate
# the surface.
light = LightSource(60, 120)
rgb = np.ones((func_eval_grid.shape[0], func_eval_grid.shape[1], 3))
illuminated_surface = light.shade_rgb(rgb * np.array([0, 0.0, 1.0]), func_eval_grid)

# Plot surface.
ax.plot_surface(
x_grid,
y_grid,
func_eval_grid,
alpha=alpha,
linewidth=0,
antialiased=False,
# cmap=cm.coolwarm,
facecolors=illuminated_surface,
)

# Plot contour.
if contour_z_offset is not None:
if contours is not None:
cset = ax.contour(
x_grid,
y_grid,
func_eval_grid,
contours,
zdir="z",
offset=contour_z_offset,
cmap=cm.coolwarm,
)
else:
cset = ax.contour(
x_grid,
y_grid,
func_eval_grid,
zdir="z",
offset=contour_z_offset,
cmap=cm.coolwarm,
)

# Set domain.
xmin = np.min(x_grid)
xmax = np.max(x_grid)
ymin = np.min(y_grid)
ymax = np.max(y_grid)

# # Plot samples.
if samples is not None and vals is not None:
xplot = samples[:, 0]
yplot = samples[:, 1]
# Manually remove samples outside of plot region
# (since Matplotlib clipping cannot do this in 3D; see
# https://github.com/matplotlib/matplotlib/issues/749).
xplot[xplot < xmin] = np.nan
xplot[xplot > xmax] = np.nan
yplot[yplot < ymin] = np.nan
yplot[yplot > ymax] = np.nan
zplot = vals
ax.scatter(xplot, yplot, zplot, c="r", s=5, marker=".")

# Define additional plot settings.
ax.set_xlim(xmin, xmax)
ax.set_ylim(ymin, ymax)
ax.view_init(elev=15.0, azim=110.0)
ax.set_xlabel("$x_0$")
ax.set_ylabel("$x_1$")
ax.set_zlim(zmin=contour_z_offset)

return ax


def plot_getdist(samples, labels=None):
"""
Plot triangle plot of marginalised distributions using getdist package.
Expand Down

0 comments on commit 4fee8d5

Please sign in to comment.