Skip to content

Commit

Permalink
Don't use local functions to wrap numba functions
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Nov 4, 2024
1 parent 9dd940e commit d7779a3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
8 changes: 4 additions & 4 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def variant_stats(
--------
:func:`count_variant_genotypes`
"""
from .aggregation_numba_fns import count_hom
from .aggregation_numba_fns import count_hom_new_axis

variables.validate(ds, {call_genotype: variables.call_genotype_spec})
mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False)
Expand All @@ -697,7 +697,7 @@ def variant_stats(
G = da.asarray(ds[call_genotype].data)
H = xr.DataArray(
da.map_blocks(
lambda *args: count_hom(*args)[:, np.newaxis, :],
count_hom_new_axis,
G,
np.zeros(3, np.uint64),
drop_axis=2,
Expand Down Expand Up @@ -796,7 +796,7 @@ def sample_stats(
ValueError
If the dataset contains mixed-ploidy genotype calls.
"""
from .aggregation_numba_fns import count_hom
from .aggregation_numba_fns import count_hom_new_axis

variables.validate(ds, {call_genotype: variables.call_genotype_spec})
mixed_ploidy = ds[call_genotype].attrs.get("mixed_ploidy", False)
Expand All @@ -805,7 +805,7 @@ def sample_stats(
GT = da.asarray(ds[call_genotype].transpose("samples", "variants", "ploidy").data)
H = xr.DataArray(
da.map_blocks(
lambda *args: count_hom(*args)[:, np.newaxis, :],
count_hom_new_axis,
GT,
np.zeros(3, np.uint64),
drop_axis=2,
Expand Down
6 changes: 6 additions & 0 deletions sgkit/stats/aggregation_numba_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# in a separate file here, and imported dynamically to avoid
# initial compilation overhead.

import numpy as np

from sgkit.accelerate import numba_guvectorize, numba_jit
from sgkit.typing import ArrayLike

Expand Down Expand Up @@ -102,3 +104,7 @@ def count_hom(
index = _classify_hom(genotypes[i])
if index >= 0:
out[index] += 1


def count_hom_new_axis(genotypes: ArrayLike, _: ArrayLike) -> ArrayLike:
return count_hom(genotypes, _)[:, np.newaxis, :]
4 changes: 1 addition & 3 deletions sgkit/stats/popgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,7 @@ def pbs(
cohorts = cohorts or list(itertools.combinations(range(n_cohorts), 3)) # type: ignore
ct = _cohorts_to_array(cohorts, ds.indexes.get("cohorts_0", None))

p = da.map_blocks(
lambda t: _pbs_cohorts(t, ct), t, chunks=shape, new_axis=3, dtype=np.float64
)
p = da.map_blocks(_pbs_cohorts, t, ct, chunks=shape, new_axis=3, dtype=np.float64)
assert_array_shape(p, n_windows, n_cohorts, n_cohorts, n_cohorts)

new_ds = create_dataset(
Expand Down

0 comments on commit d7779a3

Please sign in to comment.