diff --git a/sgkit/stats/preprocessing.py b/sgkit/stats/preprocessing.py index 195cada68..137f2a29c 100644 --- a/sgkit/stats/preprocessing.py +++ b/sgkit/stats/preprocessing.py @@ -1,4 +1,4 @@ -from typing import Hashable, Optional +from typing import Hashable, Optional, Sequence, Union import dask.array as da import numpy as np @@ -185,3 +185,91 @@ def filter_partial_calls( ) new_ds[variables.call_genotype_complete].attrs["mixed_ploidy"] = mixed_ploidy return conditional_merge_datasets(ds, new_ds, merge) + + +def mean_impute( + ds: Dataset, + variable: str, + dim: Union[Hashable, Sequence[Hashable]] = "samples", + merge: bool = True, +) -> Dataset: + """Mean impute a masked variable + + Parameters + ---------- + ds + Dataset containing the variable to be imputed. + variable + Input variable name. + Variables ``{variable}`` and ``{variable}_masked`` must be present in ``ds``. + dim: + Dimension(s) along which the means are computed. + merge + If True (the default), merge the input dataset and the computed + output variables into a single dataset, otherwise return only + the imputed output variables. + See :ref:`dataset_merge` for more details. + + Returns + ------- + Dataset containing :data:`sgkit.variables.{variable}_imputed` in which masked entries are + replaced with the mean values of the unmasked. + + Examples + -------- + + >>> import sgkit as sg, numpy as np + >>> from sgkit.stats.preprocessing import mean_impute + >>> ds = sg.simulate_genotype_call_dataset(n_variant=4, n_sample=10, seed=1, missing_pct=.1) + >>> sg.display_genotypes(ds) # doctest: +NORMALIZE_WHITESPACE + samples S0 S1 S2 S3 S4 S5 S6 S7 S8 S9 + variants + 0 1/0 1/0 ./0 1/1 0/1 1/0 0/0 0/0 ./. 1/0 + 1 ./1 0/0 1/0 1/1 ./0 1/1 1/1 1/1 0/1 0/0 + 2 ./0 1/1 1/. 0/1 0/1 0/1 1/0 ./1 1/0 0/0 + 3 0/1 0/1 0/1 0/1 1/1 1/1 0/0 1/1 0/1 1/0 + + >>> ds["call_dosage"] = ds.call_genotype.sum(dim="ploidy").astype(float) + >>> ds["call_dosage_mask"] = ds.call_genotype_mask.any(dim='ploidy') + >>> ds["call_dosage"] = ds["call_dosage"].where(~ds["call_dosage_mask"], np.nan) + >>> ds["call_dosage"] # doctest: +NORMALIZE_WHITESPACE + + array([[ 1., 1., nan, 2., 1., 1., 0., 0., nan, 1.], + [nan, 0., 1., 2., nan, 2., 2., 2., 1., 0.], + [nan, 2., nan, 1., 1., 1., 1., nan, 1., 0.], + [ 1., 1., 1., 1., 2., 2., 0., 2., 1., 1.]]) + Dimensions without coordinates: variants, samples + + >>> ds = mean_impute(ds, variable='call_dosage', dim='samples') + >>> ds["call_dosage_imputed"] # doctest: +NORMALIZE_WHITESPACE + + array([[1. , 1. , 0.875, 2. , 1. , 1. , 0. , 0. , 0.875, + 1. ], + [1.25 , 0. , 1. , 2. , 1.25 , 2. , 2. , 2. , 1. , + 0. ], + [1. , 2. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , + 0. ], + [1. , 1. , 1. , 1. , 2. , 2. , 0. , 2. , 1. , + 1. ]]) + Dimensions without coordinates: variants, samples + Attributes: + comment: Dosages imputed, encoded as floats, with NaN indicating a missi... + + """ + + variables.validate(ds, variable) + variables.validate(ds, f"{variable}_mask") + + unmasked = ~ds[f"{variable}_mask"] + new_ds = create_dataset( + { + f"{variable}_imputed": ( + ds[variable].dims, + ds[variable] + .where(unmasked, ds[variable].where(unmasked).mean(dim=dim)) + .data, + ) + } + ) + + return conditional_merge_datasets(ds, new_ds, merge) diff --git a/sgkit/variables.py b/sgkit/variables.py index c1dbb5e9a..70fe0d9bb 100644 --- a/sgkit/variables.py +++ b/sgkit/variables.py @@ -214,6 +214,15 @@ def _check_field( ) ) +call_dosage_imputed, call_dosage_imputed_spec = SgkitVariables.register_variable( + ArrayLikeSpec( + "call_dosage_imputed", + kind="f", + ndim=2, + __doc__="""Dosages imputed, encoded as floats, with NaN indicating a missing value.""", + ) +) + call_dosage_mask, call_dosage_mask_spec = SgkitVariables.register_variable( ArrayLikeSpec( "call_dosage_mask", @@ -293,6 +302,16 @@ def _check_field( ) ) +call_genotype_imputed, call_genotype_imputed_spec = SgkitVariables.register_variable( + ArrayLikeSpec( + "call_genotype_imputed", + kind="f", + ndim=3, + __doc__=""" +Call genotype imputed """, + ) +) + ( call_genotype_probability, call_genotype_probability_spec, @@ -305,6 +324,19 @@ def _check_field( ) ) +( + call_genotype_probability_imputed, + call_genotype_probability_imputed_spec, +) = SgkitVariables.register_variable( + ArrayLikeSpec( + "call_genotype_probability_imputed", + kind="f", + ndim=3, + __doc__="""Genotype probabilities Imputed.""", + ) +) + + ( call_genotype_probability_mask, call_genotype_probability_mask_spec,