From f23397aed9118d4b15f878cf1ca501937e2a0de4 Mon Sep 17 00:00:00 2001 From: Nathan LeRoy Date: Thu, 12 Dec 2024 13:03:57 -0500 Subject: [PATCH 1/2] fix tokenizer creation issue with ScEmbed (#198) --- geniml/region2vec/main.py | 7 ++++--- geniml/region2vec/utils.py | 7 ++----- geniml/scembed/main.py | 7 ++++--- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/geniml/region2vec/main.py b/geniml/region2vec/main.py index 5d210ac..7de3829 100644 --- a/geniml/region2vec/main.py +++ b/geniml/region2vec/main.py @@ -140,11 +140,12 @@ def _load_local_model(self, model_path: str, vocab_path: str, config_path: str): :param str model_path: Path to the model checkpoint. :param str vocab_path: Path to the vocabulary file. """ - _model, tokenizer, config = load_local_region2vec_model( - model_path, vocab_path, config_path - ) + _model, config = load_local_region2vec_model(model_path, config_path) + tokenizer = TreeTokenizer(vocab_path) + self._model = _model self.tokenizer = tokenizer + self.trained = True if POOLING_METHOD_KEY in config: self.pooling_method = config[POOLING_METHOD_KEY] diff --git a/geniml/region2vec/utils.py b/geniml/region2vec/utils.py index cbabbe6..01e23be 100644 --- a/geniml/region2vec/utils.py +++ b/geniml/region2vec/utils.py @@ -452,9 +452,8 @@ def export_region2vec_model( def load_local_region2vec_model( model_path: str, - vocab_path: str, config_path: str, -) -> Tuple[Region2Vec, TreeTokenizer, dict]: +) -> Tuple[Region2Vec, dict]: """ Load a region2vec model from a local directory @@ -462,8 +461,6 @@ def load_local_region2vec_model( :param str config_path: The path to the model config file :param str vocab_path: The path to the model vocabulary file """ - # init the tokenizer - only one option for now - tokenizer = TreeTokenizer(vocab_path) # load the model state dict (weights) params = torch.load(model_path) @@ -491,7 +488,7 @@ def load_local_region2vec_model( model.load_state_dict(params) - return model, tokenizer, config + return model, config class Region2VecDataset: diff --git a/geniml/scembed/main.py b/geniml/scembed/main.py index c456f3a..e9fed16 100755 --- a/geniml/scembed/main.py +++ b/geniml/scembed/main.py @@ -131,11 +131,12 @@ def _load_local_model(self, model_path: str, vocab_path: str, config_path: str): :param str model_path: Path to the model checkpoint. :param str vocab_path: Path to the vocabulary file. """ - _model, tokenizer, config = load_local_region2vec_model( - model_path, vocab_path, config_path - ) + _model, config = load_local_region2vec_model(model_path, config_path) + tokenizer = AnnDataTokenizer(vocab_path, verbose=True) + self._model = _model self.tokenizer = tokenizer + if POOLING_METHOD_KEY in config: self.pooling_method = config[POOLING_METHOD_KEY] From ded5b72789f708e332e85efddfdb0f32a65ea49f Mon Sep 17 00:00:00 2001 From: Nathan LeRoy Date: Thu, 12 Dec 2024 13:06:13 -0500 Subject: [PATCH 2/2] add data to gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index e766269..0f68ae4 100644 --- a/.gitignore +++ b/.gitignore @@ -197,4 +197,5 @@ qdrant_storage/ local_cache -lightning_logs \ No newline at end of file +lightning_logs +data/ \ No newline at end of file