diff --git a/vec_inf/cli/_cli.py b/vec_inf/cli/_cli.py index bcfba41..c002eea 100644 --- a/vec_inf/cli/_cli.py +++ b/vec_inf/cli/_cli.py @@ -82,8 +82,8 @@ def cli(): ) @click.option( "--model-weights-parent-dir", - type=str, - default="/model-weights", + type=Optional[str], + default=None, help="Path to parent directory containing model weights, default to '/model-weights' for supported models", ) @click.option( @@ -131,6 +131,10 @@ def launch( if model_name in models_df["model_name"].values: default_args = utils.load_default_args(models_df, model_name) + model_type = default_args.pop("model_type") + if model_type == "Text Embedding": + launch_cmd += " --slurm-script embed.slurm" + for arg in default_args: if arg in locals() and locals()[arg] is not None: default_args[arg] = locals()[arg] @@ -155,6 +159,9 @@ def launch( output_dict = {"slurm_job_id": slurm_job_id} for line in output_lines: + if ": " not in line: + continue + key, value = line.split(": ") table.add_row(key, value) output_dict[key.lower().replace(" ", "_")] = value @@ -336,7 +343,9 @@ def metrics(slurm_job_id: int, log_dir: Optional[str] = None) -> None: with Live(refresh_per_second=1, console=CONSOLE) as live: while True: - out_logs = utils.read_slurm_log(slurm_job_name, slurm_job_id, "out", log_dir) + out_logs = utils.read_slurm_log( + slurm_job_name, slurm_job_id, "out", log_dir + ) metrics = utils.get_latest_metric(out_logs) table = utils.create_table(key_title="Metric", value_title="Value") for key, value in metrics.items(): diff --git a/vec_inf/cli/_utils.py b/vec_inf/cli/_utils.py index 6eb5c8e..b070d98 100644 --- a/vec_inf/cli/_utils.py +++ b/vec_inf/cli/_utils.py @@ -135,7 +135,6 @@ def load_default_args(models_df: pd.DataFrame, model_name: str) -> dict: row_data = models_df.loc[models_df["model_name"] == model_name] default_args = row_data.iloc[0].to_dict() default_args.pop("model_name") - default_args.pop("model_type") return default_args diff --git a/vec_inf/embed.slurm b/vec_inf/embed.slurm new file mode 100644 index 0000000..5addb38 --- /dev/null +++ b/vec_inf/embed.slurm @@ -0,0 +1,41 @@ +#!/bin/bash +#SBATCH --cpus-per-task=16 +#SBATCH --mem=64G + +# Load CUDA, change to the cuda version on your environment if different +source /opt/lmod/lmod/init/profile +module load cuda-12.3 +nvidia-smi + +source ${SRC_DIR}/find_port.sh + +# Write server url to file +hostname=${SLURMD_NODENAME} +vllm_port_number=$(find_available_port $hostname 8080 65535) + +echo "Server address: http://${hostname}:${vllm_port_number}/v1" +echo "http://${hostname}:${vllm_port_number}/v1" > ${VLLM_BASE_URL_FILENAME} + +# Activate vllm venv +if [ "$VENV_BASE" = "singularity" ]; then + export SINGULARITY_IMAGE=/projects/aieng/public/vector-inference_0.3.4.sif + export VLLM_NCCL_SO_PATH=/vec-inf/nccl/libnccl.so.2.18.1 + module load singularity-ce/3.8.2 + singularity exec $SINGULARITY_IMAGE ray stop + singularity exec --nv \ + --bind ${MODEL_WEIGHTS_PARENT_DIR}:${MODEL_WEIGHTS_PARENT_DIR} \ + --bind ${SRC_DIR}:${SRC_DIR} \ + $SINGULARITY_IMAGE \ + python3.10 ${SRC_DIR}/embeddings/openai_api_server.py \ + --model ${VLLM_MODEL_WEIGHTS} \ + --port ${vllm_port_number} \ + --trust-remote-code \ + --max-num-seqs ${VLLM_MAX_NUM_SEQS} +else + source ${VENV_BASE}/bin/activate + python3 ${SRC_DIR}/embeddings/openai_api_server.py \ + --model ${VLLM_MODEL_WEIGHTS} \ + --port ${vllm_port_number} \ + --trust-remote-code \ + --max-num-seqs ${VLLM_MAX_NUM_SEQS} +fi \ No newline at end of file diff --git a/vec_inf/embeddings/__init__.py b/vec_inf/embeddings/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vec_inf/embeddings/openai_api_server.py b/vec_inf/embeddings/openai_api_server.py new file mode 100644 index 0000000..c154c5a --- /dev/null +++ b/vec_inf/embeddings/openai_api_server.py @@ -0,0 +1,169 @@ +import argparse +import asyncio +import base64 +from asyncio import Queue +from typing import List, Optional, Union +import sys + +import torch +import uvicorn +from fastapi import FastAPI, Response +from pydantic import BaseModel +from transformers import AutoModel, AutoTokenizer + + +# Define request and response models +class EmbeddingsRequest(BaseModel): + input: Union[str, List[str]] + model: str + encoding_format: Optional[str] = "float" # Default to 'float' + user: Optional[str] = None + + +class EmbeddingData(BaseModel): + object: str + embedding: Union[List[float], str] # Can be a list of floats or a base64 string + index: int + + +class EmbeddingsResponse(BaseModel): + object: str + data: List[EmbeddingData] + model: str + usage: dict + + +parser = argparse.ArgumentParser() +parser.add_argument("--model") +parser.add_argument("--port", type=int) +parser.add_argument("--max-num-seqs", type=int) +parser.add_argument("--trust-remote-code", action="store_true") +args = parser.parse_args() + + +# Initialize the FastAPI app +app = FastAPI() + +# Load the tokenizer and model from HuggingFace +tokenizer = AutoTokenizer.from_pretrained(args.model) +model = AutoModel.from_pretrained(args.model, trust_remote_code=args.trust_remote_code) + +# Initialize the request queue and batch processing parameters +request_queue = Queue() +BATCH_TIMEOUT = 0.01 # in seconds + + +@app.post("/v1/embeddings") +async def create_embeddings(request: EmbeddingsRequest): + """ + Handle incoming embedding requests by adding them to the processing queue. + """ + # Create a Future to hold the result + future = asyncio.get_event_loop().create_future() + # Put the request into the queue + await request_queue.put((request, future)) + # Wait for the result + result = await future + return result + + +@app.get("/health") +def status_check(): + """ + Returns 200. + """ + return Response(status_code=200) + + +async def process_queue(): + """ + Continuously process requests from the queue in batches. + """ + while True: + requests_futures = [] + try: + # Wait for at least one item + request_future = await request_queue.get() + requests_futures.append(request_future) + # Now, try to get more items with a timeout + try: + while len(requests_futures) < args.max_num_seqs: + request_future = await asyncio.wait_for( + request_queue.get(), timeout=BATCH_TIMEOUT + ) + requests_futures.append(request_future) + except asyncio.TimeoutError: + pass + except Exception: + continue + # Process the batch + requests = [rf[0] for rf in requests_futures] + futures = [rf[1] for rf in requests_futures] + # Collect input texts and track counts + batched_input_texts = [] + input_counts = [] + encoding_formats = [] + for request in requests: + input_text = request.input + if isinstance(input_text, str): + input_text = [input_text] + input_counts.append(len(input_text)) + batched_input_texts.extend(input_text) + encoding_formats.append(request.encoding_format) + # Tokenize and compute embeddings + inputs = tokenizer( + batched_input_texts, padding=True, truncation=True, return_tensors="pt" + ) + with torch.no_grad(): + outputs = model(**inputs) + embeddings = outputs.last_hidden_state.mean(dim=1).tolist() + # Split embeddings back to individual requests + idx = 0 + for request, future, count, encoding_format in zip( + requests, futures, input_counts, encoding_formats + ): + request_embeddings = embeddings[idx : idx + count] + idx += count + # Prepare response + data = [] + for i, embedding in enumerate(request_embeddings): + if encoding_format == "base64": + # Convert list of floats to bytes + embedding_bytes = ( + torch.tensor(embedding, dtype=torch.float32).numpy().tobytes() + ) + # Encode bytes to base64 string + embedding_base64 = base64.b64encode(embedding_bytes).decode("utf-8") + data.append( + EmbeddingData( + object="embedding", embedding=embedding_base64, index=i + ) + ) + else: + data.append( + EmbeddingData(object="embedding", embedding=embedding, index=i) + ) + response = EmbeddingsResponse( + object="list", + data=data, + model=request.model, + usage={ + "prompt_tokens": len(inputs["input_ids"]), # type: ignore + "total_tokens": len(inputs["input_ids"]), # type: ignore + }, + ) + # Set the result + future.set_result(response) + + +@app.on_event("startup") +async def startup_event(): + """ + Start the background task to process the request queue. + """ + asyncio.create_task(process_queue()) + + +if __name__ == "__main__": + print("INFO: Application startup complete.", file=sys.stderr) + uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/vec_inf/launch_server.sh b/vec_inf/launch_server.sh index 4476ca9..ec1988d 100755 --- a/vec_inf/launch_server.sh +++ b/vec_inf/launch_server.sh @@ -17,6 +17,7 @@ while [[ "$#" -gt 0 ]]; do --data-type) data_type="$2"; shift ;; --venv) venv="$2"; shift ;; --log-dir) log_dir="$2"; shift ;; + --slurm-script) slurm_script="$2"; shift ;; --model-weights-parent-dir) model_weights_parent_dir="$2"; shift ;; --pipeline-parallelism) pipeline_parallelism="$2"; shift ;; *) echo "Unknown parameter passed: $1"; exit 1 ;; @@ -44,6 +45,7 @@ export VLLM_MAX_MODEL_LEN=$max_model_len export VLLM_MAX_LOGPROBS=$vocab_size export VLLM_DATA_TYPE=$data_type export VENV_BASE=$venv +export SLURM_SCRIPT=$slurm_script export LOG_DIR=$log_dir export MODEL_WEIGHTS_PARENT_DIR=$model_weights_parent_dir @@ -53,6 +55,12 @@ else export VLLM_MAX_NUM_SEQS=256 fi +if [ -n "$slurm_script" ]; then + export SLURM_SCRIPT=$slurm_script +else + export SLURM_SCRIPT="vllm.slurm" +fi + if [ -n "$pipeline_parallelism" ]; then export PIPELINE_PARALLELISM=$pipeline_parallelism else @@ -121,4 +129,4 @@ sbatch --job-name $JOB_NAME \ --time $WALLTIME \ --output $LOG_DIR/$JOB_NAME.%j.out \ --error $LOG_DIR/$JOB_NAME.%j.err \ - $SRC_DIR/${is_special}vllm.slurm + $SRC_DIR/${is_special}${SLURM_SCRIPT} diff --git a/vec_inf/models/models.csv b/vec_inf/models/models.csv index 421856a..d48efa3 100644 --- a/vec_inf/models/models.csv +++ b/vec_inf/models/models.csv @@ -65,3 +65,5 @@ Qwen2.5-72B-Instruct,Qwen2.5,72B-Instruct,LLM,4,1,152064,16384,256,true,m2,08:00 Pixtral-12B-2409,Pixtral,12B-2409,VLM,1,1,131072,8192,256,true,m2,08:00:00,a40,auto,singularity,default,/model-weights bge-multilingual-gemma2,bge,multilingual-gemma2,Text Embedding,1,1,256002,4096,256,true,m2,08:00:00,a40,auto,singularity,default,/model-weights e5-mistral-7b-instruct,e5,mistral-7b-instruct,Text Embedding,1,1,32000,4096,256,true,m2,08:00:00,a40,auto,singularity,default,/model-weights +all-MiniLM-L6-v2,sentence-transformers,all-MiniLM-L6-v2,Text Embedding,1,1,30522,512,256,true,m2,08:00:00,a40,auto,singularity,default,/fs01/projects/llm/ +bge-base-en-v1.5,BAAI,base-en-v1.5,Text Embedding,1,1,30522,512,256,true,m2,08:00:00,a40,auto,singularity,default,/fs01/projects/llm/