Skip to content

Commit

Permalink
feat(example): add the bge embedding example (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yangqing authored Nov 7, 2023
1 parent fcdfd27 commit dd41123
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 0 deletions.
86 changes: 86 additions & 0 deletions advanced/embedding/baai_bge/example_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import subprocess
import time
import socket

from leptonai.client import Client, local, current # noqa: F401


def is_port_open(host, port):
"""Check if a port is open on a given host."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
try:
s.connect((host, port))
return True
except socket.error:
return False


def wait_for_port(host, port, interval=5):
"""Wait for a port to be connectable."""
while True:
if is_port_open(host, port):
print(f"Port {port} on {host} is now connectable!")
break
else:
print(
f"Port {port} on {host} is not ready yet. Retrying in"
f" {interval} seconds..."
)
time.sleep(interval)


def main():
# launches "python main.py" in a subprocess so we can use the client
# to test it.
#
print("Launching the photon in a subprocess on port 8080...")
p = subprocess.Popen(
["python", "main.py"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
wait_for_port("localhost", 8080)

# Note: this is not necessary if you are running the photon in the lepton
# server. To run it in the server, you can do
# lep photon run -n bge -m main.py --resource-shape gpu.a10
# and then instead of using local, you can use the client as
# c = Client(current(), "bge")
# where current() is a helper function to get the current workspace.

c = Client(local())
# c = Client(current(), "bge")
print("\nThe client has the following endpoints:")
print(c.paths())
print("For the encode endpoint, the docstring is as follows:")
print("***begin docstring***")
print(c.encode.__doc__)
print("***end docstring***")

print("\n\nRunning the encode endpoint...")
query = "The quick brown fox jumps over the lazy dog."
ret = c.encode(sentences=query)
print("The result is (truncated, showing first 5):")
print(ret[:5])
print(f"(the full result is a list of {len(ret)} floats)")

print("\n\nRunning the rank endpoint...")
sentences = [
"the fox jumps over the dog",
"the photon is a particle and a wave",
"let the record show that the shipment has arrived",
"the cat jumps on the fox",
]
rank, score = c.rank(query=query, sentences=sentences)
print("The rank and score are respectively:")
print([(r, s) for r, s in zip(rank, score)])
print(f"The query is: {query}")
print("The sentences, ordered from closest to furthest, are:")
print([sentences[i] for i in rank])

print("Finished. Closing everything.")
# Closes the subprocess
p.terminate()


if __name__ == "__main__":
main()
106 changes: 106 additions & 0 deletions advanced/embedding/baai_bge/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os
from typing import List, Union, Tuple

from leptonai.photon import Photon, HTTPException


# Transcribed from https://github.com/FlagOpen/FlagEmbedding/tree/master#model-list
AVAILABLE_MODELS_AND_INSTRUCTIONS = {
"BAAI/llm-embedder": None,
"BAAI/bge-reranker-large": None,
"BAAI/bge-reranker-base": None,
"BAAI/bge-large-en-v1.5": (
"Represent this sentence for searching relevant passages: "
),
"BAAI/bge-base-en-v1.5": (
"Represent this sentence for searching relevant passages: "
),
"BAAI/bge-small-en-v1.5": (
"Represent this sentence for searching relevant passages: "
),
"BAAI/bge-large-zh-v1.5": "为这个句子生成表示以用于检索相关文章:",
"BAAI/bge-base-zh-v1.5": "为这个句子生成表示以用于检索相关文章:",
"BAAI/bge-small-zh-v1.5": "为这个句子生成表示以用于检索相关文章:",
"BAAI/bge-large-en": "Represent this sentence for searching relevant passages: ",
"BAAI/bge-base-en": "Represent this sentence for searching relevant passages: ",
"BAAI/bge-small-en": "Represent this sentence for searching relevant passages: ",
"BAAI/bge-large-zh": "为这个句子生成表示以用于检索相关文章:",
"BAAI/bge-base-zh": "为这个句子生成表示以用于检索相关文章:",
"BAAI/bge-small-zh": "为这个句子生成表示以用于检索相关文章:",
}


class BGEEmbedding(Photon):
"""
The BGE embedding model from BAAI.
"""

requirement_dependency = [
"FlagEmbedding",
]

# manage the max concurrency of the photon. This is the number of requests
# that can be handled at the same time.
handler_max_concurrency = 4

DEFAULT_MODEL_NAME = "BAAI/bge-large-en-v1.5"
DEFAULT_QUERY_INSTRUCTION = AVAILABLE_MODELS_AND_INSTRUCTIONS[DEFAULT_MODEL_NAME]
DEFAULT_USE_FP16 = True
DEFAULT_NORMALIZE_EMBEDDINGS = True

def init(self):
from FlagEmbedding import FlagModel

model_name = os.environ.get("MODEL_NAME", self.DEFAULT_MODEL_NAME)
if model_name not in AVAILABLE_MODELS_AND_INSTRUCTIONS:
raise ValueError(
f"Model name {model_name} not found. Available models:"
f" {AVAILABLE_MODELS_AND_INSTRUCTIONS.keys()}"
)
query_instruction = os.environ.get(
"QUERY_INSTRUCTION", self.DEFAULT_QUERY_INSTRUCTION
)
use_fp16 = os.environ.get("USE_FP16", self.DEFAULT_USE_FP16)
normalize_embeddings = os.environ.get(
"NORMALIZE_EMBEDDINGS", self.DEFAULT_NORMALIZE_EMBEDDINGS
)
self._model = FlagModel(
model_name,
query_instruction_for_retrieval=query_instruction,
use_fp16=use_fp16,
normalize_embeddings=normalize_embeddings,
)

@Photon.handler
def encode(self, sentences: Union[str, List[str]]) -> List[float]:
"""
Encodes the current sentences into embeddings.
"""
embeddings = self._model.encode(sentences)
return embeddings.tolist()

@Photon.handler
def rank(self, query: str, sentences: List[str]) -> Tuple[List[int], List[float]]:
"""
Returns a ranked list of indices of the most relevant sentences. This uses
the inner product of the embeddings to rank the sentences. If the model is
not initialized as normalize_embeddings=True, this will raise an error. The
relative similarity scores are also returned.
"""
if not self._model.normalize_embeddings:
raise HTTPException(
status_code=500,
detail="Model must have normalize_embeddings=True to use rank.",
)
embeddings = self._model.encode([query] + sentences)
query_embedding = embeddings[0]
sentence_embeddings = embeddings[1:]
inner_product = query_embedding @ sentence_embeddings.T
sorted_indices = inner_product.argsort()[::-1]
return sorted_indices.tolist(), inner_product[sorted_indices].tolist()


if __name__ == "__main__":
# TODO: change the name of the class "MyPhoton" to the name of your photon
ph = BGEEmbedding()
ph.launch(port=8080)

0 comments on commit dd41123

Please sign in to comment.