Skip to content

Commit

Permalink
Merge pull request #199 from databio/bug_fix_r2v_scembed
Browse files Browse the repository at this point in the history
fix tokenizer creation issue with ScEmbed (#198)
  • Loading branch information
nsheff authored Jan 3, 2025
2 parents 11ce521 + ded5b72 commit 94a5f74
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,5 @@ qdrant_storage/

local_cache

lightning_logs
lightning_logs
data/
7 changes: 4 additions & 3 deletions geniml/region2vec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,12 @@ def _load_local_model(self, model_path: str, vocab_path: str, config_path: str):
:param str model_path: Path to the model checkpoint.
:param str vocab_path: Path to the vocabulary file.
"""
_model, tokenizer, config = load_local_region2vec_model(
model_path, vocab_path, config_path
)
_model, config = load_local_region2vec_model(model_path, config_path)
tokenizer = TreeTokenizer(vocab_path)

self._model = _model
self.tokenizer = tokenizer

self.trained = True
if POOLING_METHOD_KEY in config:
self.pooling_method = config[POOLING_METHOD_KEY]
Expand Down
7 changes: 2 additions & 5 deletions geniml/region2vec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,18 +452,15 @@ def export_region2vec_model(

def load_local_region2vec_model(
model_path: str,
vocab_path: str,
config_path: str,
) -> Tuple[Region2Vec, TreeTokenizer, dict]:
) -> Tuple[Region2Vec, dict]:
"""
Load a region2vec model from a local directory
:param str model_path: The path to the model checkpoint file
:param str config_path: The path to the model config file
:param str vocab_path: The path to the model vocabulary file
"""
# init the tokenizer - only one option for now
tokenizer = TreeTokenizer(vocab_path)

# load the model state dict (weights)
params = torch.load(model_path)
Expand Down Expand Up @@ -491,7 +488,7 @@ def load_local_region2vec_model(

model.load_state_dict(params)

return model, tokenizer, config
return model, config


class Region2VecDataset:
Expand Down
7 changes: 4 additions & 3 deletions geniml/scembed/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ def _load_local_model(self, model_path: str, vocab_path: str, config_path: str):
:param str model_path: Path to the model checkpoint.
:param str vocab_path: Path to the vocabulary file.
"""
_model, tokenizer, config = load_local_region2vec_model(
model_path, vocab_path, config_path
)
_model, config = load_local_region2vec_model(model_path, config_path)
tokenizer = AnnDataTokenizer(vocab_path, verbose=True)

self._model = _model
self.tokenizer = tokenizer

if POOLING_METHOD_KEY in config:
self.pooling_method = config[POOLING_METHOD_KEY]

Expand Down

0 comments on commit 94a5f74

Please sign in to comment.