Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More large-catalog speedups #397

Merged
merged 7 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions flarestack/core/injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
import zipfile
import zlib
from astropy.table import Table
from flarestack.shared import band_mask_cache_name
from flarestack.core.energy_pdf import EnergyPDF, read_e_pdf_dict
from flarestack.core.time_pdf import TimePDF, read_t_pdf_dict
Expand Down Expand Up @@ -128,7 +129,7 @@ def get_n_exp_single(self, source):
else:
name = source["source_name"]

return np.copy(self.n_exp[self.n_exp["source_name"] == name])
return self.n_exp[self.n_exp["source_name"] == name]

def get_expectation(self, source, scale):
return float(self.get_n_exp_single(source)["n_exp"]) * scale
Expand Down Expand Up @@ -228,7 +229,7 @@ class MCInjector(BaseInjector):

def __init__(self, season, sources, **kwargs):
kwargs = read_injector_dict(kwargs)
self._mc = season.get_mc()
self._mc = self.get_mc(season)
BaseInjector.__init__(self, season, sources, **kwargs)

self.injection_declination_bandwidth = self.inj_kwargs.pop(
Expand All @@ -243,6 +244,9 @@ def __init__(self, season, sources, **kwargs):
logger.warning("No Injection Arguments. Are you unblinding?")
pass

def get_mc(self, season):
return season.get_mc()

def select_mc_band(self, source):
"""For a given source, selects MC events within a declination band of
width +/- 5 degrees that contains the source. Then returns the MC data
Expand All @@ -260,7 +264,7 @@ def select_mc_band(self, source):

band_mask = self.get_band_mask(source, min_dec, max_dec)

return np.copy(self._mc[band_mask]), omega, band_mask
return self._mc[band_mask], omega, band_mask

def get_band_mask(self, source, min_dec, max_dec):
# Checks if the mask has already been evaluated for the source
Expand Down Expand Up @@ -522,6 +526,34 @@ def get_band_mask(self, source, min_dec, max_dec):
return self.band_mask_cache.getrow(entry["source_index"][0]).toarray()[0]


@MCInjector.register_subclass("table_injector")
class TableInjector(MCInjector):
"""
For even larger numbers of sources O(~1000), accessing every element of the
MC array in select_band_mask() once for every source in calculate_n_exp()
becomes a bottleneck. Store event fields in columns, ordered by declination,
making single-field accesses vastly cheaper and band masks practically free.
For 1000 sources, calculate_n_exp() is ~60x faster than MCInjector.
"""

def get_mc(self, season):
mc: np.ndarray = season.get_mc()
# Sort rows by trueDec, and store as columns in a Table
table = Table(mc[np.argsort(mc["trueDec"].copy())])
# Prevent in-place modifications
for k in table.columns:
table[k].setflags(write=False)
return table

def get_band_mask(self, source, min_dec, max_dec):
return slice(*np.searchsorted(self._mc["trueDec"], [min_dec, max_dec]))

def select_mc_band(self, source):
table, omega, band_mask = super().select_mc_band(source)
# allow individual columns to be replaced
return table.copy(copy_data=False), omega, band_mask


@MCInjector.register_subclass("effective_area_injector")
class EffectiveAreaInjector(BaseInjector):
"""Class for injecting signal events by relying on effective areas rather
Expand Down
126 changes: 82 additions & 44 deletions flarestack/core/llh.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import scipy.interpolate
from scipy import sparse
from astropy.table import Table
from typing import Optional
import pickle
from flarestack.shared import (
Expand Down Expand Up @@ -1370,29 +1371,60 @@ def __init__(self, season, sources, llh_dict):
+ "please change 'the spatial_pdf_name' accordingly"
)

def get_spatially_coincident_indices(self, data, source) -> np.ndarray:
"""
Get spatially coincident data for a single source, taking advantage of
the fact that data are sorted in dec
"""
width = np.deg2rad(self.spatial_box_width)

# Sets a declination band 5 degrees above and below the source
min_dec = max(-np.pi / 2.0, source["dec_rad"] - width)
max_dec = min(np.pi / 2.0, source["dec_rad"] + width)

# Accepts events lying within a 5 degree band of the source
dec_range = slice(*np.searchsorted(data["dec"], [min_dec, max_dec]))

# Sets the minimum value of cos(dec)
cos_factor = np.amin(np.cos([min_dec, max_dec]))

# Scales the width of the box in ra, to give a roughly constant
# area. However, if the width would have to be greater that +/- pi,
# then sets the area to be exactly 2 pi.
dPhi = np.amin([2.0 * np.pi, 2.0 * width / cos_factor])

# Accounts for wrapping effects at ra=0, calculates the distance
# of each event to the source.
ra_dist = np.fabs(
(data["ra"][dec_range] - source["ra_rad"] + np.pi) % (2.0 * np.pi) - np.pi
)
return np.nonzero(ra_dist < dPhi / 2.0)[0] + dec_range.start

def create_kwargs(self, data, pull_corrector, weight_f=None):
if weight_f is None:
raise Exception(
"Weight function not passed, but is required for "
"standard_overlapping LLH functions."
)

coincidence_matrix = sparse.lil_matrix(
(len(self.sources), len(data)), dtype=bool
)
# Keep data in an astropy Table (column-wise) to improve cache
# performance. Sort to allow get_spatially_coincident_indices to find
# declination bands by binary search.
data = Table(data[np.argsort(data[["dec", "ra"]])])

SoB_rows = [None] * len(self.sources)

kwargs = dict()

kwargs["n_all"] = float(len(data))

sources = self.sources

for i, source in enumerate(sources):
s_mask = self.select_spatially_coincident_data(data, [source])
# Treat sources in declination order to keep caches hot
order = np.argsort(self.sources[["dec_rad", "ra_rad"]])
for i in order:
source = self.sources[i]
idx = self.get_spatially_coincident_indices(data, source)

coincident_data = data[s_mask]

if len(coincident_data) > 0:
if len(idx) > 0:
# Only bother accepting neutrinos where the spacial
# likelihood is greater than 1e-21. This prevents 0s
# appearing in dynamic pull corrections, but also speeds
Expand All @@ -1403,6 +1435,7 @@ def create_kwargs(self, data, pull_corrector, weight_f=None):
# the spatial pdf is calculated (ie the KDE spline is evaluated) for gamma = 2
# If we want to be correct this should be in a loop for the get_gamma_support_points
# but that would add an extra dimension in the matrix so better not
coincident_data = data[idx]
if (
self.spatial_pdf.signal.SplineIs4D
and self.spatial_pdf.signal.KDE_eval_gamma is not None
Expand All @@ -1415,24 +1448,35 @@ def create_kwargs(self, data, pull_corrector, weight_f=None):
):
sig = self.signal_pdf(source, coincident_data, gamma=2.0)

nonzero_mask = sig > spatial_mask_threshold
s_mask[s_mask] *= nonzero_mask
nonzero_idx = np.nonzero(sig > spatial_mask_threshold)
column_indices = idx[nonzero_idx]

# build a single-row CSR matrix in canonical format
SoB_rows[i] = sparse.csr_matrix(
(
sig[nonzero_idx]
/ self.background_pdf(source, coincident_data[nonzero_idx]),
column_indices,
[0, len(column_indices)],
),
shape=(1, len(data)),
)
else:
SoB_rows[i] = sparse.csr_matrix((1, len(data)), dtype=float)

coincidence_matrix[i] = s_mask
SoB_only_matrix = sparse.vstack(SoB_rows, format="csr")

# Using Sparse matrixes
coincident_nu_mask = np.sum(coincidence_matrix, axis=0) > 0
coincident_nu_mask = np.array(coincident_nu_mask).ravel()
coincident_source_mask = np.sum(coincidence_matrix, axis=1) > 0
coincident_source_mask = np.array(coincident_source_mask).ravel()
coincident_nu_mask = np.asarray(np.sum(SoB_only_matrix, axis=0) != 0).ravel()
coincident_source_mask = np.asarray(
np.sum(SoB_only_matrix, axis=1) != 0
).ravel()

coincidence_matrix = (
coincidence_matrix[coincident_source_mask].T[coincident_nu_mask].T
SoB_only_matrix = (
SoB_only_matrix[coincident_source_mask].T[coincident_nu_mask].T
)
coincidence_matrix.tocsr()

coincident_data = data[coincident_nu_mask]
coincident_sources = sources[coincident_source_mask]
coincident_sources = self.sources[coincident_source_mask]

season_weight = lambda x: weight_f([1.0, x], self.season)[
coincident_source_mask
Expand All @@ -1450,18 +1494,6 @@ def create_kwargs(self, data, pull_corrector, weight_f=None):
"Creating gamma-independent SoB matrix for all srcs when 3D KDE or 4D w/ 'spatial_pdf_index'"
)

SoB_only_matrix = sparse.lil_matrix(coincidence_matrix.shape, dtype=float)

for i, src in enumerate(coincident_sources):
mask = (coincidence_matrix.getrow(i)).toarray()[0]
SoB_only_matrix[i, mask] = self.signal_pdf(
src, coincident_data[mask]
) / self.background_pdf( # gamma = None
src, coincident_data[mask]
)

SoB_only_matrix = SoB_only_matrix.tocsr()

def joint_SoB(dataset, gamma):
weight = np.array(season_weight(gamma))
weight /= np.sum(weight)
Expand All @@ -1477,20 +1509,26 @@ def joint_SoB(dataset, gamma):
weight = np.array(season_weight(gamma))
weight /= np.sum(weight)

# create an empty lil_matrix (good for matrix creation) with shape
# of coincidence_matrix and type float
SoB_matrix_sparse = sparse.lil_matrix(
coincidence_matrix.shape, dtype=float
)

# Build CSR matrix containing source_weight * S / B, with the
# same sparsity structure as SoB_only_matrix, taking advantage
# of the fact that the column indices are indices into `dataset`
data = np.empty_like(SoB_only_matrix.data)
for i, src in enumerate(coincident_sources):
mask = (coincidence_matrix.getrow(i)).toarray()[0]
SoB_matrix_sparse[i, mask] = (
row = slice(
SoB_only_matrix.indptr[i], SoB_only_matrix.indptr[i + 1]
)
masked_dataset = dataset[SoB_only_matrix.indices[row]]
data[row] = (
weight[i]
* self.signal_pdf(src, dataset[mask], gamma)
/ self.background_pdf(src, dataset[mask])
* self.signal_pdf(src, masked_dataset, gamma)
/ self.background_pdf(src, masked_dataset)
)

SoB_matrix_sparse = sparse.csr_matrix(
(data, SoB_only_matrix.indices, SoB_only_matrix.indptr),
shape=SoB_only_matrix,
)

SoB_sum = SoB_matrix_sparse.sum(axis=0)
return_value = np.array(SoB_sum).ravel()

Expand Down
2 changes: 1 addition & 1 deletion flarestack/core/minimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,7 +1267,7 @@ class LargeCatalogueMinimisationHandler(FixedWeightMinimisationHandler):

compatible_llh = ["standard_matrix", "std_matrix_kde_enabled"]
compatible_negative_n_s = False
compatible_injectors = ["low_memory_injector"]
compatible_injectors = ["low_memory_injector", "table_injector"]

def __init__(self, mh_dict):
FixedWeightMinimisationHandler.__init__(self, mh_dict)
Expand Down
8 changes: 1 addition & 7 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@
packages = flarestack
exclude = analyses

[mypy-astropy]
ignore_missing_imports = True

[mypy-astropy.coordinates]
ignore_missing_imports = True

[mypy-astropy.cosmology]
[mypy-astropy.*]
ignore_missing_imports = True

[mypy-healpy]
Expand Down
Loading