-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #51 from PathwayCommons/major-refactor
Major refactor
- Loading branch information
Showing
7 changed files
with
166 additions
and
130 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
__version__ = "0.1.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from enum import Enum | ||
from typing import Tuple, List, Optional | ||
|
||
import torch | ||
import typer | ||
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer | ||
|
||
|
||
class Emoji(Enum): | ||
# Emoji's used in typer.secho calls | ||
# See: https://github.com/carpedm20/emoji/blob/master/emoji/unicode_codes.py | ||
SUCCESS = "\U00002705" | ||
WARNING = "\U000026A0" | ||
FAST = "\U0001F3C3" | ||
|
||
|
||
def get_device(cuda_device: int = -1) -> torch.device: | ||
"""Return a `torch.cuda` device if `torch.cuda.is_available()` and `cuda_device>=0`. | ||
Otherwise returns a `torch.cpu` device. | ||
""" | ||
if cuda_device != -1 and torch.cuda.is_available(): | ||
device = torch.device("cuda") | ||
typer.secho( | ||
f"{Emoji.FAST.value} Using CUDA device {torch.cuda.get_device_name()} with index" | ||
f" {torch.cuda.current_device()}.", | ||
fg=typer.colors.GREEN, | ||
bold=True, | ||
) | ||
else: | ||
device = torch.device("cpu") | ||
typer.secho( | ||
f"{Emoji.WARNING.value} Using CPU. Note that this will be many times slower than a GPU.", | ||
fg=typer.colors.YELLOW, | ||
bold=True, | ||
) | ||
return device | ||
|
||
|
||
def setup_model_and_tokenizer( | ||
pretrained_model_name_or_path: str, cuda_device: int = -1 | ||
) -> Tuple[PreTrainedTokenizer, PreTrainedModel]: | ||
"""Given a HuggingFace Transformer `pretrained_model_name_or_path`, return the corresponding | ||
model and tokenizer. Optionally, places the model on `cuda_device`, if available. | ||
""" | ||
device = get_device(cuda_device) | ||
# Load the Transformers tokenizer | ||
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) | ||
typer.secho( | ||
( | ||
f'{Emoji.SUCCESS.value} Tokenizer "{pretrained_model_name_or_path}" from Transformers' | ||
" loaded successfully." | ||
), | ||
fg=typer.colors.GREEN, | ||
bold=True, | ||
) | ||
# Load the Transformers model | ||
model = AutoModel.from_pretrained(pretrained_model_name_or_path) | ||
model = model.to(device) | ||
model.eval() | ||
typer.secho( | ||
( | ||
f'{Emoji.SUCCESS.value} Model "{pretrained_model_name_or_path}" from Transformers' | ||
" loaded successfully." | ||
), | ||
fg=typer.colors.GREEN, | ||
bold=True, | ||
) | ||
|
||
return tokenizer, model | ||
|
||
|
||
@torch.no_grad() | ||
def encode_with_transformer( | ||
text: List[str], | ||
tokenizer: PreTrainedTokenizer, | ||
model: PreTrainedModel, | ||
max_length: Optional[int] = None, | ||
mean_pool: bool = True, | ||
) -> torch.Tensor: | ||
|
||
inputs = tokenizer( | ||
text, padding=True, truncation=True, max_length=max_length, return_tensors="pt" | ||
) | ||
for name, tensor in inputs.items(): | ||
inputs[name] = tensor.to(model.device) | ||
attention_mask = inputs["attention_mask"] | ||
output = model(**inputs).last_hidden_state | ||
|
||
if mean_pool: | ||
embedding = torch.sum(output * attention_mask.unsqueeze(-1), dim=1) / torch.clamp( | ||
torch.sum(attention_mask, dim=1, keepdims=True), min=1e-9 | ||
) | ||
else: | ||
embedding = output[:, 0, :] | ||
|
||
return embedding |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from typing import Callable, List, Optional | ||
|
||
import torch | ||
from pydantic import BaseModel, validator | ||
from transformers import PreTrainedModel, PreTrainedTokenizer | ||
|
||
from semantic_search.ncbi import uids_to_docs | ||
|
||
UID = str | ||
|
||
# See: https://fastapi.tiangolo.com/tutorial/body/ for more details on creating a Request Body. | ||
|
||
|
||
class Model(BaseModel): | ||
tokenizer: PreTrainedModel = None | ||
model: PreTrainedTokenizer = None | ||
similarity: Callable[..., torch.Tensor] = None # type: ignore | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
|
||
class Document(BaseModel): | ||
uid: UID | ||
text: str | ||
|
||
|
||
class Query(BaseModel): | ||
query: Document | ||
documents: List[Document] = [] | ||
top_k: Optional[int] = None | ||
|
||
@validator("query", "documents", pre=True) | ||
def normalize_document(cls, v, field): | ||
if field.name == "query": | ||
v = [v] | ||
|
||
normalized_docs = [] | ||
for doc in v: | ||
if isinstance(doc, UID): | ||
normalized_docs.append(Document(**uids_to_docs([doc])[0])) | ||
else: | ||
normalized_docs.append(doc) | ||
return normalized_docs[0] if field.name == "query" else normalized_docs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
version="0.1.0", | ||
author="John Giorgi", | ||
author_email="[email protected]", | ||
description=("A simple semantic search engine powered by HuggingFace's Transformers library."), | ||
description=("A simple semantic search engine for scientific papers."), | ||
long_description=long_description, | ||
long_description_content_type="text/markdown", | ||
url="https://github.com/PathwayCommons/semantic-search", | ||
|
@@ -21,18 +21,18 @@ | |
"License :: OSI Approved :: Apache Software License", | ||
"Operating System :: OS Independent", | ||
"Programming Language :: Python :: 3", | ||
"Programming Language :: Python :: 3.6", | ||
"Programming Language :: Python :: 3.7", | ||
"Programming Language :: Python :: 3.8", | ||
"Programming Language :: Python :: 3.9", | ||
"Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
"Typing :: Typed", | ||
], | ||
python_requires=">=3.7.0", | ||
install_requires=[ | ||
"fastapi>=0.62.0", | ||
"uvicorn>=0.13.0", | ||
"torch>=1.7.0", | ||
"transformers>=4.0.1,<4.4.0", | ||
"fastapi>=0.63.0", | ||
"uvicorn>=0.13.4", | ||
"torch>=1.7.1", | ||
"transformers>=4.3.3", | ||
"typer>=0.3.2", | ||
"python-dotenv>=0.15.0", | ||
"xmltodict>=0.12.0", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from semantic_search import __version__ | ||
|
||
|
||
def test_version(): | ||
assert __version__ == "0.1.0" |