From b87c28e9f361eaf0ba1abff50e3ac76d60611fbe Mon Sep 17 00:00:00 2001 From: Emmanuel Noutahi Date: Mon, 18 Sep 2023 14:55:54 +0100 Subject: [PATCH 1/2] wip --- molfeat/trans/pretrained/hf_transformers.py | 22 ++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/molfeat/trans/pretrained/hf_transformers.py b/molfeat/trans/pretrained/hf_transformers.py index f254144..f2c5a46 100644 --- a/molfeat/trans/pretrained/hf_transformers.py +++ b/molfeat/trans/pretrained/hf_transformers.py @@ -13,13 +13,7 @@ from dataclasses import dataclass from loguru import logger -from transformers import EncoderDecoderModel -from transformers import AutoTokenizer -from transformers import AutoModel -from transformers import AutoConfig -from transformers import MODEL_MAPPING -from transformers import PreTrainedModel -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from molfeat.utils import requires from molfeat.trans.pretrained.base import PretrainedMolTransformer from molfeat.store.loader import PretrainedStoreModel from molfeat.store import ModelStore @@ -27,6 +21,15 @@ from molfeat.utils.converters import SmilesConverter from molfeat.utils.pooler import get_default_hgf_pooler +if requires.check("transformers"): + from transformers import EncoderDecoderModel + from transformers import AutoTokenizer + from transformers import AutoModel + from transformers import AutoConfig + from transformers import MODEL_MAPPING + from transformers import PreTrainedModel + from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + @dataclass class HFExperiment: @@ -260,6 +263,11 @@ def __init__( random_seed: random seed to use for reproducibility whenever a DNN pooler is used (e.g bert/roberta) """ + if not requires.check("transformers"): + raise ValueError( + "Cannot find transformers and/or tokenizers. It's required for this featurizer !" + ) + super().__init__( dtype=dtype, device=device, From ae789209bda8699e11a0a761f6f638e7c31f09bb Mon Sep 17 00:00:00 2001 From: Emmanuel Noutahi Date: Sat, 23 Sep 2023 14:30:41 -0400 Subject: [PATCH 2/2] make transformers optional --- tests/test_pretrained.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_pretrained.py b/tests/test_pretrained.py index a233576..c22d82b 100644 --- a/tests/test_pretrained.py +++ b/tests/test_pretrained.py @@ -124,7 +124,9 @@ def test_dgl_pretrained_cache(self): # add buffers self.assertLessEqual(cached_run, ori_run + time_buffer) - +@pytest.mark.xfail( + not requires.check("transformers"), reason="3rd party module transformers is missing" +) class TestHGFTransformer(ut.TestCase): r"""Test cases for FingerprintsTransformer""" smiles = [