Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing channel handling in full llhd simplification #13

Merged
merged 13 commits into from
Sep 19, 2024
6 changes: 3 additions & 3 deletions .zenodo.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
{
"description": "pyhf plug-in for spey package",
"license": "MIT",
"title": "SpeysideHEP/spey-pyhf: v0.1.5",
"version": "v0.1.5",
"title": "SpeysideHEP/spey-pyhf: v0.1.6",
"version": "v0.1.6",
"upload_type": "software",
"creators": [
{
Expand All @@ -29,7 +29,7 @@
},
{
"scheme": "url",
"identifier": "https://github.com/SpeysideHEP/spey-pyhf/tree/v0.1.5",
"identifier": "https://github.com/SpeysideHEP/spey-pyhf/tree/v0.1.6",
"relation": "isSupplementTo"
},
{
Expand Down
6 changes: 6 additions & 0 deletions docs/releases/changelog-v0.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
* Improve undefined channel handling in the patchset
([#12](https://github.com/SpeysideHEP/spey-pyhf/pull/12))

* Improve undefined channel handling in the patchset for full likelihood simplification.
([#13](https://github.com/SpeysideHEP/spey-pyhf/pull/13))

* Add modifier check to signal injection.
([#13](https://github.com/SpeysideHEP/spey-pyhf/pull/13))

## Bug fixes

* Bugfix in `simplify` module, where signal injector was not initiated properly.
Expand Down
12 changes: 11 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
url="https://github.com/SpeysideHEP/spey-pyhf",
project_urls={
"Bug Tracker": "https://github.com/SpeysideHEP/spey-pyhf/issues",
"Documentation": "https://spey-pyhf.readthedocs.io",
"Repository": "https://github.com/SpeysideHEP/spey-pyhf",
"Homepage": "https://github.com/SpeysideHEP/spey-pyhf",
"Download": f"https://github.com/SpeysideHEP/spey-pyhf/archive/refs/tags/v{version}.tar.gz",
},
download_url=f"https://github.com/SpeysideHEP/spey-pyhf/archive/refs/tags/v{version}.tar.gz",
author="Jack Y. Araz",
Expand All @@ -50,8 +54,14 @@
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering :: Physics",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
extras_require={
"dev": ["pytest>=7.1.2", "pytest-cov>=3.0.0", "twine>=3.7.1", "wheel>=0.37.1"],
Expand Down
2 changes: 1 addition & 1 deletion src/spey_pyhf/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Version of the spey - pyhf plugin"""

__version__ = "0.1.5"
__version__ = "0.1.6"
42 changes: 28 additions & 14 deletions src/spey_pyhf/helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ def __dir__():
return __all__


# pylint: disable=W1203, W1201, C0103

log = logging.getLogger("Spey")


Expand Down Expand Up @@ -168,6 +170,7 @@ def inject_signal(
Args:
channel (``Text``): channel name
data (``List[float]``): signal yields
modifiers (``List[Dict]``): uncertainties. If None, default modifiers will be added.

Raises:
``ValueError``: If channel does not exist or number of yields does not match
Expand All @@ -184,9 +187,20 @@ def inject_signal(
f"{self.bin_map[channel]} expected, {len(data)} received."
)

default_modifiers = _default_modifiers(self.poi_name[0][1])
if modifiers is not None:
for mod in default_modifiers:
if mod not in modifiers:
log.warning(
f"Modifier `{mod['name']}` with type `{mod['type']}` is missing"
f" from the input. Adding `{mod['name']}`"
)
log.debug(f"Adding modifier: {mod}")
modifiers.append(mod)

self._signal_dict[channel] = data
self._signal_modifiers[channel] = (
_default_modifiers(self.poi_name[0][1]) if modifiers is None else modifiers
default_modifiers if modifiers is None else modifiers
)

@property
Expand Down Expand Up @@ -237,7 +251,7 @@ def reset_signal(self) -> None:

def add_patch(self, signal_patch: List[Dict]) -> None:
"""Inject signal patch"""
self._signal_dict, self._to_remove = self.patch_to_map(
self._signal_dict, self._signal_modifiers, self._to_remove = self.patch_to_map(
signal_patch=signal_patch, return_remove_list=True
)

Expand Down Expand Up @@ -272,7 +286,10 @@ def remove_list(self) -> List[Text]:

def patch_to_map(
self, signal_patch: List[Dict], return_remove_list: bool = False
) -> Union[Tuple[Dict[Text, Dict], List[Text]], Dict[Text, Dict]]:
) -> Union[
Tuple[Dict[Text, Dict], Dict[Text, Dict], List[Text]],
Tuple[Dict[Text, Dict], Dict[Text, Dict]],
]:
"""
Convert JSONPatch into signal map

Expand All @@ -288,23 +305,20 @@ def patch_to_map(
.. versionadded:: 0.1.5

Returns:
``Tuple[Dict[Text, Dict], List[Text]]`` or ``Dict[Text, Dict]``:
``Tuple[Dict[Text, Dict], Dict[Text, Dict], List[Text]]`` or ``Tuple[Dict[Text, Dict], Dict[Text, Dict]]``:
signal map including the data and modifiers and the list of channels to be removed.
"""
signal_map = {}
to_remove = []
signal_map, modifier_map, to_remove = {}, {}, []
for item in signal_patch:
path = int(item["path"].split("/")[2])
channel_name = self["channels"][path]["name"]
if item["op"] == "add":
signal_map[channel_name] = {
"data": item["value"]["data"],
"modifiers": item["value"].get(
"modifiers", _default_modifiers(poi_name=self.poi_name[0][1])
),
}
signal_map[channel_name] = item["value"]["data"]
modifier_map[channel_name] = item["value"].get(
"modifiers", _default_modifiers(poi_name=self.poi_name[0][1])
)
elif item["op"] == "remove":
to_remove.append(channel_name)
if return_remove_list:
return signal_map, to_remove
return signal_map
return signal_map, modifier_map, to_remove
return signal_map, modifier_map
102 changes: 81 additions & 21 deletions src/spey_pyhf/simplify.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Interface to convert pyhf likelihoods to simplified likelihood framework"""
import copy
import logging
import warnings
from typing import Callable, List, Optional, Text, Union, Literal
from contextlib import contextmanager
from typing import Callable, List, Literal, Optional, Text, Union

import numpy as np
import spey
Expand All @@ -18,6 +20,11 @@ def __dir__():
return []


# pylint: disable=W1203, R0903

log = logging.getLogger("Spey")


class ConversionError(Exception):
"""Conversion error class"""

Expand All @@ -41,6 +48,22 @@ def func(vector: np.ndarray) -> float:
return func


@contextmanager
def _disable_logging(highest_level: int = logging.CRITICAL):
"""
Temporary disable logging implementation, this should move into Spey

Args:
highest_level (``int``, default ``logging.CRITICAL``): highest level to be set in logging
"""
previous_level = logging.root.manager.disable
logging.disable(highest_level)
try:
yield
finally:
logging.disable(previous_level)


class Simplify(spey.ConverterBase):
r"""
An interface to convert pyhf full statistical model prescription into simplified likelihood
Expand Down Expand Up @@ -175,9 +198,10 @@ def __call__(
}[fittype]

interpreter = WorkspaceInterpreter(bkgonly_model)
bin_map = interpreter.bin_map

# configure signal patch map with respect to channel names
signal_patch_map = interpreter.patch_to_map(signal_patch)
signal_patch_map, signal_modifiers_map = interpreter.patch_to_map(signal_patch)

# Prepare a JSON patch to separate control and validation regions
# These regions are generally marked as CR and VR
Expand All @@ -190,25 +214,26 @@ def __call__(
)

for channel in interpreter.get_channels(control_region_indices):
interpreter.inject_signal(
channel,
[0.0] * len(signal_patch_map[channel]["data"]),
signal_patch_map[channel]["modifiers"]
if include_modifiers_in_control_model
else None,
)
if channel in signal_patch_map and channel in signal_modifiers_map:
interpreter.inject_signal(
channel,
[0.0] * bin_map[channel],
signal_modifiers_map[channel]
if include_modifiers_in_control_model
else None,
)

pdf_wrapper = spey.get_backend("pyhf")
control_model = pdf_wrapper(
background_only_model=bkgonly_model, signal_patch=interpreter.make_patch()
)
with _disable_logging():
control_model = pdf_wrapper(
background_only_model=bkgonly_model, signal_patch=interpreter.make_patch()
)

# Extract the nuisance parameters that maximises the likelihood at mu=0
fit_opts = control_model.prepare_for_fit(expected=expected)
_, fit_param = fit(
**fit_opts,
initial_parameters=None,
bounds=None,
fixed_poi_value=0.0,
)

Expand All @@ -234,13 +259,33 @@ def __call__(
)

# Retreive pyhf models and compare parameter maps
stat_model_pyhf = statistical_model.backend.model()[1]
if include_modifiers_in_control_model:
stat_model_pyhf = statistical_model.backend.model()[1]
else:
# Remove the nuisance parameters from the signal patch
# Note that even if the signal yields are zero, nuisance parameters
# do contribute to the statistical model and some models may be highly
# sensitive to the shape and size of the nuisance parameters.
with _disable_logging():
tmp_interpreter = copy.deepcopy(interpreter)
for channel, data in signal_patch_map.items():
tmp_interpreter.inject_signal(channel=channel, data=data)
tmp_model = spey.get_backend("pyhf")(
background_only_model=bkgonly_model,
signal_patch=tmp_interpreter.make_patch(),
)
stat_model_pyhf = tmp_model.backend.model()[1]
del tmp_model, tmp_interpreter
control_model_pyhf = control_model.backend.model()[1]
is_nuisance_map_different = (
stat_model_pyhf.config.par_map != control_model_pyhf.config.par_map
)
fit_opts = statistical_model.prepare_for_fit(expected=expected)
suggested_fixed = fit_opts["model_configuration"].suggested_fixed
log.debug(
"Number of parameters to be fitted during the scan: "
f"{fit_opts['model_configuration'].npar - len(fit_param)}"
)

samples = []
warnings_list = []
Expand Down Expand Up @@ -290,7 +335,9 @@ def __call__(
_, new_params = fit(
**current_fit_opts,
initial_parameters=init_params.tolist(),
bounds=None,
bounds=current_fit_opts[
"model_configuration"
].suggested_bounds,
)
warnings_list += w

Expand All @@ -304,13 +351,16 @@ def __call__(
# Some of the samples can lead to problems while sampling from a poisson distribution.
# e.g. poisson requires positive lambda values to sample from. If sample leads to a negative
# lambda value continue sampling to avoid that point.
log.debug("Problem with the sample generation")
log.debug(
f"Nuisance parameters: {current_nui_params if new_params is None else new_params}"
)
continue

if len(warnings_list) > 0:
warnings.warn(
message=f"{len(warnings_list)} warning(s) generated during sampling."
" This might be due to edge cases in nuisance parameter sampling.",
category=RuntimeWarning,
log.warning(
f"{len(warnings_list)} warning(s) generated during sampling."
" This might be due to edge cases in nuisance parameter sampling."
)

samples = np.vstack(samples)
Expand All @@ -323,9 +373,19 @@ def __call__(

# NOTE: model spec might be modified within the pyhf workspace, thus
# yields needs to be reordered properly before constructing the simplified likelihood
signal_yields = []
signal_yields, missing_channels = [], []
for channel_name in stat_model_pyhf.config.channels:
signal_yields += signal_patch_map[channel_name]["data"]
try:
signal_yields += signal_patch_map[channel_name]
except KeyError:
missing_channels.append(channel_name)
signal_yields += [0.0] * bin_map[channel_name]
if len(missing_channels) > 0:
log.warning(
"Following channels are not in the signal patch,"
f" will be set to zero: {', '.join(missing_channels)}"
)

# NOTE background yields are first moments in simplified framework not the yield values
# in the full statistical model!
background_yields = np.mean(samples, axis=0)
Expand Down
Loading