diff --git a/sgkit/stats/aggregation.py b/sgkit/stats/aggregation.py index fe197c3bf..f76dbe691 100644 --- a/sgkit/stats/aggregation.py +++ b/sgkit/stats/aggregation.py @@ -77,6 +77,11 @@ def count_call_alleles( variables.validate(ds, {call_genotype: variables.call_genotype_spec}) n_alleles = ds.sizes["alleles"] G = da.asarray(ds[call_genotype]) + if G.numblocks[2] > 1: + raise ValueError( + f"Variable {call_genotype} must have only a single chunk in the ploidy dimension. " + "Consider rechunking to change the size of chunks." + ) shape = (G.chunks[0], G.chunks[1], n_alleles) # use numpy array to avoid dask task dependencies between chunks N = np.empty(n_alleles, dtype=np.uint8) diff --git a/sgkit/tests/test_aggregation.py b/sgkit/tests/test_aggregation.py index 34df9eddd..e2719c1c1 100644 --- a/sgkit/tests/test_aggregation.py +++ b/sgkit/tests/test_aggregation.py @@ -139,8 +139,10 @@ def test_count_variant_alleles__chunked(using): calls = rs.randint(0, 1, size=(50, 10, 2)) ds = get_dataset(calls) ac1 = count_variant_alleles(ds, using=using) - # Coerce from numpy to multiple chunks in all dimensions - ds["call_genotype"] = ds["call_genotype"].chunk(chunks=(5, 5, 1)) + # Coerce from numpy to multiple chunks in all non-core dimensions + ds["call_genotype"] = ds["call_genotype"].chunk( + chunks={"variants": 5, "samples": 5} + ) ac2 = count_variant_alleles(ds, using=using) assert isinstance(ac2["variant_allele_count"].data, da.Array) xr.testing.assert_equal(ac1, ac2) @@ -273,6 +275,14 @@ def test_count_call_alleles__chunked(): assert hasattr(ac2["call_allele_count"].data, "chunks") xr.testing.assert_equal(ac1, ac2) + # Multiple chunks in core dimension should fail + ds["call_genotype"] = ds["call_genotype"].chunk(chunks={"ploidy": 1}) + with pytest.raises( + ValueError, + match="Variable call_genotype must have only a single chunk in the ploidy dimension", + ): + count_call_alleles(ds) + def test_count_cohort_alleles__multi_variant_multi_sample(): ds = get_dataset( diff --git a/sgkit/tests/test_popgen.py b/sgkit/tests/test_popgen.py index 6bc06acd8..251e0568b 100644 --- a/sgkit/tests/test_popgen.py +++ b/sgkit/tests/test_popgen.py @@ -533,7 +533,7 @@ def test_Garud_h__raise_on_no_windows(): @pytest.mark.filterwarnings("ignore::RuntimeWarning") -@pytest.mark.parametrize("chunks", [((4,), (6,), (4,)), ((2, 2), (3, 3), (2, 2))]) +@pytest.mark.parametrize("chunks", [((4,), (6,), (4,)), ((2, 2), (3, 3), (4))]) def test_observed_heterozygosity(chunks): ds = simulate_genotype_call_dataset( n_variant=4, @@ -599,7 +599,7 @@ def test_observed_heterozygosity(chunks): @pytest.mark.filterwarnings("ignore::RuntimeWarning") -@pytest.mark.parametrize("chunks", [((4,), (6,), (4,)), ((2, 2), (3, 3), (2, 2))]) +@pytest.mark.parametrize("chunks", [((4,), (6,), (4,)), ((2, 2), (3, 3), (4,))]) @pytest.mark.parametrize( "cohorts,expectation", [