Skip to content

Commit

Permalink
added interpolation capability, uses less space when saving results
Browse files Browse the repository at this point in the history
  • Loading branch information
Johannes Ulf Lange committed Feb 18, 2020
1 parent 3f904ec commit 8aae6b0
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 19 deletions.
3 changes: 1 addition & 2 deletions scripts/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
rp_bins = np.logspace(-1, 1, 20)

halocat = CachedHaloCatalog(simname='bolplanck')
halotab = TabCorr.tabulate(halocat, wp, rp_bins, pi_max=40,
period=halocat.Lbox)
halotab = TabCorr.tabulate(halocat, wp, rp_bins, pi_max=40)

# We can save the result for later use.
halotab.write('bolplanck.hdf5')
Expand Down
4 changes: 2 additions & 2 deletions tabcorr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .tabcorr import TabCorr
from .tabcorr import *

__all__ = ["TabCorr"]
__all__ = ["TabCorr", "interpolate_predict"]
219 changes: 204 additions & 15 deletions tabcorr/tabcorr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import h5py
import numpy as np
import itertools
from sys import getsizeof
from scipy.spatial import Delaunay
from astropy.table import Table, vstack
from halotools.empirical_models import HodModelFactory, model_defaults
from halotools.empirical_models import TrivialPhaseSpace, Zheng07Cens
Expand Down Expand Up @@ -38,7 +38,7 @@ def tabulate(cls, halocat, tpcf, *tpcf_args,
verbose=False, redshift_space_distortions=True,
cens_prof_model=None, sats_prof_model=None, project_xyz=False,
cosmology_ref=None, **tpcf_kwargs):
r"""
"""
Tabulates correlation functions for halos such that galaxy correlation
functions can be calculated rapidly.
Expand Down Expand Up @@ -135,7 +135,7 @@ def tabulate(cls, halocat, tpcf, *tpcf_args,
If True, the coordinates will be projected along all three spatial
axes. By default, only the projection onto the z-axis is used.
*tpcf_kwargs : dict, optional
**tpcf_kwargs : dict, optional
Keyword arguments passed to the ``tpcf`` function.
Returns
Expand Down Expand Up @@ -343,7 +343,7 @@ def tabulate(cls, halocat, tpcf, *tpcf_args,

@classmethod
def read(cls, fname):
r"""
"""
Reads tabulated correlation functions from the disk.
Parameters
Expand All @@ -365,16 +365,25 @@ def read(cls, fname):
for key in fstream.attrs.keys():
halotab.attrs[key] = fstream.attrs[key]

halotab.tpcf_matrix = fstream['tpcf_matrix'].value
tpcf_matrix = fstream['tpcf_matrix'][()]
if halotab.attrs['mode'] == 'auto' and len(tpcf_matrix.shape) == 2:
halotab.tpcf_matrix = []
for i in range(tpcf_matrix.shape[0]):
halotab.tpcf_matrix.append(array_to_symmetric_matrix(
tpcf_matrix[i]))
halotab.tpcf_matrix = np.array(halotab.tpcf_matrix)
else:
halotab.tpcf_matrix = tpcf_matrix

halotab.tpcf_args = []
for key in fstream['tpcf_args'].keys():
halotab.tpcf_args.append(fstream['tpcf_args'][key].value)
halotab.tpcf_args.append(fstream['tpcf_args'][key][()])
halotab.tpcf_args = tuple(halotab.tpcf_args)
halotab.tpcf_kwargs = {}
if 'tpcf_kwargs' in fstream:
for key in fstream['tpcf_kwargs'].keys():
halotab.tpcf_kwargs[key] = fstream['tpcf_kwargs'][key].value
halotab.tpcf_shape = tuple(fstream['tpcf_shape'].value)
halotab.tpcf_kwargs[key] = fstream['tpcf_kwargs'][key][()]
halotab.tpcf_shape = tuple(fstream['tpcf_shape'][()])
fstream.close()

halotab.gal_type = Table.read(fname, path='gal_type')
Expand All @@ -383,8 +392,9 @@ def read(cls, fname):

return halotab

def write(self, fname, overwrite=False, max_args_size=1000000):
r"""
def write(self, fname, overwrite=False, max_args_size=1000000,
matrix_dtype=np.float32):
"""
Writes tabulated correlation functions to the disk.
Parameters
Expand All @@ -408,7 +418,16 @@ def write(self, fname, overwrite=False, max_args_size=1000000):
for key in keys:
fstream.attrs[key] = self.attrs[key]

fstream['tpcf_matrix'] = self.tpcf_matrix
if self.attrs['mode'] == 'auto':
tpcf_matrix_write = []
for i in range(self.tpcf_matrix.shape[0]):
tpcf_matrix_write.append(symmetric_matrix_to_array(
self.tpcf_matrix[i]))
fstream['tpcf_matrix'] = np.array(tpcf_matrix_write,
dtype=matrix_dtype)
else:
fstream['tpcf_matrix'] = self.tpcf_matrix.astype(matrix_dtype)

for i, arg in enumerate(self.tpcf_args):
if (type(arg) is not np.ndarray or
np.prod(arg.shape) < max_args_size):
Expand All @@ -422,8 +441,8 @@ def write(self, fname, overwrite=False, max_args_size=1000000):

self.gal_type.write(fname, path='gal_type', append=True)

def predict(self, model, separate_gal_type=False):
r"""
def predict(self, model, separate_gal_type=False, **occ_kwargs):
"""
Predicts the number density and correlation function for a certain
model.
Expand All @@ -437,6 +456,10 @@ def predict(self, model, separate_gal_type=False):
If True, the return values are dictionaries divided by each galaxy
types contribution to the output result.
**occ_kwargs : dict, optional
Keyword arguments passed to the ``mean_occupation`` functions
of the model.
Returns
-------
ngal : numpy.array or dict
Expand Down Expand Up @@ -491,11 +514,11 @@ def predict(self, model, separate_gal_type=False):
mean_occupation[mask] = model.mean_occupation_centrals(
prim_haloprop=self.gal_type['prim_haloprop'][mask],
sec_haloprop_percentile=(
self.gal_type['sec_haloprop_percentile'][mask]))
self.gal_type['sec_haloprop_percentile'][mask]), **occ_kwargs)
mean_occupation[~mask] = model.mean_occupation_satellites(
prim_haloprop=self.gal_type['prim_haloprop'][~mask],
sec_haloprop_percentile=(
self.gal_type['sec_haloprop_percentile'][~mask]))
self.gal_type['sec_haloprop_percentile'][~mask]), **occ_kwargs)

