Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(external): implement METHLYANVI for scBS-seq data #3066

Open
wants to merge 40 commits into
base: main
Choose a base branch
from

Conversation

ethanweinberger
Copy link
Contributor

@canergen per our email exchange, this PR adds my MethylANVI model implementation within scvi.external.methylvi.

Copy link

codecov bot commented Dec 3, 2024

Codecov Report

Attention: Patch coverage is 83.25472% with 71 lines in your changes missing coverage. Please review.

Project coverage is 82.96%. Comparing base (a435561) to head (074370a).

Files with missing lines Patch % Lines
src/scvi/external/methylvi/_methylanvi_module.py 76.63% 25 Missing ⚠️
src/scvi/external/methylvi/_methylvi_model.py 76.47% 16 Missing ⚠️
src/scvi/external/methylvi/_base_components.py 86.66% 14 Missing ⚠️
src/scvi/model/base/_training_mixin.py 86.56% 9 Missing ⚠️
src/scvi/external/methylvi/_methylanvi_model.py 89.06% 7 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff            @@
##             main    #3066    +/-   ##
========================================
  Coverage   82.95%   82.96%            
========================================
  Files         181      183     +2     
  Lines       15433    15692   +259     
========================================
+ Hits        12803    13019   +216     
- Misses       2630     2673    +43     
Files with missing lines Coverage Δ
src/scvi/data/fields/__init__.py 100.00% <100.00%> (ø)
src/scvi/data/fields/_scanvi.py 100.00% <100.00%> (ø)
src/scvi/dataloaders/_data_splitting.py 95.47% <ø> (ø)
src/scvi/dataloaders/_semi_dataloader.py 92.30% <ø> (ø)
src/scvi/external/__init__.py 100.00% <100.00%> (ø)
src/scvi/external/methylvi/__init__.py 100.00% <100.00%> (ø)
src/scvi/external/methylvi/_methylvi_module.py 80.00% <100.00%> (ø)
src/scvi/external/methylvi/_utils.py 85.18% <100.00%> (ø)
src/scvi/model/base/__init__.py 100.00% <100.00%> (ø)
src/scvi/external/methylvi/_methylanvi_model.py 89.06% <89.06%> (ø)
... and 4 more

@ethanweinberger ethanweinberger marked this pull request as ready for review December 9, 2024 19:34
@ethanweinberger ethanweinberger changed the title MethylANVI Model feat(external): implement METHLYANVI for scBS-seq data Dec 9, 2024
@ethanweinberger
Copy link
Contributor Author

Hi @canergen @ori-kron-wis

I'm removing the draft label here for now to get your feedback before proceeding further on this PR. In short, the goal of this PR is to add an implementation of the MethylANVI (MethylVI + scANVI) model from the MethylVI manuscript. Beyond just the code necessary for MethylANVI, I've also created some new mixin's to capture shared functions between models and avoid too much code duplication; if it's easier for you I'm happy to move these to separate PRs. I provide more details below:

  • MethylANVI model and module classes can be found in scvi/external/methylvi. To avoid duplicating shared functions between MethylVI/MethylANVI (e.g. get_normalized_methylation) I added a BSSeqMixin class in external/methylvi/_base_components.py. The BSSeqMixin essentially has the same role as RNASeqMixin for methylation.
  • To avoid duplicating functions that are identical for scANVI and MethylANVI, I also created a SemisupervisedTrainingMixin class in scvi/model/base/_training_mixin.py. This mixin currently provides implementations of the _set_indices_and_labels and train functions from scANVI. We could potentially abstract away more functions here (e.g. predict?), but I wanted to get your thoughts here before proceeding because this code touches things outside of external.
  • Minor: I added a MuData wrapper for the LabelsWithUnlabeledObsField for cell type labels.

) -> (np.ndarray | pd.DataFrame) | dict[str, np.ndarray | pd.DataFrame]:
r"""Convenience function to obtain normalized methylation values for a single context.

Only applicable to MuData models.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this limitation? It's anyhow only accessible with MuData?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. I've removed this comment in the latest version.

r"""Convenience function to obtain normalized methylation values for a single context.

