From 317d5b702b8936c6d681883801d85fc3e990db3e Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Tue, 27 Feb 2024 15:00:06 +0000 Subject: [PATCH 1/3] add gl sampling (f00 quad test failing) --- s2fft/sampling/s2_samples.py | 27 ++++++++++++--------- s2fft/utils/quadrature.py | 47 ++++++++++++++++++++++++++++++++++-- tests/test_quadrature.py | 4 +-- tests/test_samples.py | 9 ++++--- 4 files changed, 68 insertions(+), 19 deletions(-) diff --git a/s2fft/sampling/s2_samples.py b/s2fft/sampling/s2_samples.py index dbc6bec3..398bec5b 100644 --- a/s2fft/sampling/s2_samples.py +++ b/s2fft/sampling/s2_samples.py @@ -10,7 +10,7 @@ def ntheta(L: int = None, sampling: str = "mw", nside: int = None) -> int: Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -31,7 +31,7 @@ def ntheta(L: int = None, sampling: str = "mw", nside: int = None) -> int: f"Sampling scheme sampling={sampling} with L={L} not supported" ) - if sampling.lower() == "mw": + if sampling.lower() in ["mw", "gl"]: return L elif sampling.lower() == "mwss": @@ -93,7 +93,7 @@ def nphi_equiang(L: int, sampling: str = "mw") -> int: L (int): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl"}. Defaults to "mw". Raises: ValueError: HEALPix sampling scheme. @@ -104,7 +104,7 @@ def nphi_equiang(L: int, sampling: str = "mw") -> int: int: Number of :math:`\phi` samples. """ - if sampling.lower() == "mw": + if sampling.lower() in ["mw", "gl"]: return 2 * L - 1 elif sampling.lower() == "mwss": @@ -129,7 +129,7 @@ def ftm_shape(L: int, sampling: str = "mw", nside: int = None) -> Tuple[int, int L (int): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. @@ -144,7 +144,7 @@ def ftm_shape(L: int, sampling: str = "mw", nside: int = None) -> Tuple[int, int if sampling.lower() in ["mwss", "healpix"]: return ntheta(L, sampling, nside), 2 * L - elif sampling.lower() in ["mw", "dh"]: + elif sampling.lower() in ["mw", "dh", "gl"]: return ntheta(L, sampling, nside), 2 * L - 1 else: @@ -203,7 +203,7 @@ def thetas(L: int = None, sampling: str = "mw", nside: int = None) -> np.ndarray Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -211,6 +211,9 @@ def thetas(L: int = None, sampling: str = "mw", nside: int = None) -> np.ndarray Returns: np.ndarray: Array of :math:`\theta` samples for given sampling scheme. """ + if sampling.lower() == "gl": + return np.flip(np.arccos(np.polynomial.legendre.leggauss(L)[0])) + t = np.arange(0, ntheta(L=L, sampling=sampling, nside=nside)).astype(np.float64) return t2theta(t, L, sampling, nside) @@ -228,7 +231,7 @@ def t2theta( Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -346,7 +349,7 @@ def phis_equiang(L: int, sampling: str = "mw") -> np.ndarray: L (int, optional): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported equiangular sampling - schemes include {"mw", "mwss", "dh"}. Defaults to "mw". + schemes include {"mw", "mwss", "dh", "gl"}. Defaults to "mw". Returns: np.ndarray: Array of :math:`\phi` samples for given sampling scheme. @@ -365,7 +368,7 @@ def p2phi_equiang(L: int, p: int, sampling: str = "mw") -> np.ndarray: p (int): :math:`\phi` index. sampling (str, optional): Sampling scheme. Supported equiangular sampling - schemes include {"mw", "mwss", "dh"}. Defaults to "mw". + schemes include {"mw", "mwss", "dh", "gl"}. Defaults to "mw". Raises: ValueError: HEALPix sampling not support (only equiangular schemes supported). @@ -376,7 +379,7 @@ def p2phi_equiang(L: int, p: int, sampling: str = "mw") -> np.ndarray: np.ndarray: :math:`\phi` sample(s) for given sampling scheme. """ - if sampling.lower() == "mw": + if sampling.lower() in ["mw", "gl"]: return 2 * p * np.pi / (2 * L - 1) elif sampling.lower() == "mwss": @@ -431,7 +434,7 @@ def f_shape(L: int = None, sampling: str = "mw", nside: int = None) -> Tuple[int L (int, optional): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. diff --git a/s2fft/utils/quadrature.py b/s2fft/utils/quadrature.py index 34d1262e..d4a0d262 100644 --- a/s2fft/utils/quadrature.py +++ b/s2fft/utils/quadrature.py @@ -16,7 +16,7 @@ def quad_weights_transform( L (int): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mwss", "dh", "healpix}. Defaults to "mwss". + {"mwss", "dh", "gl", "healpix}. Defaults to "mwss". spin (int, optional): Harmonic spin. Defaults to 0. @@ -38,6 +38,9 @@ def quad_weights_transform( elif sampling.lower() == "dh": return quad_weights_dh(L) + elif sampling.lower() == "gl": + return quad_weights_gl(L) + elif sampling.lower() == "healpix": return quad_weights_hp(nside) @@ -56,7 +59,7 @@ def quad_weights( Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". spin (int, optional): Harmonic spin. Defaults to 0. @@ -80,6 +83,9 @@ def quad_weights( elif sampling.lower() == "dh": return quad_weights_dh(L) + elif sampling.lower() == "gl": + return quad_weights_gl(L) + elif sampling.lower() == "healpix": return quad_weights_hp(nside) @@ -112,6 +118,43 @@ def quad_weights_hp(nside: int) -> np.ndarray: return hp_weights +def quad_weights_gl(L: int) -> np.ndarray: + r"""Compute GL quadrature weights for :math:`\theta` and :math:`\phi` integration. + + Args: + L (int): Harmonic band-limit. + + Returns: + np.ndarray: Weights computed for each :math:`\theta` (weights are identical + as :math:`\phi` varies for given :math:`\theta`). + """ + x1, x2 = -1.0, 1.0 + ntheta = samples.ntheta(L, "gl") + weights = np.zeros(ntheta, dtype=np.float64) + + m = int((L + 1) / 2) + x1 = 0.5 * (x2 - x1) + + for i in range(1, m + 1): + z = np.cos(np.pi * (i - 0.25) / (L + 0.5)) + z1 = 2.0 + while np.abs(z - z1) > 1e-14: + p1 = 1.0 + p2 = 0.0 + for j in range(1, L + 1): + p3 = p2 + p2 = p1 + p1 = ((2.0 * j - 1.0) * z * p2 - (j - 1.0) * p3) / j + pp = L * (z * p1 - p2) / (z * z - 1.0) + z1 = z + z = z1 - p1 / pp + + weights[i - 1] = 2.0 * x1 / ((1.0 - z**2) * pp * pp) + weights[L + 1 - i - 1] = weights[i - 1] + + return weights + + def quad_weights_dh(L: int) -> np.ndarray: r"""Compute DH quadrature weights for :math:`\theta` and :math:`\phi` integration. diff --git a/tests/test_quadrature.py b/tests/test_quadrature.py index dfe22aa3..db55d8ab 100644 --- a/tests/test_quadrature.py +++ b/tests/test_quadrature.py @@ -6,12 +6,12 @@ @pytest.mark.parametrize("L", [5, 6]) -@pytest.mark.parametrize("sampling", ["mw", "mwss"]) +@pytest.mark.parametrize("sampling", ["gl"]) +# @pytest.mark.parametrize("sampling", ["mw", "mwss", "dh", "gl"]) def test_quadrature_mw_weights(flm_generator, L: int, sampling: str): spin = 0 q = quadrature.quad_weights(L, sampling, spin) - flm = flm_generator(L, spin, reality=False) f = spherical.inverse(flm, L, spin, sampling) diff --git a/tests/test_samples.py b/tests/test_samples.py index 44f18687..95aa9a1a 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("L", [15, 16]) -@pytest.mark.parametrize("sampling", ["mw", "mwss", "dh"]) +@pytest.mark.parametrize("sampling", ["mw", "mwss", "dh", "gl"]) def test_samples_n_and_angles(L: int, sampling: str): # Test ntheta and nphi ntheta = samples.ntheta(L, sampling) @@ -18,8 +18,11 @@ def test_samples_n_and_angles(L: int, sampling: str): assert (ntheta, nphi) == pytest.approx((ntheta_ssht, nphi_ssht)) # Test thetas and phis - t = np.arange(0, ntheta) - thetas = samples.t2theta(t, L, sampling) + if sampling.lower() == "gl": + thetas = samples.thetas(L, sampling) + else: + t = np.arange(0, ntheta) + thetas = samples.t2theta(t, L, sampling) p = np.arange(0, nphi) phis = samples.p2phi_equiang(L, p, sampling) thetas_ssht, phis_ssht = ssht.sample_positions(L, sampling.upper()) From 95ea4f591650ef4a30970f059960977f5f8cc390 Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Mon, 4 Mar 2024 10:07:51 +0000 Subject: [PATCH 2/3] add Gauss-Legendre sampling (numpy and jax) --- s2fft/_version.py | 16 +++++++++ s2fft/base_transforms/spherical.py | 2 +- s2fft/base_transforms/wigner.py | 2 +- s2fft/precompute_transforms/construct.py | 8 ++--- s2fft/precompute_transforms/spherical.py | 12 +++---- s2fft/precompute_transforms/wigner.py | 12 +++---- s2fft/sampling/so3_samples.py | 19 +++++----- s2fft/transforms/otf_recursions.py | 8 ++--- s2fft/transforms/spherical.py | 12 +++---- s2fft/transforms/wigner.py | 12 +++---- s2fft/utils/quadrature.py | 36 +++++++++---------- s2fft/utils/quadrature_jax.py | 46 ++++++++++++++++++++++-- tests/test_quadrature.py | 3 +- tests/test_spherical_base.py | 2 +- tests/test_spherical_precompute.py | 2 +- tests/test_wigner_base.py | 2 +- tests/test_wigner_precompute.py | 2 +- 17 files changed, 125 insertions(+), 71 deletions(-) create mode 100644 s2fft/_version.py diff --git a/s2fft/_version.py b/s2fft/_version.py new file mode 100644 index 00000000..684acf06 --- /dev/null +++ b/s2fft/_version.py @@ -0,0 +1,16 @@ +# file generated by setuptools_scm +# don't change, don't track in version control +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple, Union + VERSION_TUPLE = Tuple[Union[int, str], ...] +else: + VERSION_TUPLE = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE + +__version__ = version = '1.0.1.dev12' +__version_tuple__ = version_tuple = (1, 0, 1, 'dev12') diff --git a/s2fft/base_transforms/spherical.py b/s2fft/base_transforms/spherical.py index 4625e7bc..6bfaf5de 100644 --- a/s2fft/base_transforms/spherical.py +++ b/s2fft/base_transforms/spherical.py @@ -27,7 +27,7 @@ def inverse( spin (int, optional): Harmonic spin. Defaults to 0. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. diff --git a/s2fft/base_transforms/wigner.py b/s2fft/base_transforms/wigner.py index ba5295bf..b7ed60fe 100644 --- a/s2fft/base_transforms/wigner.py +++ b/s2fft/base_transforms/wigner.py @@ -33,7 +33,7 @@ def inverse( L_lower (int, optional): Harmonic lower bound. Defaults to 0. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to diff --git a/s2fft/precompute_transforms/construct.py b/s2fft/precompute_transforms/construct.py index 95667db3..87b42720 100644 --- a/s2fft/precompute_transforms/construct.py +++ b/s2fft/precompute_transforms/construct.py @@ -33,7 +33,7 @@ def spin_spherical_kernel( Defaults to False. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int): HEALPix Nside resolution parameter. Only required if sampling="healpix". @@ -103,7 +103,7 @@ def spin_spherical_kernel_jax( Defaults to False. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int): HEALPix Nside resolution parameter. Only required if sampling="healpix". @@ -181,7 +181,7 @@ def wigner_kernel( Defaults to False. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int): HEALPix Nside resolution parameter. Only required if sampling="healpix". @@ -252,7 +252,7 @@ def wigner_kernel_jax( Defaults to False. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". nside (int): HEALPix Nside resolution parameter. Only required if sampling="healpix". diff --git a/s2fft/precompute_transforms/spherical.py b/s2fft/precompute_transforms/spherical.py index 056200eb..f1ef5b34 100644 --- a/s2fft/precompute_transforms/spherical.py +++ b/s2fft/precompute_transforms/spherical.py @@ -32,7 +32,7 @@ def inverse( kernel (np.ndarray, optional): Wigner-d kernel. Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -85,7 +85,7 @@ def inverse_transform( L (int): Harmonic band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -149,7 +149,7 @@ def inverse_transform_jax( L (int): Harmonic band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -211,7 +211,7 @@ def forward( kernel (np.ndarray, optional): Wigner-d kernel. Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -264,7 +264,7 @@ def forward_transform( L (int): Harmonic band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -332,7 +332,7 @@ def forward_transform_jax( L (int): Harmonic band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. diff --git a/s2fft/precompute_transforms/wigner.py b/s2fft/precompute_transforms/wigner.py index 5cb119ef..31f13b0e 100644 --- a/s2fft/precompute_transforms/wigner.py +++ b/s2fft/precompute_transforms/wigner.py @@ -32,7 +32,7 @@ def inverse( kernel (np.ndarray, optional): Wigner-d kernel. Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -86,7 +86,7 @@ def inverse_transform( N (int): Directional band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -147,7 +147,7 @@ def inverse_transform_jax( N (int): Directional band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -227,7 +227,7 @@ def forward( kernel (np.ndarray, optional): Wigner-d kernel. Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -279,7 +279,7 @@ def forward_transform( N (int): Directional band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. @@ -360,7 +360,7 @@ def forward_transform_jax( N (int): Directional band-limit. sampling (str): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. + {"mw", "mwss", "dh", "gl", "healpix"}. reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. diff --git a/s2fft/sampling/so3_samples.py b/s2fft/sampling/so3_samples.py index 5c56b550..63a29784 100644 --- a/s2fft/sampling/so3_samples.py +++ b/s2fft/sampling/so3_samples.py @@ -23,7 +23,7 @@ def f_shape( N (int): Directional band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -35,15 +35,12 @@ def f_shape( Tuple[int,int,int]: Shape of pixel-space sampling of rotation group :math:`SO(3)`. """ - if sampling in ["mw", "mwss", "dh"]: + if sampling in ["mw", "mwss", "dh", "gl"]: return _ngamma(N), _nbeta(L, sampling), _nalpha(L, sampling) elif sampling.lower() == "healpix": return _ngamma(N), 12 * nside**2 - elif sampling.lower() == "healpix": - return 12 * nside**2, _ngamma(N) - else: raise ValueError(f"Sampling scheme sampling={sampling} not supported") @@ -76,7 +73,7 @@ def fnab_shape( N (int): Directional band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl"}. Defaults to "mw". nside (int, optional): HEALPix Nside resolution parameter. @@ -91,7 +88,7 @@ def fnab_shape( if sampling.lower() in ["mwss", "healpix"]: return _ngamma(N), samples.ntheta(L, sampling, nside), 2 * L - elif sampling.lower() in ["mw", "dh"]: + elif sampling.lower() in ["mw", "dh", "gl"]: return _ngamma(N), samples.ntheta(L, sampling, nside), 2 * L - 1 else: @@ -121,7 +118,7 @@ def _nalpha(L: int, sampling: str = "mw") -> int: L (int): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl"}. Defaults to "mw". Raises: ValueError: Unknown sampling scheme. @@ -129,7 +126,7 @@ def _nalpha(L: int, sampling: str = "mw") -> int: Returns: int: Number of :math:`\alpha` samples. """ - if sampling.lower() in ["mw", "dh"]: + if sampling.lower() in ["mw", "dh", "gl"]: return 2 * L - 1 elif sampling.lower() == "mwss": @@ -146,7 +143,7 @@ def _nbeta(L: int, sampling: str = "mw") -> int: L (int): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl"}. Defaults to "mw". Raises: ValueError: Unknown sampling scheme. @@ -154,7 +151,7 @@ def _nbeta(L: int, sampling: str = "mw") -> int: Returns: int: Number of :math:`\beta` samples. """ - if sampling.lower() == "mw": + if sampling.lower() in ["mw", "gl"]: return L elif sampling.lower() == "mwss": diff --git a/s2fft/transforms/otf_recursions.py b/s2fft/transforms/otf_recursions.py index 74b80b9c..530c3501 100644 --- a/s2fft/transforms/otf_recursions.py +++ b/s2fft/transforms/otf_recursions.py @@ -46,7 +46,7 @@ def inverse_latitudinal_step( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to @@ -195,7 +195,7 @@ def inverse_latitudinal_step_jax( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to @@ -441,7 +441,7 @@ def forward_latitudinal_step( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to @@ -597,7 +597,7 @@ def forward_latitudinal_step_jax( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". reality (bool, optional): Whether the signal on the sphere is real. If so, conjugate symmetry is exploited to reduce computational costs. Defaults to diff --git a/s2fft/transforms/spherical.py b/s2fft/transforms/spherical.py index 406326b7..c11d9a2c 100644 --- a/s2fft/transforms/spherical.py +++ b/s2fft/transforms/spherical.py @@ -40,7 +40,7 @@ def inverse( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -112,7 +112,7 @@ def inverse_numpy( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -211,7 +211,7 @@ def inverse_jax( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -334,7 +334,7 @@ def forward( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -406,7 +406,7 @@ def forward_numpy( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -533,7 +533,7 @@ def forward_jax( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". diff --git a/s2fft/transforms/wigner.py b/s2fft/transforms/wigner.py index 60ec8563..d05451b5 100644 --- a/s2fft/transforms/wigner.py +++ b/s2fft/transforms/wigner.py @@ -43,7 +43,7 @@ def inverse( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -124,7 +124,7 @@ def inverse_numpy( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -211,7 +211,7 @@ def inverse_jax( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -366,7 +366,7 @@ def forward( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -445,7 +445,7 @@ def forward_numpy( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". @@ -541,7 +541,7 @@ def forward_jax( if sampling="healpix". Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". method (str, optional): Execution mode in {"numpy", "jax"}. Defaults to "numpy". diff --git a/s2fft/utils/quadrature.py b/s2fft/utils/quadrature.py index d4a0d262..571effcf 100644 --- a/s2fft/utils/quadrature.py +++ b/s2fft/utils/quadrature.py @@ -135,24 +135,24 @@ def quad_weights_gl(L: int) -> np.ndarray: m = int((L + 1) / 2) x1 = 0.5 * (x2 - x1) - for i in range(1, m + 1): - z = np.cos(np.pi * (i - 0.25) / (L + 0.5)) - z1 = 2.0 - while np.abs(z - z1) > 1e-14: - p1 = 1.0 - p2 = 0.0 - for j in range(1, L + 1): - p3 = p2 - p2 = p1 - p1 = ((2.0 * j - 1.0) * z * p2 - (j - 1.0) * p3) / j - pp = L * (z * p1 - p2) / (z * z - 1.0) - z1 = z - z = z1 - p1 / pp - - weights[i - 1] = 2.0 * x1 / ((1.0 - z**2) * pp * pp) - weights[L + 1 - i - 1] = weights[i - 1] - - return weights + i = np.arange(1, m + 1) + z = np.cos(np.pi * (i - 0.25) / (L + 0.5)) + z1 = 2.0 + while np.max(np.abs(z - z1)) > 1e-14: + p1 = 1.0 + p2 = 0.0 + for j in range(1, L + 1): + p3 = p2 + p2 = p1 + p1 = ((2.0 * j - 1.0) * z * p2 - (j - 1.0) * p3) / j + pp = L * (z * p1 - p2) / (z * z - 1.0) + z1 = z + z = z1 - p1 / pp + + weights[i - 1] = 2.0 * x1 / ((1.0 - z**2) * pp * pp) + weights[L + 1 - i - 1] = weights[i - 1] + + return weights * 2 * np.pi / (2 * L - 1) def quad_weights_dh(L: int) -> np.ndarray: diff --git a/s2fft/utils/quadrature_jax.py b/s2fft/utils/quadrature_jax.py index 787c0946..b82d1116 100644 --- a/s2fft/utils/quadrature_jax.py +++ b/s2fft/utils/quadrature_jax.py @@ -20,7 +20,7 @@ def quad_weights_transform( L (int): Harmonic band-limit. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mwss", "dh", "healpix}. Defaults to "mwss". + {"mwss", "dh", "gl", "healpix}. Defaults to "mwss". nside (int, optional): HEALPix Nside resolution parameter. Only required if sampling="healpix". Defaults to None. @@ -40,6 +40,9 @@ def quad_weights_transform( elif sampling.lower() == "dh": return quad_weights_dh(L) + elif sampling.lower() == "gl": + return quad_weights_gl(L) + elif sampling.lower() == "healpix": return quad_weights_hp(nside) @@ -58,7 +61,7 @@ def quad_weights(L: int = None, sampling: str = "mw", nside: int = None) -> jnp. Defaults to None. sampling (str, optional): Sampling scheme. Supported sampling schemes include - {"mw", "mwss", "dh", "healpix"}. Defaults to "mw". + {"mw", "mwss", "dh", "gl", "healpix"}. Defaults to "mw". spin (int, optional): Harmonic spin. Defaults to 0. @@ -82,6 +85,9 @@ def quad_weights(L: int = None, sampling: str = "mw", nside: int = None) -> jnp. elif sampling.lower() == "dh": return quad_weights_dh(L) + elif sampling.lower() == "gl": + return quad_weights_gl(L) + elif sampling.lower() == "healpix": return quad_weights_hp(nside) @@ -111,6 +117,42 @@ def quad_weights_hp(nside: int) -> jnp.ndarray: return jnp.ones(rings, dtype=jnp.float64) * 4 * jnp.pi / npix +def quad_weights_gl(L: int) -> jnp.ndarray: + r"""Compute GL quadrature weights for :math:`\theta` and :math:`\phi` integration. + + Args: + L (int): Harmonic band-limit. + + Returns: + jnp.ndarray: Weights computed for each :math:`\theta` (weights are identical + as :math:`\phi` varies for given :math:`\theta`). + """ + x1, x2 = -1.0, 1.0 + ntheta = samples.ntheta(L, "gl") + weights = jnp.zeros(ntheta, dtype=jnp.float64) + + m = int((L + 1) / 2) + x1 = 0.5 * (x2 - x1) + i = jnp.arange(1, m + 1) + z = jnp.cos(jnp.pi * (i - 0.25) / (L + 0.5)) + z1 = 2.0 + while jnp.max(jnp.abs(z - z1)) > 1e-14: + p1 = 1.0 + p2 = 0.0 + for j in range(1, L + 1): + p3 = p2 + p2 = p1 + p1 = ((2.0 * j - 1.0) * z * p2 - (j - 1.0) * p3) / j + pp = L * (z * p1 - p2) / (z * z - 1.0) + z1 = z + z = z1 - p1 / pp + + weights[i - 1] = 2.0 * x1 / ((1.0 - z**2) * pp * pp) + weights[L + 1 - i - 1] = weights[i - 1] + + return weights * 2 * jnp.pi / (2 * L - 1) + + @partial(jit, static_argnums=(0)) def quad_weights_dh(L: int) -> jnp.ndarray: r"""Compute DH quadrature weights for :math:`\theta` and :math:`\phi` integration. diff --git a/tests/test_quadrature.py b/tests/test_quadrature.py index db55d8ab..c67eb5b3 100644 --- a/tests/test_quadrature.py +++ b/tests/test_quadrature.py @@ -6,8 +6,7 @@ @pytest.mark.parametrize("L", [5, 6]) -@pytest.mark.parametrize("sampling", ["gl"]) -# @pytest.mark.parametrize("sampling", ["mw", "mwss", "dh", "gl"]) +@pytest.mark.parametrize("sampling", ["mw", "mwss", "dh", "gl"]) def test_quadrature_mw_weights(flm_generator, L: int, sampling: str): spin = 0 diff --git a/tests/test_spherical_base.py b/tests/test_spherical_base.py index b888f0a2..9c79df75 100644 --- a/tests/test_spherical_base.py +++ b/tests/test_spherical_base.py @@ -11,7 +11,7 @@ spin_to_test = [-2, 0, 1] nside_to_test = [4, 5] L_to_nside_ratio = [2, 3] -sampling_to_test = ["mw", "mwss", "dh"] +sampling_to_test = ["mw", "mwss", "dh", "gl"] method_to_test = ["direct", "sov", "sov_fft", "sov_fft_vectorized"] reality_to_test = [False, True] diff --git a/tests/test_spherical_precompute.py b/tests/test_spherical_precompute.py index 0f6b19c0..97c6dfb3 100644 --- a/tests/test_spherical_precompute.py +++ b/tests/test_spherical_precompute.py @@ -8,7 +8,7 @@ spin_to_test = [-2, 0, 1] nside_to_test = [4, 5] L_to_nside_ratio = [2, 3] -sampling_to_test = ["mw", "mwss", "dh"] +sampling_to_test = ["mw", "mwss", "dh", "gl"] reality_to_test = [True, False] methods_to_test = ["numpy", "jax"] diff --git a/tests/test_wigner_base.py b/tests/test_wigner_base.py index 64404ac3..e1f5b67a 100644 --- a/tests/test_wigner_base.py +++ b/tests/test_wigner_base.py @@ -9,7 +9,7 @@ N_to_test = [2, 3] L_lower_to_test = [0, 2] sampling_schemes_so3 = ["mw", "mwss"] -sampling_schemes = ["mw", "mwss", "dh"] +sampling_schemes = ["mw", "mwss", "dh", "gl"] reality_to_test = [False, True] diff --git a/tests/test_wigner_precompute.py b/tests/test_wigner_precompute.py index 84bef0ea..9853660e 100644 --- a/tests/test_wigner_precompute.py +++ b/tests/test_wigner_precompute.py @@ -10,7 +10,7 @@ nside_to_test = [4, 6] L_to_nside_ratio = [2] reality_to_test = [False, True] -sampling_schemes = ["mw", "mwss", "dh"] +sampling_schemes = ["mw", "mwss", "dh", "gl"] methods_to_test = ["numpy", "jax"] From 43b21988dbbe8984bed65d543014722bdffb8955 Mon Sep 17 00:00:00 2001 From: CosmoMatt Date: Mon, 4 Mar 2024 11:14:36 +0000 Subject: [PATCH 3/3] include GL in readme --- README.md | 5 +++-- docs/index.rst | 6 ++++-- requirements/requirements-core.txt | 5 ++++- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index c52bf3a5..a1fed65f 100644 --- a/README.md +++ b/README.md @@ -57,8 +57,9 @@ isolattitude sampling scheme. A number of sampling schemes are currently supported. The equiangular sampling schemes of [McEwen & Wiaux -(2012)](https://arxiv.org/abs/1110.6298) and [Driscoll & Healy -(1995)](https://www.sciencedirect.com/science/article/pii/S0196885884710086) +(2012)](https://arxiv.org/abs/1110.6298), [Driscoll & Healy +(1995)](https://www.sciencedirect.com/science/article/pii/S0196885884710086) +and [Gauss-Legendre (1986)](https://link.springer.com/article/10.1007/BF02519350) are supported, which exhibit associated sampling theorems and so harmonic transforms can be computed to machine precision. Note that the McEwen & Wiaux sampling theorem reduces the Nyquist rate on the sphere diff --git a/docs/index.rst b/docs/index.rst index 78c61339..c7b072fd 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -21,7 +21,8 @@ Algorithms |:zap:| ``S2FFT`` leverages new algorithmic structures that can he highly parallelised and distributed, and so map very well onto the architecture of hardware accelerators (i.e. GPUs and TPUs). In particular, these algorithms are based on new Wigner-d recursions -that are stable to high angular resolution :math:`L`. The diagram below illustrates the recursions (for further details see Price & McEwen 2023). +that are stable to high angular resolution :math:`L`. The diagram below illustrates the +recursions (for further details see Price & McEwen 2023). .. image:: ./assets/figures/Wigner_recursion_github_docs.png @@ -46,7 +47,8 @@ Sampling |:earth_africa:| The structure of the algorithms implemented in ``S2FFT`` can support any isolattitude sampling scheme. A number of sampling schemes are currently supported. -The equiangular sampling schemes of `McEwen & Wiaux (2012) `_ and `Driscoll & Healy (1995) `_ are supported, which exhibit associated sampling theorems and so harmonic transforms can be computed to machine precision. Note that the McEwen & Wiaux sampling theorem reduces the Nyquist rate on the sphere by a factor of two compared to the Driscoll & Healy approach, halving the number of spherical samples required. +The equiangular sampling schemes of `McEwen & Wiaux (2012) `_, +`Driscoll & Healy (1995) `_, and `Gauss-Legendre (1986) `_ are supported, which exhibit associated sampling theorems and so harmonic transforms can be computed to machine precision. Note that the McEwen & Wiaux sampling theorem reduces the Nyquist rate on the sphere by a factor of two compared to the Driscoll & Healy approach, halving the number of spherical samples required. The popular `HEALPix `_ sampling scheme (`Gorski et al. 2005 `_) is also supported. The HEALPix sampling does not exhibit a sampling theorem and so the corresponding harmonic transforms do not achieve machine precision but exhibit some error. However, the HEALPix sampling provides pixels of equal areas, which has many practical advantages. diff --git a/requirements/requirements-core.txt b/requirements/requirements-core.txt index a2a369e2..7c86b0ea 100644 --- a/requirements/requirements-core.txt +++ b/requirements/requirements-core.txt @@ -3,4 +3,7 @@ numpy>=1.20 colorlog pyyaml jax>=0.3.13 -jaxlib \ No newline at end of file +jaxlib + +# Remove when subpackage functionality is fixed. +torch \ No newline at end of file