Skip to content

Commit

Permalink
Use cellxgene download in HubModel when needed (#1884)
Browse files Browse the repository at this point in the history
* wip

* fix

* address pr feedback

* address feedback
  • Loading branch information
watiss authored Feb 1, 2023
1 parent 744ba4c commit 6cb0709
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 9 deletions.
1 change: 1 addition & 0 deletions scvi/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,5 @@
"organize_multiome_anndatas",
"pbmc_seurat_v4_cite_seq",
"add_dna_sequence",
"cellxgene",
]
8 changes: 5 additions & 3 deletions scvi/data/_built_in_data/_cellxgene.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Optional, Union

import requests
from anndata import AnnData, read_h5ad
Expand All @@ -14,7 +14,8 @@ def _load_cellxgene_dataset(
collection_id: Optional[str] = None,
filename: Optional[str] = None,
save_path: str = "data/",
) -> AnnData:
return_path: bool = False,
) -> Union[AnnData, str]:
"""
Loads a file from `cellxgene <https://cellxgene.cziscience.com/>`_ portal.
Expand Down Expand Up @@ -61,6 +62,7 @@ def _load_cellxgene_dataset(
filename = "local.h5ad"
_download(presigned_url, save_path, filename)
file_path = os.path.join(save_path, filename)
if return_path:
return file_path
adata = read_h5ad(file_path)

return adata
6 changes: 4 additions & 2 deletions scvi/data/_datasets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Union

import anndata

Expand Down Expand Up @@ -123,7 +123,8 @@ def cellxgene(
url: str,
filename: Optional[str] = None,
save_path: str = "data/",
) -> anndata.AnnData:
return_path: bool = False,
) -> Union[anndata.AnnData, str]:
"""
Loads a file from `cellxgene <https://cellxgene.cziscience.com/>`_ portal.
Expand All @@ -144,6 +145,7 @@ def cellxgene(
url=url,
filename=filename,
save_path=save_path,
return_path=return_path,
)


Expand Down
8 changes: 6 additions & 2 deletions scvi/hub/hub_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ class HubMetadata:
anndata_version
The version of anndata used during model training.
training_data_url
Link to the training data used to train the model, if it is too large to be uploaded to the hub.
Link to the training data used to train the model, if it is too large to be uploaded to the hub. This can be
a cellxgene explorer session url. However it cannot be a self-hosted session -- it must be from the cellxgene
portal (https://cellxgene.cziscience.com/).
model_parent_module
The parent module of the model class. Defaults to `scvi.model`. Change this if you are using a model
class that is not in the `scvi.model` module, for example, if you are using a model class from a custom module.
Expand Down Expand Up @@ -102,7 +104,9 @@ class HubModelCardHelper:
data_is_minified
Whether the training data uploaded with the model has been minified.
training_data_url
Link to the training data used to train the model, if it is too large to be uploaded to the hub.
Link to the training data used to train the model, if it is too large to be uploaded to the hub. This can be
a cellxgene explorer session url. However it cannot be a self-hosted session -- it must be from the cellxgene
portal (https://cellxgene.cziscience.com/).
training_code_url
Link to the code used to train the model.
model_parent_module
Expand Down
17 changes: 15 additions & 2 deletions scvi/hub/hub_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from huggingface_hub import HfApi, ModelCard, create_repo, snapshot_download
from rich.markdown import Markdown

from scvi.data import cellxgene
from scvi.data._download import _download
from scvi.hub.hub_metadata import HubMetadata, HubModelCardHelper
from scvi.model.base import BaseModelClass
Expand Down Expand Up @@ -297,15 +298,27 @@ def read_adata(self):
logger.info("No data found on disk. Skipping...")

def read_large_training_adata(self):
"""Downloads the large training adata, if it exists, then load it into memory. Otherwise, this is a no-op."""
"""
Downloads the large training adata, if it exists, then load it into memory. Otherwise, this is a no-op
Notes
-----
The large training data url can be a cellxgene explorer session url. However it cannot be a self-hosted
session. In other words, it must be from the cellxgene portal (https://cellxgene.cziscience.com/).
"""
training_data_url = self.metadata.training_data_url
if training_data_url is not None:
logger.info(
f"Downloading large training dataset from this url:\n{training_data_url}..."
)
dn = Path(self._large_training_adata_path).parent.as_posix()
fn = Path(self._large_training_adata_path).name
_download(training_data_url, dn, fn)
url_parts = training_data_url.split("/")
url_last_part = url_parts[-2] if url_parts[-1] == "" else url_parts[-1]
if url_last_part.endswith(".cxg"):
_ = cellxgene(training_data_url, fn, dn, return_path=True)
else:
_download(training_data_url, dn, fn)
logger.info("Reading large training data...")
self._large_training_adata = anndata.read_h5ad(
self._large_training_adata_path
Expand Down

0 comments on commit 6cb0709

Please sign in to comment.