Only applicable to MuData models.
use_posterior_mean: bool = True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addition to scANVI? Makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean here? The use_posterior_mean parameter is already present in scANVI.

batch_index=batch,
use_posterior_mean=use_posterior_mean,
)
if self.module.classifier.logits:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need it here? This was for backward compatibility in scANVI.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be always legit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

y_pred = torch.cat(y_pred).numpy()
if not soft:
predictions = []
for p in y_pred:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do list comprehension?

@@ -289,3 +281,348 @@ def sample(
exprs[context] = dist.sample().cpu()

return exprs


class METHYLANVAE(METHYLVAE, BSSeqModuleMixin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put it in two files?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! The two modules can now be found in _methylvi_module.py and _methylanvi_module.py. For consistency I also split the two models into separate files (_methylvi_model.py and _methylanvi_model.py).



class METHYLANVAE(METHYLVAE, BSSeqModuleMixin):
"""Single-cell annotation using variational inference.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Methyl should be in here. Currently it's the acronym for scANVI

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch! Fixed.

w_y[:, group_index] *= w_g[:, [i]]
else:
w_y = self.classifier(z)
return w_y
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we inherit the classifier from scANVI?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we would reuse the function from scANVI. Per @ori-kron-wis's comments I believe for now we should leave this as-is and clean it up in a subsequent PR with the SemisupervisedMixin.

@canergen
Copy link
Member

I like the idea of a SemisupervisedMixin class especially as we'll add more semisupervised models soon'ish. Predict should best case be part of it. However, the input to the classifier is slightly different for both scANVI and methylANVI. I'm also thinking about whether we should also abstract the module code for semisupervised models.
@ori-kron-wis can you also review this code? And provide your ideas on it?

@ori-kron-wis
Copy link
Collaborator

ori-kron-wis commented Dec 24, 2024

Im in favor of reducing code duplications and having BSSeq & SemisupervisedTraining mixins.

I would suggest, as you both already pointed out, that as we expect more models, that are currently under development with "current" scanvi code, to use the SemisupervisedTraining mixin, I would create a new PR just for the scanvi changes here and concentrate only on Methylanvi in this PR, so our future integration will be easier.
The scanvi PR can be checked out from this branch.

It will have some code duplications for now until all other models will move also to SemisupervisedTrainingmixin (and probably expand it beyond methylvi).

I also validated the scnavi changes here, and it looks the same as before.

@ethanweinberger
Copy link
Contributor Author

Hi @canergen @ori-kron-wis. Happy new year! I just finished modifying this PR to address your comments (including reverting the previous changes to scANVI). Tests are currently failing, but the failures appear unrelated to this PR (the tests are related to general data loading functions).

@canergen per your suggestion I added a predict function to the SemisupervisedMixin class. I tried to make the function flexible to handle different numbers of inputs without requiring too many changes in other classes (see my comments). Let me know what you think! I'd also be happy to add a corresponding semisupervised module mixin to this PR, or I can open another PR after to avoid putting too much in one PR.

y_pred = []
for _, tensors in enumerate(scdl):
inference_inputs = self.module._get_inference_input(tensors) # (n_obs, n_vars)
data_inputs = {key: inference_inputs[key] for key in self.module.data_input_keys}
Copy link
Contributor Author

@ethanweinberger ethanweinberger Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@canergen This line is the main change to allow the predict function to be reused across models with different numbers of "data inputs" (e.g. mc + cov for BS-seq vs just RNA counts for RNA-seq).

It comes at the cost of requiring that a new field (data_input_keys) be specified in the module class, but this would enable more code re-use for semisupervised models.

@ori-kron-wis
Copy link
Collaborator

ori-kron-wis commented Jan 6, 2025

@ethanweinberger thanks!
The fail tests are a result of scipy version update integration issue with anndata. I noticed them.
scverse/anndata#1811. Tests will fail until they will release their fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants