Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hinting #192

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
cf58e66
Initial type-hinting attempt
ggalloni Aug 27, 2024
bf001b5
Add try-except statements for ccl
ggalloni Aug 28, 2024
338cc84
Add type checking functions in `utils`
ggalloni Aug 29, 2024
2461901
Check types from yaml files
ggalloni Aug 29, 2024
5a12e70
Fix test
ggalloni Aug 29, 2024
ed2c841
Add tests with wrong types
ggalloni Aug 29, 2024
cec70c9
fix
ggalloni Sep 3, 2024
df199fb
compatibilty
ggalloni Sep 3, 2024
57de021
ops
ggalloni Sep 3, 2024
cbd7764
dynamic import
ggalloni Sep 3, 2024
62cf91d
Remove `CCL` and `Tracer` from hints
ggalloni Sep 3, 2024
0696e00
Avoid repetition of types and refactor
ggalloni Sep 3, 2024
d6863c0
Move type checking functions within `Cobaya`
ggalloni Sep 5, 2024
ff79ce9
Switch-on type enforcing
ggalloni Sep 5, 2024
4700060
Fix for `_fast_chi_square`
ggalloni Sep 5, 2024
501d279
Clean code
ggalloni Sep 10, 2024
3564490
Merge branch 'master' into type_hinting
ggalloni Sep 10, 2024
2ddec66
Remove some over-redundant hints
ggalloni Sep 10, 2024
a36a5dd
test conda install for `pyccl`, see [issue](https://github.com/LSSTDE…
ggalloni Sep 24, 2024
89fca30
test OS dependency
ggalloni Sep 24, 2024
3d4ec8c
ops
ggalloni Sep 24, 2024
aa76f23
test
ggalloni Sep 24, 2024
0e7a449
revert
ggalloni Sep 24, 2024
97679b3
Merge remote-tracking branch 'upstream/master' into type_hinting
ggalloni Oct 2, 2024
028955e
Fix `cobaya` version in requirements
ggalloni Oct 2, 2024
133e059
oopsie
ggalloni Oct 2, 2024
caba78b
Test new refactor of Cobaya type checking
ggalloni Nov 6, 2024
5fdfe05
Correct attribute name
ggalloni Nov 6, 2024
ba59550
Point to `master` branch of `cobaya`
ggalloni Nov 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dill
fuzzywuzzy
astropy
getdist
cobaya
cobaya @ git+https://github.com/CobayaSampler/cobaya.git@master
pyccl
sacc
fgspectra>=1.1.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = [
"pandas", # to remove
"pytest-cov",
"astropy",
"cobaya",
"cobaya @ git+https://github.com/CobayaSampler/cobaya.git@master",
"sacc",
"fgspectra >= 1.1.0",
"pyccl >= 3.0; platform_system!='Windows'",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dill
fuzzywuzzy
astropy
getdist
cobaya
cobaya @ git+https://github.com/CobayaSampler/cobaya.git@master
pyccl >= 3.0; platform_system!='Windows'
sacc
fgspectra>=1.1.0
Expand Down
21 changes: 17 additions & 4 deletions soliket/bias/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,34 @@
function (have a look at the linear bias model for ideas).
"""

from typing import Any, Dict, List, Optional, Set, Tuple, Union

import numpy as np
from cobaya.theory import Theory


class Bias(Theory):
"""Parent class for bias models."""

kmax: Union[int, float]
nonlinear: bool
z: Union[float, List[float], np.ndarray]
extra_args: Optional[dict]
params: dict

_enforce_types: bool = True

_logz = np.linspace(-3, np.log10(1100), 150)
_default_z_sampling = 10 ** _logz
_default_z_sampling[0] = 0

def initialize(self):
self._var_pairs = set()
self._var_pairs: Set[Tuple[str, str]] = set()

def get_requirements(self):
def get_requirements(self) -> Dict[str, Any]:
return {}

def must_provide(self, **requirements):
def must_provide(self, **requirements) -> Dict[str, Any]:
options = requirements.get("linear_bias") or {}

self.kmax = max(self.kmax, options.get("kmax", self.kmax))
Expand All @@ -69,7 +79,7 @@ def must_provide(self, **requirements):
assert len(self._var_pairs) < 2, "Bias doesn't support other Pk yet"
return needs

def _get_Pk_mm(self):
def _get_Pk_mm(self) -> np.ndarray:
self.k, self.z, Pk_mm = \
self.provider.get_Pk_grid(var_pair=list(self._var_pairs)[0],
nonlinear=self.nonlinear)
Expand All @@ -89,6 +99,9 @@ class Linear_bias(Bias):
Has one free parameter, :math:`b_\mathrm{lin}` (``b_lin``).
"""

_enforce_types: bool = True
params: dict

def calculate(self, state: dict, want_derived: bool = True,
**params_values_dict):
Pk_mm = self._get_Pk_mm()
Expand Down
13 changes: 7 additions & 6 deletions soliket/cash/cash.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Tuple
import numpy as np
from cobaya.likelihood import Likelihood
from .cash_data import CashCData
Expand All @@ -9,25 +9,26 @@

class CashCLikelihood(Likelihood):
name: str = "Cash-C"
datapath = Optional[str]
datapath: Optional[str] = None

def initialize(self):
_enforce_types: bool = True

def initialize(self):
x, N = self._get_data()
self.data = CashCData(self.name, N)

def _get_data(self):
def _get_data(self) -> Tuple[np.ndarray, np.ndarray]:
data = np.loadtxt(self.datapath, unpack=False)
N = data[:, -1] # assume data stored like column_stack([z, q, N])
x = data[:, :-1]
return x, N

def _get_theory(self, **kwargs):
def _get_theory(self, **kwargs) -> np.ndarray:
if "cash_test_logp" in kwargs:
return np.arange(kwargs["cash_test_logp"])
else:
raise NotImplementedError

def logp(self, **params_values):
def logp(self, **params_values) -> float:
theory = self._get_theory(**params_values)
return self.data.loglike(theory)
15 changes: 11 additions & 4 deletions soliket/cash/cash_data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import Union
import numpy as np
from scipy.special import factorial


def cash_c_logpdf(theory, data, usestirling=True):
def cash_c_logpdf(
theory: Union[np.ndarray, float],
data: Union[np.ndarray, float],
usestirling: bool = True
) -> float:
data = np.asarray(data, dtype=int)

ln_fac = np.zeros_like(data, dtype=float)
Expand All @@ -24,13 +29,15 @@ class CashCData:
"""Named multi-dimensional Cash-C distributed data
"""

def __init__(self, name, N, usestirling=True):
def __init__(
self, name: str, N: Union[np.ndarray, float], usestirling: bool = True
):
self.name = str(name)
self.data = N
self.usestirling = usestirling

def __len__(self):
def __len__(self) -> int:
return len(self.data)

def loglike(self, theory):
def loglike(self, theory: Union[np.ndarray, float]) -> float:
return cash_c_logpdf(theory, self.data)
21 changes: 13 additions & 8 deletions soliket/ccl/ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,21 +79,26 @@
# https://cobaya.readthedocs.io/en/devel/theories_and_dependencies.html

import numpy as np
from typing import Sequence
from cobaya.theory import Theory
from typing import Dict, List, Optional, Sequence, Union
from cobaya.theory import Provider, Theory
from cobaya.tools import LoggedError


class CCL(Theory):
"""A theory code wrapper for CCL."""
kmax: Union[int, float]
nonlinear: bool
z: Union[float, List[float], np.ndarray]
extra_args: Optional[dict]

_enforce_types: bool = True

_logz = np.linspace(-3, np.log10(1100), 150)
_default_z_sampling = 10 ** _logz
_default_z_sampling[0] = 0
kmax: float
z: np.ndarray
nonlinear: bool
provider: Provider

def initialize(self) -> None:
def initialize(self):
try:
import pyccl as ccl
except ImportError:
Expand Down Expand Up @@ -125,7 +130,7 @@ def must_provide(self, **requirements) -> dict:
np.atleast_1d(self.z))))

