-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(example): add the bge embedding example
- Loading branch information
Showing
2 changed files
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import subprocess | ||
import time | ||
import socket | ||
|
||
from leptonai.client import Client, local, current # noqa: F401 | ||
|
||
|
||
def is_port_open(host, port): | ||
"""Check if a port is open on a given host.""" | ||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | ||
s.settimeout(1) | ||
try: | ||
s.connect((host, port)) | ||
return True | ||
except socket.error: | ||
return False | ||
|
||
|
||
def wait_for_port(host, port, interval=5): | ||
"""Wait for a port to be connectable.""" | ||
while True: | ||
if is_port_open(host, port): | ||
print(f"Port {port} on {host} is now connectable!") | ||
break | ||
else: | ||
print( | ||
f"Port {port} on {host} is not ready yet. Retrying in" | ||
f" {interval} seconds..." | ||
) | ||
time.sleep(interval) | ||
|
||
|
||
def main(): | ||
# launches "python main.py" in a subprocess so we can use the client | ||
# to test it. | ||
# | ||
print("Launching the photon in a subprocess on port 8080...") | ||
p = subprocess.Popen( | ||
["python", "main.py"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL | ||
) | ||
wait_for_port("localhost", 8080) | ||
|
||
# Note: this is not necessary if you are running the photon in the lepton | ||
# server. To run it in the server, you can do | ||
# lep photon run -n bge -m main.py --resource-shape gpu.a10 | ||
# and then instead of using local, you can use the client as | ||
# c = Client(current(), "bge") | ||
# where current() is a helper function to get the current workspace. | ||
|
||
c = Client(local()) | ||
# c = Client(current(), "bge") | ||
print("\nThe client has the following endpoints:") | ||
print(c.paths()) | ||
print("For the encode endpoint, the docstring is as follows:") | ||
print("***begin docstring***") | ||
print(c.encode.__doc__) | ||
print("***end docstring***") | ||
|
||
print("\n\nRunning the encode endpoint...") | ||
query = "The quick brown fox jumps over the lazy dog." | ||
ret = c.encode(sentences=query) | ||
print("The result is (truncated, showing first 5):") | ||
print(ret[:5]) | ||
print(f"(the full result is a list of {len(ret)} floats)") | ||
|
||
print("\n\nRunning the rank endpoint...") | ||
sentences = [ | ||
"the fox jumps over the dog", | ||
"the photon is a particle and a wave", | ||
"let the record show that the shipment has arrived", | ||
"the cat jumps on the fox", | ||
] | ||
rank, score = c.rank(query=query, sentences=sentences) | ||
print("The rank and score are respectively:") | ||
print([(r, s) for r, s in zip(rank, score)]) | ||
print(f"The query is: {query}") | ||
print("The sentences, ordered from closest to furthest, are:") | ||
print([sentences[i] for i in rank]) | ||
|
||
print("Finished. Closing everything.") | ||
# Closes the subprocess | ||
p.terminate() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import os | ||
from typing import List, Union, Tuple | ||
|
||
from leptonai.photon import Photon, HTTPException | ||
|
||
|
||
# Transcribed from https://github.com/FlagOpen/FlagEmbedding/tree/master#model-list | ||
AVAILABLE_MODELS_AND_INSTRUCTIONS = { | ||
"BAAI/llm-embedder": None, | ||
"BAAI/bge-reranker-large": None, | ||
"BAAI/bge-reranker-base": None, | ||
"BAAI/bge-large-en-v1.5": ( | ||
"Represent this sentence for searching relevant passages: " | ||
), | ||
"BAAI/bge-base-en-v1.5": ( | ||
"Represent this sentence for searching relevant passages: " | ||
), | ||
"BAAI/bge-small-en-v1.5": ( | ||
"Represent this sentence for searching relevant passages: " | ||
), | ||
"BAAI/bge-large-zh-v1.5": "为这个句子生成表示以用于检索相关文章:", | ||
"BAAI/bge-base-zh-v1.5": "为这个句子生成表示以用于检索相关文章:", | ||
"BAAI/bge-small-zh-v1.5": "为这个句子生成表示以用于检索相关文章:", | ||
"BAAI/bge-large-en": "Represent this sentence for searching relevant passages: ", | ||
"BAAI/bge-base-en": "Represent this sentence for searching relevant passages: ", | ||
"BAAI/bge-small-en": "Represent this sentence for searching relevant passages: ", | ||
"BAAI/bge-large-zh": "为这个句子生成表示以用于检索相关文章:", | ||
"BAAI/bge-base-zh": "为这个句子生成表示以用于检索相关文章:", | ||
"BAAI/bge-small-zh": "为这个句子生成表示以用于检索相关文章:", | ||
} | ||
|
||
|
||
class BGEEmbedding(Photon): | ||
""" | ||
The BGE embedding model from BAAI. | ||
""" | ||
|
||
requirement_dependency = [ | ||
"FlagEmbedding", | ||
] | ||
|
||
# manage the max concurrency of the photon. This is the number of requests | ||
# that can be handled at the same time. | ||
handler_max_concurrency = 4 | ||
|
||
DEFAULT_MODEL_NAME = "BAAI/bge-large-en-v1.5" | ||
DEFAULT_QUERY_INSTRUCTION = AVAILABLE_MODELS_AND_INSTRUCTIONS[DEFAULT_MODEL_NAME] | ||
DEFAULT_USE_FP16 = True | ||
DEFAULT_NORMALIZE_EMBEDDINGS = True | ||
|
||
def init(self): | ||
from FlagEmbedding import FlagModel | ||
|
||
model_name = os.environ.get("MODEL_NAME", self.DEFAULT_MODEL_NAME) | ||
if model_name not in AVAILABLE_MODELS_AND_INSTRUCTIONS: | ||
raise ValueError( | ||
f"Model name {model_name} not found. Available models:" | ||
f" {AVAILABLE_MODELS_AND_INSTRUCTIONS.keys()}" | ||
) | ||
query_instruction = os.environ.get( | ||
"QUERY_INSTRUCTION", self.DEFAULT_QUERY_INSTRUCTION | ||
) | ||
use_fp16 = os.environ.get("USE_FP16", self.DEFAULT_USE_FP16) | ||
normalize_embeddings = os.environ.get( | ||
"NORMALIZE_EMBEDDINGS", self.DEFAULT_NORMALIZE_EMBEDDINGS | ||
) | ||
self._model = FlagModel( | ||
model_name, | ||
query_instruction_for_retrieval=query_instruction, | ||
use_fp16=use_fp16, | ||
normalize_embeddings=normalize_embeddings, | ||
) | ||
|
||
@Photon.handler | ||
def encode(self, sentences: Union[str, List[str]]) -> List[float]: | ||
""" | ||
Encodes the current sentences into embeddings. | ||
""" | ||
embeddings = self._model.encode(sentences) | ||
return embeddings.tolist() | ||
|
||
@Photon.handler | ||
def rank(self, query: str, sentences: List[str]) -> Tuple[List[int], List[float]]: | ||
""" | ||
Returns a ranked list of indices of the most relevant sentences. This uses | ||
the inner product of the embeddings to rank the sentences. If the model is | ||
not initialized as normalize_embeddings=True, this will raise an error. The | ||
relative similarity scores are also returned. | ||
""" | ||
if not self._model.normalize_embeddings: | ||
raise HTTPException( | ||
status_code=500, | ||
detail="Model must have normalize_embeddings=True to use rank.", | ||
) | ||
embeddings = self._model.encode([query] + sentences) | ||
query_embedding = embeddings[0] | ||
sentence_embeddings = embeddings[1:] | ||
inner_product = query_embedding @ sentence_embeddings.T | ||
sorted_indices = inner_product.argsort()[::-1] | ||
return sorted_indices.tolist(), inner_product[sorted_indices].tolist() | ||
|
||
|
||
if __name__ == "__main__": | ||
# TODO: change the name of the class "MyPhoton" to the name of your photon | ||
ph = BGEEmbedding() | ||
ph.launch(port=8080) |