ngal = mean_occupation * self.gal_type['n_h'].data

Expand Down Expand Up @@ -539,3 +562,169 @@ def predict(self, model, separate_gal_type=False):
xi * mask, axis=1).reshape(self.tpcf_shape)

return ngal_dict, xi_dict


def interpolate_predict(tabcorr_arr, x, xi, model, extrapolate=True,
**occ_kwargs):
"""
Linearly interprets the predictions from multiple TabCorr instances. Each
TabCorr instance will have a corresponding x-value and this function
finds a linear/barycentric interpolation for an intermediate point xi. For
example, this function can be used for predict correlation functions
for continues choices of concentrations for satellites.
Parameters
----------
tabcorr_arr : array_like
TabCorr instances used to interpolate.
x : array_like
Array of shape (npts) or (npts, ndim). Here, npts is the number of
TabCorr instances and ndim the number of dimensions over which
we will interpolate.
xi : float or array_like
x-value at which to interpolate. If x has shape (npts), xi must have
be a float. On the other hand, if x has shape (npts, ndim), xi must
be an array of len ndim.
model : HodModelFactory
Instance of ``halotools.empirical_models.HodModelFactory``
describing the model for which predictions are made.
extrapolate : boolean, optional
Whether to allow extrapolation beyond points sampled by x. If set to
False, attempting to extrapolate will result in a RuntimeError.
**occ_kwargs : dict, optional
Keyword arguments passed to the ``mean_occupation`` functions of
the model.
Returns
-------
ngal : numpy.array or dict
Array containing the number densities for each galaxy type stored in
self.gal_type. The total galaxy number density is the sum of all
elements of this array.
xi : numpy.array or dict
Array storing the prediction for the correlation function.
"""