# Dictionary of the things CCL needs from CAMB/CLASS
needs = {}
needs: Dict[str, dict] = {}

if self.kmax:
self.nonlinear = self.nonlinear or options.get('nonlinear', False)
Expand Down Expand Up @@ -231,5 +236,5 @@ def calculate(self, state: dict, want_derived: bool = True,
for required_result, method in self._required_results.items():
state['CCL'][required_result] = method(cosmo)

def get_CCL(self):
def get_CCL(self) -> dict:
return self._current_state['CCL']
42 changes: 20 additions & 22 deletions soliket/clusters/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@
p
"""
import os
from typing import Dict
import numpy as np
import pandas as pd
from scipy.interpolate import interp1d

from soliket.constants import C_KM_S
from soliket.clusters import massfunc as mf
from soliket.poisson import PoissonLikelihood

from .survey import SurveyData
from .sz_utils import szutils, trapezoid
from cobaya import LoggedError

C_KM_S = 2.99792e5
from cobaya.theory import Provider


class SZModel:
Expand All @@ -42,6 +43,7 @@ class ClusterLikelihood(PoissonLikelihood):
"""
name = "Clusters"
columns = ["tsz_signal", "z", "tsz_signal_err"]
provider: Provider

# data_name = resource_filename("soliket",
# "clusters/data/MFMF_WebSkyHalos_A10tSZ_3freq_tiles_mass.fits")
Expand Down Expand Up @@ -96,45 +98,44 @@ def get_requirements(self):
# # model.szk = SZTracer(cosmo)
# return model

def _get_catalog(self):
def _get_catalog(self) -> pd.DataFrame:
self.survey = SurveyData(
self.data_path, self.data_name
) # , MattMock=False,tiles=False)

self.szutils = szutils(self.survey)

df = pd.DataFrame(
return pd.DataFrame(
{
"z": self.survey.clst_z.byteswap().newbyteorder(),
"tsz_signal": self.survey.clst_y0.byteswap().newbyteorder(),
"tsz_signal_err": self.survey.clst_y0err.byteswap().newbyteorder(),
}
)
return df

def _get_om(self):
def _get_om(self) -> float:
return (self.provider.get_param("omch2") + self.provider.get_param("ombh2")) / (
(self.provider.get_param("H0") / 100.0) ** 2
)

def _get_ob(self):
def _get_ob(self) -> float:
return (self.provider.get_param("ombh2")) / (
(self.provider.get_param("H0") / 100.0) ** 2
)

def _get_Ez(self):
def _get_Ez(self) -> np.ndarray:
return self.provider.get_Hubble(self.zarr) / self.provider.get_param("H0")

def _get_Ez_interpolator(self):
def _get_Ez_interpolator(self) -> interp1d:
return interp1d(self.zarr, self._get_Ez())

def _get_DAz(self):
def _get_DAz(self) -> np.ndarray:
return self.provider.get_angular_diameter_distance(self.zarr)

def _get_DAz_interpolator(self):
def _get_DAz_interpolator(self) -> interp1d:
return interp1d(self.zarr, self._get_DAz())

def _get_HMF(self):
def _get_HMF(self) -> mf.HMF:
h = self.provider.get_param("H0") / 100.0

Pk_interpolator = self.provider.get_Pk_interpolator(
Expand All @@ -149,11 +150,9 @@ def _get_HMF(self):
) # self.provider.get_Hubble(self.zarr) / self.provider.get_param("H0")
om = self._get_om()

hmf = mf.HMF(om, Ez, pk=pks * h ** 3, kh=self.k / h, zarr=self.zarr)

return hmf
return mf.HMF(om, Ez, pk=pks * h ** 3, kh=self.k / h, zarr=self.zarr)

def _get_param_vals(self, **kwargs):
def _get_param_vals(self, **kwargs) -> Dict[str, float]:
# Read in scaling relation parameters
# scat = kwargs['scat']
# massbias = kwargs['massbias']
Expand Down Expand Up @@ -190,7 +189,7 @@ def _get_rate_fn(self, **kwargs):

h = self.provider.get_param("H0") / 100.0

def Prob_per_cluster(z, tsz_signal, tsz_signal_err):
def Prob_per_cluster(z, tsz_signal, tsz_signal_err) -> np.ndarray:
c_y = tsz_signal
c_yerr = tsz_signal_err
c_z = z
Expand All @@ -201,13 +200,12 @@ def Prob_per_cluster(z, tsz_signal, tsz_signal_err):

dn_dzdm = 10 ** np.squeeze(dn_dzdm_interp((np.log10(HMF.M), c_z))) * h ** 4.0

ans = trapezoid(dn_dzdm * Pfunc_ind, dx=np.diff(HMF.M, axis=0), axis=0)
return ans
return trapezoid(dn_dzdm * Pfunc_ind, dx=np.diff(HMF.M, axis=0), axis=0)

return Prob_per_cluster
# Implement a function that returns a rate function (function of (tsz_signal, z))

def _get_dVdz(self):
def _get_dVdz(self) -> np.ndarray:
DA_z = self.provider.get_angular_diameter_distance(self.zarr)

dV_dz = (
Expand All @@ -219,7 +217,7 @@ def _get_dVdz(self):
# dV_dz *= (self.provider.get_param("H0") / 100.0) ** 3.0 # was h0
return dV_dz

def _get_n_expected(self, **kwargs):
def _get_n_expected(self, **kwargs) -> float:
"""
Calculates expected number of clusters at the current parameter values.
"""
Expand Down Expand Up @@ -253,7 +251,7 @@ def _get_n_expected(self, **kwargs):

return Ntot

def _test_n_tot(self, **kwargs):
def _test_n_tot(self, **kwargs) -> float:
HMF = self._get_HMF()
# param_vals = self._get_param_vals(**kwargs)
# Ez_fn = self._get_Ez_interpolator()
Expand Down
Loading
Loading