Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(example): add the bge embedding example #54

Merged
merged 1 commit into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions advanced/embedding/baai_bge/example_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import subprocess
import time
import socket

from leptonai.client import Client, local, current # noqa: F401


def is_port_open(host, port):
"""Check if a port is open on a given host."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
try:
s.connect((host, port))
return True
except socket.error:
return False


def wait_for_port(host, port, interval=5):
"""Wait for a port to be connectable."""
while True:
if is_port_open(host, port):
print(f"Port {port} on {host} is now connectable!")
break
else:
print(
f"Port {port} on {host} is not ready yet. Retrying in"
f" {interval} seconds..."
)
time.sleep(interval)


def main():
# launches "python main.py" in a subprocess so we can use the client
# to test it.
#
print("Launching the photon in a subprocess on port 8080...")
p = subprocess.Popen(
["python", "main.py"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
wait_for_port("localhost", 8080)

# Note: this is not necessary if you are running the photon in the lepton
# server. To run it in the server, you can do
# lep photon run -n bge -m main.py --resource-shape gpu.a10
# and then instead of using local, you can use the client as
# c = Client(current(), "bge")
# where current() is a helper function to get the current workspace.

c = Client(local())
# c = Client(current(), "bge")
print("\nThe client has the following endpoints:")
print(c.paths())
print("For the encode endpoint, the docstring is as follows:")
print("***begin docstring***")
print(c.encode.__doc__)
print("***end docstring***")

print("\n\nRunning the encode endpoint...")
query = "The quick brown fox jumps over the lazy dog."
ret = c.encode(sentences=query)
print("The result is (truncated, showing first 5):")
print(ret[:5])
print(f"(the full result is a list of {len(ret)} floats)")

print("\n\nRunning the rank endpoint...")
sentences = [
"the fox jumps over the dog",
"the photon is a particle and a wave",
"let the record show that the shipment has arrived",
"the cat jumps on the fox",
]
rank, score = c.rank(query=query, sentences=sentences)
print("The rank and score are respectively:")
print([(r, s) for r, s in zip(rank, score)])
print(f"The query is: {query}")
print("The sentences, ordered from closest to furthest, are:")
print([sentences[i] for i in rank])

print("Finished. Closing everything.")
# Closes the subprocess
p.terminate()


if __name__ == "__main__":
main()
106 changes: 106 additions & 0 deletions advanced/embedding/baai_bge/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
from typing import List, Union, Tuple

from leptonai.photon import Photon, HTTPException


# Transcribed from https://github.com/FlagOpen/FlagEmbedding/tree/master#model-list
AVAILABLE_MODELS_AND_INSTRUCTIONS = {
"BAAI/llm-embedder": None,
"BAAI/bge-reranker-large": None,
"BAAI/bge-reranker-base": None,
"BAAI/bge-large-en-v1.5": (
"Represent this sentence for searching relevant passages: "
),
"BAAI/bge-base-en-v1.5": (
"Represent this sentence for searching relevant passages: "
),
"BAAI/bge-small-en-v1.5": (
"Represent this sentence for searching relevant passages: "
),
"BAAI/bge-large-zh-v1.5": "为这个句子生成表示以用于检索相关文章:",
"BAAI/bge-base-zh-v1.5": "为这个句子生成表示以用于检索相关文章:",
"BAAI/bge-small-zh-v1.5": "为这个句子生成表示以用于检索相关文章:",
"BAAI/bge-large-en": "Represent this sentence for searching relevant passages: ",
"BAAI/bge-base-en": "Represent this sentence for searching relevant passages: ",
"BAAI/bge-small-en": "Represent this sentence for searching relevant passages: ",
"BAAI/bge-large-zh": "为这个句子生成表示以用于检索相关文章:",
"BAAI/bge-base-zh": "为这个句子生成表示以用于检索相关文章:",
"BAAI/bge-small-zh": "为这个句子生成表示以用于检索相关文章:",
}


class BGEEmbedding(Photon):
"""
The BGE embedding model from BAAI.
"""

requirement_dependency = [
"FlagEmbedding",
]

# manage the max concurrency of the photon. This is the number of requests
# that can be handled at the same time.
handler_max_concurrency = 4

DEFAULT_MODEL_NAME = "BAAI/bge-large-en-v1.5"
DEFAULT_QUERY_INSTRUCTION = AVAILABLE_MODELS_AND_INSTRUCTIONS[DEFAULT_MODEL_NAME]
DEFAULT_USE_FP16 = True
DEFAULT_NORMALIZE_EMBEDDINGS = True

def init(self):
from FlagEmbedding import FlagModel

model_name = os.environ.get("MODEL_NAME", self.DEFAULT_MODEL_NAME)
if model_name not in AVAILABLE_MODELS_AND_INSTRUCTIONS:
raise ValueError(
f"Model name {model_name} not found. Available models:"
f" {AVAILABLE_MODELS_AND_INSTRUCTIONS.keys()}"
)
query_instruction = os.environ.get(
"QUERY_INSTRUCTION", self.DEFAULT_QUERY_INSTRUCTION
)
use_fp16 = os.environ.get("USE_FP16", self.DEFAULT_USE_FP16)
normalize_embeddings = os.environ.get(
"NORMALIZE_EMBEDDINGS", self.DEFAULT_NORMALIZE_EMBEDDINGS
)
self._model = FlagModel(
model_name,
query_instruction_for_retrieval=query_instruction,
use_fp16=use_fp16,
normalize_embeddings=normalize_embeddings,
)

@Photon.handler
def encode(self, sentences: Union[str, List[str]]) -> List[float]:
"""
Encodes the current sentences into embeddings.
"""
embeddings = self._model.encode(sentences)
return embeddings.tolist()

@Photon.handler
def rank(self, query: str, sentences: List[str]) -> Tuple[List[int], List[float]]:
"""
Returns a ranked list of indices of the most relevant sentences. This uses
the inner product of the embeddings to rank the sentences. If the model is
not initialized as normalize_embeddings=True, this will raise an error. The
relative similarity scores are also returned.
"""
if not self._model.normalize_embeddings:
raise HTTPException(
status_code=500,
detail="Model must have normalize_embeddings=True to use rank.",
)
embeddings = self._model.encode([query] + sentences)
query_embedding = embeddings[0]
sentence_embeddings = embeddings[1:]
inner_product = query_embedding @ sentence_embeddings.T
sorted_indices = inner_product.argsort()[::-1]
return sorted_indices.tolist(), inner_product[sorted_indices].tolist()


if __name__ == "__main__":
# TODO: change the name of the class "MyPhoton" to the name of your photon
ph = BGEEmbedding()
ph.launch(port=8080)