-
Notifications
You must be signed in to change notification settings - Fork 366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom dataloader registry support #2932
base: main
Are you sure you want to change the base?
Conversation
…try' into ori-2907-custom-dataloader-registry
…module / registry big change
for more information, see https://pre-commit.ci
…un, we will later adjust this file
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2932 +/- ##
==========================================
- Coverage 84.81% 83.90% -0.92%
==========================================
Files 173 173
Lines 14793 14920 +127
==========================================
- Hits 12547 12518 -29
- Misses 2246 2402 +156
|
and fix the test for custom dataloaders
src/scvi/model/_scvi.py
Outdated
@setup_anndata_dsp.dedent | ||
def setup_datamodule( | ||
cls, | ||
datamodule, # TODO: what to put here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be pytorch.DataLoader right? Martin has done typing for it in the current code.
src/scvi/model/_scvi.py
Outdated
"state_registry": { | ||
"n_obs": datamodule.n_obs, | ||
"n_vars": datamodule.n_vars, | ||
"column_names": [str(i) for i in column_names], # TODO: from adata (czi)? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not following?
_validate_var_names(adata[modality], var_names[modality]) | ||
logger.debug("Subsetting query vars to reference vars.") | ||
adata._inplace_subset_var(var_names) | ||
_validate_var_names(adata, var_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to verify for dataloaders that the gene names are matching.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And the order.
logger.debug("Subsetting query vars to reference vars.") | ||
adata._inplace_subset_var(var_names) | ||
_validate_var_names(adata, var_names) | ||
registry = attr_dict.pop("registry_") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check and the ones below are independent of datamodule or AnnData, right? Remove the indent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure why every code here is displayed as modified?
@@ -202,7 +215,7 @@ def prepare_query_anndata( | |||
Query adata ready to use in `load_query_data` unless `return_reference_var_names` | |||
in which case a pd.Index of reference var names is returned. | |||
""" | |||
_, var_names, _ = _get_loaded_data(reference_model, device="cpu") | |||
_, var_names, _ = _get_loaded_data(reference_model, device="cpu", adata=adata) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work with a dataloader?
@@ -350,15 +363,15 @@ def requires_grad(key): | |||
par.requires_grad = False | |||
|
|||
|
|||
def _get_loaded_data(reference_model, device=None): | |||
def _get_loaded_data(reference_model, device=None, adata=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need adata here?
self.registry_ = registry | ||
self.summary_stats = _get_summary_stats_from_registry(registry) | ||
elif self.__class__.__name__ == "GIMVI": | ||
# note some models do accept empty registry/adata (e.g: gimvi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not following this one. What is the exception with GIMVI?
src/scvi/model/base/_base_model.py
Outdated
else: | ||
return self._adata_manager.get_from_registry(registry_key) | ||
|
||
# def get_from_registry(self, registry_key: str) -> np.ndarray | pd.DataFrame: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this?
else: | ||
# Case where correct AnnDataManager is found, replay registration as necessary. | ||
adata_manager.validate() | ||
|
||
return adata | ||
|
||
def transfer_fields(self, adata: AnnOrMuData, **kwargs) -> AnnData: | ||
"""Transfer fields from a model to an AnnData object.""" | ||
if self.adata: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where do we need transfer_fields? can we make it work with datamodule?
@@ -627,8 +711,7 @@ def save( | |||
|
|||
# save the model state dict and the trainer state dict only | |||
model_state_dict = self.module.state_dict() | |||
|
|||
var_names = _get_var_names(self.adata, legacy_mudata_format=legacy_mudata_format) | |||
var_names = self.get_var_names(legacy_mudata_format=legacy_mudata_format) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need two get_var_names function?
"Saved model does not contain original setup inputs. " | ||
"Cannot load the original setup." | ||
) | ||
_validate_var_names(adata, var_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be validated also for a dataloader.
|
||
def get_state_registry(self, registry_key: str) -> attrdict: | ||
"""Returns the state registry for the AnnDataField registered with this instance.""" | ||
return attrdict(self.registry_[_FIELD_REGISTRIES_KEY][registry_key][_STATE_REGISTRY_KEY]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work with dataloader. Documentation should be updated then.
@@ -133,7 +133,10 @@ def _initialize_model(cls, adata, attr_dict): | |||
if "pretrained_model" in non_kwargs.keys(): | |||
non_kwargs.pop("pretrained_model") | |||
|
|||
model = cls(adata, **non_kwargs, **kwargs) | |||
if not adata: | |||
adata = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is adata false here? Do we need a default value for registry?
if max_epochs is None: | ||
if datamodule is None: | ||
if self.adata is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should take here n_obs from summary stats to make it compatible with a dataloader.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See below we don't need the if statement.
experiment_name = "mus_musculus" | ||
obs_value_filter = 'is_primary_data == True and tissue_general in ["kidney"] and nnz >= 3000' | ||
|
||
# This is under comments just to save time (selecting highly varkable genes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this block, we don't need it.
dataloader_kwargs={"num_workers": 0, "persistent_workers": False}, | ||
) | ||
|
||
# table of genes should be filtered by soma_joinid - but we should keep the encoded indexes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to.
model = scvi.model.SCVI(adata_orig, n_latent=10) | ||
model.train(max_epochs=1) | ||
|
||
# TODO: do we need to apply those functions to any census model as is? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not getting it.
_ = model.get_reconstruction_error(dataloader=dataloader) | ||
_ = model.get_latent_representation(dataloader=dataloader) | ||
|
||
scvi.model.SCVI.prepare_query_anndata(adata_orig, reference_model=model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test this with a second model trained using dataloader
n_layers = 1 | ||
n_latent = 50 | ||
|
||
scvi.model._scvi.SCVI.setup_datamodule(datamodule) # takes time |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this part. It's nice.
|
||
pprint(datamodule.registry) | ||
|
||
batch_size = 1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does batch size have an effect. I thought it's defined by the datamodule?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes it is redundant
# _ = model_census2.get_latent_representation() | ||
|
||
# takes time | ||
adata = cellxgene_census.get_anndata( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just download 10 cells - see below with obs_value_filter.
var_coords=hv_idx, | ||
) | ||
|
||
# TODO: do we need to put inside (or is it alrady pre-made) - perhaps need to tell CZI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume we need to make it.
adata.obs["batch"] = adata.obs[batch_keys].agg("".join, axis=1).astype("category") | ||
|
||
scvi.model.SCVI.prepare_query_anndata(adata, save_path) | ||
scvi.model.SCVI.load_query_data(registry=datamodule.registry, reference_model=save_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have more tests that actually fail - using different genes without prepare_query_anndata and different batch categories. Assert that it fails.
|
||
scvi.model.SCVI.prepare_query_anndata(adata, model_census2) | ||
|
||
scvi.model.SCVI.setup_anndata(adata, batch_key="batch") # needed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checking that an AnnData model can be trained using datamodule. Do we really want it?
|
||
user_attributes_model_census3 = model_census3._get_user_attributes() | ||
pprint(user_attributes_model_census3) | ||
_ = model_census3.get_elbo() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uses AnnData for inference?
scvi.model.SCVI.prepare_query_anndata(adata, model_census3) | ||
scvi.model.SCVI.load_query_data(adata, model_census3) | ||
|
||
datamodule_inference = CensusSCVIDataModule( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check here that using different genes and different batches fails. You can take much fewer cells here, like 1000.
# Create a dataloder of a CZI module | ||
datapipe = datamodule_inference.datapipe | ||
dataloader = experiment_dataloader(datapipe, num_workers=0, persistent_workers=False) | ||
mapped_dataloader = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this?
|
||
model = SCVI(adata, n_latent=n_latent) | ||
model.train(max_epochs=1) | ||
dataloader = model._make_data_loader(adata) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does model._make_data_loader exist for all models? We should then add the test to the other models as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the dataloader sufficient to also setup the model and does setup_datamodule work for it?
No description provided.