diff --git a/requirements.txt b/requirements.txt index 3014054..0e76558 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,6 @@ wbtools==3.0.10 scikit-learn sent2vec-prebuilt nltk~=3.6.3 -transformers~=4.35.2 \ No newline at end of file +transformers~=4.35.2 +datasets +optimum \ No newline at end of file diff --git a/src/backend/sentence_classification_api/sentence_classification_api.py b/src/backend/sentence_classification_api/sentence_classification_api.py index 3979c51..0222cb1 100755 --- a/src/backend/sentence_classification_api/sentence_classification_api.py +++ b/src/backend/sentence_classification_api/sentence_classification_api.py @@ -6,6 +6,7 @@ import falcon from falcon import HTTPStatus from transformers import AutoModelForSequenceClassification, TextClassificationPipeline, AutoTokenizer +from datasets import Dataset logger = logging.getLogger(__name__) @@ -33,17 +34,18 @@ def __init__(self): @staticmethod def load_tokenizers(tokenizers_path): logger.info("Loading tokenizers...") - sentence_tokenizer_all_info_expression = AutoTokenizer.from_pretrained(f"{tokenizers_path}/all_info_expression") + sentence_tokenizer_all_info_expression = AutoTokenizer.from_pretrained( + f"{tokenizers_path}/model_biobert_expression/fully_curatable", use_fast=True) sentence_tokenizer_curatable_expression = AutoTokenizer.from_pretrained( - f"{tokenizers_path}/curatable_expression.joblib") + f"{tokenizers_path}/model_biobert_expression/partially_curatable", use_fast=True) sentence_tokenizer_language_expression = AutoTokenizer.from_pretrained( - f"{tokenizers_path}/language_expression.joblib") + f"{tokenizers_path}/model_biobert_expression/language_related", use_fast=True) sentence_tokenizer_all_info_kinase = AutoTokenizer.from_pretrained( - f"{tokenizers_path}/all_info_kinase.joblib") + f"{tokenizers_path}/model_biobert_kinaseact/fully_curatable", use_fast=True) sentence_tokenizer_curatable_kinase = AutoTokenizer.from_pretrained( - f"{tokenizers_path}/curatable_kinase.joblib") + f"{tokenizers_path}/model_biobert_kinaseact/partially_curatable", use_fast=True) sentence_tokenizer_language_kinase = AutoTokenizer.from_pretrained( - f"{tokenizers_path}/language_kinase.joblib") + f"{tokenizers_path}/model_biobert_kinaseact/language_related", use_fast=True) logger.info("All sentence classifiers loaded") return { "expression": { @@ -61,12 +63,18 @@ def load_tokenizers(tokenizers_path): @staticmethod def load_sentence_classifiers(models_path): logger.info("Loading sentence classifiers...") - sentence_classifier_all_info_expression = AutoModelForSequenceClassification.from_pretrained(f"{models_path}/all_info_expression") - sentence_classifier_curatable_expression = AutoModelForSequenceClassification.from_pretrained(f"{models_path}/curatable_expression.joblib") - sentence_classifier_language_expression = AutoModelForSequenceClassification.from_pretrained(f"{models_path}/language_expression.joblib") - sentence_classifier_all_info_kinase = AutoModelForSequenceClassification.from_pretrained(f"{models_path}/all_info_kinase.joblib") - sentence_classifier_curatable_kinase = AutoModelForSequenceClassification.from_pretrained(f"{models_path}/curatable_kinase.joblib") - sentence_classifier_language_kinase = AutoModelForSequenceClassification.from_pretrained(f"{models_path}/language_kinase.joblib") + sentence_classifier_all_info_expression = AutoModelForSequenceClassification.from_pretrained( + f"{models_path}/model_biobert_expression/fully_curatable").to_bettertransformer() + sentence_classifier_curatable_expression = AutoModelForSequenceClassification.from_pretrained( + f"{models_path}/model_biobert_expression/partially_curatable").to_bettertransformer() + sentence_classifier_language_expression = AutoModelForSequenceClassification.from_pretrained( + f"{models_path}/model_biobert_expression/language_related").to_bettertransformer() + sentence_classifier_all_info_kinase = AutoModelForSequenceClassification.from_pretrained( + f"{models_path}/model_biobert_kinaseact/fully_curatable").to_bettertransformer() + sentence_classifier_curatable_kinase = AutoModelForSequenceClassification.from_pretrained( + f"{models_path}/model_biobert_kinaseact/partially_curatable").to_bettertransformer() + sentence_classifier_language_kinase = AutoModelForSequenceClassification.from_pretrained( + f"{models_path}/model_biobert_kinaseact/language_related").to_bettertransformer() logger.info("All sentence classifiers loaded") return { "expression": { @@ -84,34 +92,49 @@ def load_sentence_classifiers(models_path): def on_post(self, req, resp, req_type): if req_type != "classify_sentences" or "sentences" not in req.media: raise falcon.HTTPError(falcon.HTTP_BAD_REQUEST) - classes_all_info_expression = TextClassificationPipeline( - model=self.sentence_classifiers["expression"]["all_info"], - tokenizer=self.sentence_tokenizers["expression"]["all_info"])(req["media"]["sentences"]) - classes_curatable_expression = TextClassificationPipeline( - model=self.sentence_classifiers["expression"]["curatable"], - tokenizer=self.sentence_tokenizers["expression"]["curatable"])(req["media"]["sentences"]) - classes_language_expression = TextClassificationPipeline( - model=self.sentence_classifiers["expression"]["language"], - tokenizer=self.sentence_tokenizers["expression"]["language"])(req["media"]["sentences"]) - classes_all_info_kinase = TextClassificationPipeline( - model=self.sentence_classifiers["kinase"]["all_info"], - tokenizer=self.sentence_tokenizers["kinase"]["all_info"])(req["media"]["sentences"]) - classes_curatable_kinase = TextClassificationPipeline( - model=self.sentence_classifiers["kinase"]["curatable"], - tokenizer=self.sentence_tokenizers["kinase"]["curatable"])(req["media"]["sentences"]) - classes_language_kinase = TextClassificationPipeline( - model=self.sentence_classifiers["kinase"]["language"], - tokenizer=self.sentence_tokenizers["kinase"]["language"])(req["media"]["sentences"]) + logger.info("started sentences classification...") + dataset = Dataset.from_dict({"text": req.media["sentences"]}) + classes_all_info_expression = [int(classification["label"] == "LABEL_1") for classification in + TextClassificationPipeline( + model=self.sentence_classifiers["expression"]["all_info"], + tokenizer=self.sentence_tokenizers["expression"]["all_info"])( + dataset["text"], batch_size=32)] + classes_curatable_expression = [int(classification["label"] == "LABEL_1") for classification in + TextClassificationPipeline( + model=self.sentence_classifiers["expression"]["curatable"], + tokenizer=self.sentence_tokenizers["expression"]["curatable"])( + dataset["text"], batch_size=32)] + classes_language_expression = [int(classification["label"] == "LABEL_1") for classification in + TextClassificationPipeline( + model=self.sentence_classifiers["expression"]["language"], + tokenizer=self.sentence_tokenizers["expression"]["language"])( + dataset["text"], batch_size=32)] + classes_all_info_kinase = [int(classification["label"] == "LABEL_1") for classification in + TextClassificationPipeline( + model=self.sentence_classifiers["kinase"]["all_info"], + tokenizer=self.sentence_tokenizers["kinase"]["all_info"])( + dataset["text"], batch_size=32)] + classes_curatable_kinase = [int(classification["label"] == "LABEL_1") for classification in + TextClassificationPipeline( + model=self.sentence_classifiers["kinase"]["curatable"], + tokenizer=self.sentence_tokenizers["kinase"]["curatable"])( + dataset["text"], batch_size=32)] + classes_language_kinase = [int(classification["label"] == "LABEL_1") for classification in + TextClassificationPipeline( + model=self.sentence_classifiers["kinase"]["language"], + tokenizer=self.sentence_tokenizers["kinase"]["language"])( + dataset["text"], batch_size=32)] + logger.info("finished sentences classification.") classes = { "expression": { - "all_info": classes_all_info_expression.tolist(), - "curatable": classes_curatable_expression.tolist(), - "language": classes_language_expression.tolist() + "all_info": classes_all_info_expression, + "curatable": classes_curatable_expression, + "language": classes_language_expression }, "kinase": { - "all_info": classes_all_info_kinase.tolist(), - "curatable": classes_curatable_kinase.tolist(), - "language": classes_language_kinase.tolist() + "all_info": classes_all_info_kinase, + "curatable": classes_curatable_kinase, + "language": classes_language_kinase } } resp.body = f'{{"classes": {json.dumps(classes)}}}' diff --git a/src/frontend/curator_dashboard/src/pages/SentenceClassification.js b/src/frontend/curator_dashboard/src/pages/SentenceClassification.js index 5e55f18..c5fce8d 100644 --- a/src/frontend/curator_dashboard/src/pages/SentenceClassification.js +++ b/src/frontend/curator_dashboard/src/pages/SentenceClassification.js @@ -1,4 +1,4 @@ -import React, {useState} from 'react'; +import React, {useEffect, useState} from 'react'; import {Button, FormControl, Spinner, Tab, Table, Tabs} from "react-bootstrap"; import {withRouter} from "react-router-dom"; import queryString from "query-string"; @@ -17,15 +17,26 @@ const SentenceClassification = () => { const [resultType, setResultType] = useState(1); const [dataType, setDataType] = useState('expression'); const [isSpreadsheetLoading, setIsSpreadsheetLoading] = useState(false); + const [isQueryEnabled, setIsQueryEnabled] = useState(false); - let paperID = undefined; - let url = document.location.toString(); - if (url.match("\\?")) { - paperID = queryString.parse(document.location.search).paper_id - } - dispatch(setSelectedPaperID(paperID)); - const queryRes = useQuery('fulltext' + paperID, () => - axios.post(process.env.REACT_APP_API_DB_READ_ADMIN_ENDPOINT + "/converted_text", {paper_id: paperID})); + useEffect(() => { + let paperID; + let url = document.location.toString(); + if (url.match("\\?")) { + paperID = queryString.parse(document.location.search).paper_id + } + dispatch(setSelectedPaperID(paperID)); + setIsQueryEnabled(true); + }, [dispatch]); + + const paperID = queryString.parse(document.location.search).paper_id; + const queryRes = useQuery(['fulltext', paperID], () => + axios.post(process.env.REACT_APP_API_DB_READ_ADMIN_ENDPOINT + "/converted_text", {paper_id: paperID}), + { + referchOnWindowFocus: false, + staleTime: 1000 * 60 * 5, + enabled: isQueryEnabled + }); return(