try:
assert isinstance(x, list) or isinstance(x, np.ndarray)
x = np.array(x)
assert (x.ndim == 1) or (x.ndim == 2)
except AssertionError:
raise RuntimeError('x must be a one or two-dimensional array.')

if x.ndim == 1:
n_points = x.shape[0]
n_dim = 1
elif x.ndim == 2:
n_points = x.shape[0]
n_dim = x.shape[1]
if n_points <= n_dim:
raise RuntimeError('x must contain more points than dimensions.')

try:
if n_dim == 1:
assert isinstance(xi, float)
else:
assert isinstance(xi, list) or isinstance(xi, np.ndarray)
xi = np.array(xi)
assert xi.ndim == 1
assert len(xi) == n_dim
except AssertionError:
raise RuntimeError('xi must match the dimensionality of x.')

if n_points != len(tabcorr_arr):
raise RuntimeError('The length of tabcorr_arr does not match the ' +
'number of points provided via x.')

if n_dim > 1:

tri = Delaunay(x)
i = tri.find_simplex(xi)
if i != -1:
s = tri.simplices[i]
if i == -1:
if not extrapolate:
raise RuntimeError('x is outside of the interpolation range.')
else:
x_cm = np.mean(x[tri.simplices], axis=1)
i = np.argmin(np.sum((xi - x_cm)**2, axis=1)) # closest simplex

s = tri.simplices[i]
b = tri.transform[i, :-1].dot(xi - tri.transform[i, -1])
w = np.append(b, 1 - np.sum(b))

else:

if np.any(x < xi) and np.any(x > xi):
s = [np.ma.MaskedArray.argmax(np.ma.masked_array(x, mask=(x > xi))),
np.ma.MaskedArray.argmin(np.ma.masked_array(x, mask=(x < xi)))]
else:
if not extrapolate:
raise RuntimeError('x is outside of the interpolation range.')
else:
s = np.argsort(np.abs(xi - x))[:2]

w = np.array([x[s[1]] - xi, xi - x[s[0]]]) / (x[s[1]] - x[s[0]])

for i in range(len(s)):
ngal_i, xi_i = tabcorr_arr[s[i]].predict(model, **occ_kwargs)
if i == 0:
ngal = ngal_i * w[i]
xi = xi_i * w[i]
else:
ngal += ngal_i * w[i]
xi += xi_i * w[i]

return ngal, xi


def symmetric_matrix_to_array(matrix):

try:
assert matrix.shape[0] == matrix.shape[1]
assert np.all(matrix == matrix.T)
except AssertionError:
raise RuntimeError('The matrix you provided is not symmetric.')

array = np.zeros((np.prod(matrix.shape) + matrix.shape[0]) // 2,
dtype=matrix.dtype)

k = 0
for i in range(matrix.shape[0]):
for j in range(matrix.shape[1]):
if j > i:
continue
else:
array[k] = matrix[i, j]
k = k + 1

return array


def array_to_symmetric_matrix(array):

dim = int(np.rint(-0.5 + np.sqrt(0.25 + len(array) * 2)))

try:
assert len(array) == (dim**2 + dim) / 2
except AssertionError:
raise RuntimeError('The length of the array does not correspond to a' +
' symmetric matrix.')

matrix = np.zeros((dim, dim), dtype=array.dtype)

k = 0
for i in range(matrix.shape[0]):
matrix[i, :i+1] = array[k:k+i+1]
k = k + i + 1

matrix = matrix + matrix.T - matrix * np.identity(matrix.shape[0])

return matrix

0 comments on commit 8aae6b0

Please sign in to comment.