From 832582614cdff308eb393954a52fdd41bf0db828 Mon Sep 17 00:00:00 2001 From: "Hauzer S. Lee" Date: Fri, 18 Oct 2024 23:31:38 +0800 Subject: [PATCH] feat: localai client updated --- llmpa/clients/base.py | 125 +++++++++++++++++++++++++++++++++-- llmpa/clients/localai.py | 74 +++++++++------------ llmpa/fileparser/__init__.py | 1 + llmpa/fileparser/__main__.py | 2 +- llmpa/fileparser/base.py | 9 ++- llmpa/fileparser/mimetype.py | 67 ++++++++++--------- llmpa/fileparser/text.py | 4 +- requirements.txt | 1 + 8 files changed, 201 insertions(+), 82 deletions(-) diff --git a/llmpa/clients/base.py b/llmpa/clients/base.py index 751cf8b..44c5b85 100644 --- a/llmpa/clients/base.py +++ b/llmpa/clients/base.py @@ -1,22 +1,139 @@ +import requests +from requests.exceptions import HTTPError, Timeout, RequestException from typing import Optional, List class ClientBase: - def __init__(self, base_url: str, timeout: Optional[int] = 10): + def __init__( + self, + base_url: str, + api_key: Optional[str] = None, + verify_ssl: Optional[bool] = True, + timeout: Optional[int] = 10, + ): self.base_url = base_url + if self.base_url.endswith("/"): + self.base_url = self.base_url[:-1] + self.api_key = api_key + self.verify_ssl = verify_ssl self.timeout = timeout + self.headers = { + "Authorization": f"Bearer {self.api_key}" if self.api_key else None, + "Content-Type": "application/json", + } + + def _make_request( + self, + method, + endpoint, + data=None, + json=None, + params=None, + extra_headers=None, + timeout: Optional[int] = None, + ): + """ + Make an HTTP request to the LocalAI server. + + :param method: HTTP method (GET, POST, PUT, DELETE). + :param endpoint: The API endpoint to hit. + :param data: The request body for POST/PUT requests. + :param params: Query parameters for GET/DELETE requests. + :param extra_headers: Additional headers for the request. + :return: Parsed JSON response or None if an error occurred. + """ + url = f"{self.base_url}{endpoint}" + headers = {**self.headers, **(extra_headers or {})} + + print( + f"data={data}, json={json}, params={params}, extra_headers={extra_headers}" + ) + try: + response = requests.request( + method=method.upper(), + url=url, + data=data, + json=json, + params=params, + headers=headers, + verify=self.verify_ssl, + timeout=timeout or self.timeout, + ) + response.raise_for_status() # Raise HTTPError for bad responses + return response + except HTTPError as http_err: + print(f"HTTP error occurred: {http_err}") + except Timeout as timeout_err: + print(f"Timeout error: {timeout_err}") + except RequestException as req_err: + print(f"Request error: {req_err}") + except Exception as err: + print(f"An error occurred: {err}") + return None + + def get( + self, endpoint, params=None, extra_headers=None, timeout: Optional[int] = None + ): + return self._make_request( + "GET", endpoint, params=params, extra_headers=extra_headers, timeout=timeout + ) + + def post( + self, + endpoint, + data=None, + json=None, + extra_headers=None, + timeout: Optional[int] = None, + ): + return self._make_request( + "POST", + endpoint, + data=data, + json=json, + extra_headers=extra_headers, + timeout=timeout, + ) + + def put( + self, + endpoint, + data=None, + json=None, + extra_headers=None, + timeout: Optional[int] = None, + ): + return self._make_request( + "PUT", + endpoint, + data=data, + json=json, + extra_headers=extra_headers, + timeout=timeout, + ) + + def delete( + self, endpoint, params=None, extra_headers=None, timeout: Optional[int] = None + ): + return self._make_request( + "DELETE", + endpoint, + params=params, + extra_headers=extra_headers, + timeout=timeout, + ) def generate( self, - model: str, prompt: str, + model: Optional[str] = None, max_tokens: Optional[int] = 150, temperature: Optional[float] = 1.0, ) -> Optional[str]: raise NotImplementedError - def embedding_text(self, model: str, text: str) -> Optional[List[float]]: + def embedding_text(self, text: str, model: str) -> Optional[List[float]]: raise NotImplementedError - def embedding_image(self, model: str, filepath: str) -> Optional[List[float]]: + def embedding_file(self, filepath: str, model: str) -> Optional[List[float]]: raise NotImplementedError diff --git a/llmpa/clients/localai.py b/llmpa/clients/localai.py index 3bd7d96..2f946bb 100644 --- a/llmpa/clients/localai.py +++ b/llmpa/clients/localai.py @@ -1,12 +1,12 @@ -import requests import json +from requests.exceptions import HTTPError, Timeout, RequestException from typing import Optional, List from .base import ClientBase class LocalAIClient(ClientBase): - def __init__(self, base_url: str, timeout: Optional[int] = 10): + def __init__(self, base_url: str, api_key=None, verify_ssl=True, timeout=10): """ Initializes the LocalAI client with the specified server base URL. @@ -14,24 +14,7 @@ def __init__(self, base_url: str, timeout: Optional[int] = 10): base_url (str): The base URL of the LocalAI server (e.g., "http://localhost:8080"). timeout (int, optional): Timeout for the HTTP requests. Default is 10 seconds. """ - super(LocalAIClient, self).__init__(base_url, timeout) - - def get_model_info(self, model: str) -> Optional[dict]: - """ - Retrieves model information from the LocalAI server. - - Returns: - dict: Model information as a dictionary, or None if the request fails. - """ - try: - response = requests.get( - f"{self.base_url}/v1/models/{model}", timeout=self.timeout - ) - response.raise_for_status() - return response.json() - except requests.exceptions.RequestException as e: - print(f"Error retrieving model information: {e}") - return None + super(LocalAIClient, self).__init__(base_url, api_key, verify_ssl, timeout) def list_available_models(self) -> Optional[list]: """ @@ -41,17 +24,20 @@ def list_available_models(self) -> Optional[list]: list: List of available models, or None if the request fails. """ try: - response = requests.get(f"{self.base_url}/v1/models", timeout=self.timeout) + response = self.get("/v1/models", timeout=self.timeout) + print(response.json()) + if not response: + return None response.raise_for_status() - return response.json().get("models", []) - except requests.exceptions.RequestException as e: + return response.json().get("data", []) + except RequestException as e: print(f"Error retrieving list of models: {e}") return None def generate( self, - model: str, prompt: str, + model: Optional[str] = None, max_tokens: Optional[int] = 150, temperature: Optional[float] = 1.0, ) -> Optional[str]: @@ -77,12 +63,13 @@ def generate( } try: - response = requests.post( - f"{self.base_url}/v1/generate", - headers=headers, - data=json.dumps(payload), + response = self.post( + "/v1/chat/completions", + json=payload, timeout=self.timeout, ) + if not response: + return None response.raise_for_status() # Extract and return the generated text from the response @@ -93,11 +80,11 @@ def generate( print("No valid response received from the model.") return None - except requests.exceptions.RequestException as e: + except RequestException as e: print(f"Error during the request to LocalAI: {e}") return None - def embedding(self, model: str, text: str) -> Optional[List[float]]: + def embedding(self, text: str, model: str) -> Optional[List[float]]: """ Sends a request to the LocalAI server to generate embeddings from the input text. @@ -113,12 +100,14 @@ def embedding(self, model: str, text: str) -> Optional[List[float]]: payload = {"model": model, "input": text} try: - response = requests.post( - f"{self.base_url}/v1/embeddings", - headers=headers, - data=json.dumps(payload), + response = self.post( + "/embeddings", + extra_headers=headers, + json=payload, timeout=self.timeout, ) + if not response: + return None response.raise_for_status() result = response.json() @@ -128,7 +117,7 @@ def embedding(self, model: str, text: str) -> Optional[List[float]]: print("No valid embedding received from the model.") return None - except requests.exceptions.RequestException as e: + except RequestException as e: print(f"Error during the request to LocalAI for embedding: {e}") return None @@ -136,19 +125,16 @@ def embedding(self, model: str, text: str) -> Optional[List[float]]: # Example usage: if __name__ == "__main__": # Initialize the client - client = LocalAIClient(base_url="http://localhost:8080", model="your-model-name") + client = LocalAIClient(base_url="http://localhost:58080") # Example 1: Generating text prompt = "Tell me a story about a brave knight." - generated_text = client.generate(prompt, max_tokens=100, temperature=0.7) + generated_text = client.generate( + prompt, max_tokens=100, temperature=0.7, model="text-embedding-ada-002" + ) if generated_text: print(f"Generated Text: {generated_text}") - # Example 2: Get model information - model_info = client.get_model_info() - if model_info: - print(f"Model Information: {json.dumps(model_info, indent=2)}") - # Example 3: List available models models = client.list_available_models() if models: @@ -156,9 +142,9 @@ def embedding(self, model: str, text: str) -> Optional[List[float]]: # Example 4: Get embeddings for input text input_text = "Artificial intelligence is transforming industries." - embedding = client.embedding(input_text) + embedding = client.embedding(input_text, "text-embedding-ada-002") if embedding: print(f"Embedding: {embedding}") # Example 5: Update the model used by the client - client.update_model("another-model-name") + # client.update_model("another-model-name") diff --git a/llmpa/fileparser/__init__.py b/llmpa/fileparser/__init__.py index 7c75f25..1a0294d 100644 --- a/llmpa/fileparser/__init__.py +++ b/llmpa/fileparser/__init__.py @@ -20,6 +20,7 @@ "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": xlsx.XlsxFileParser, "text/csv": csv.CsvFileParser, "text/plain": text.TextFileParser, + "text/x-ini": text.TextFileParser, "image/gif": image.ImageFileParser, "image/jpeg": image.ImageFileParser, "image/png": image.ImageFileParser, diff --git a/llmpa/fileparser/__main__.py b/llmpa/fileparser/__main__.py index 275ef5a..057569f 100644 --- a/llmpa/fileparser/__main__.py +++ b/llmpa/fileparser/__main__.py @@ -5,7 +5,7 @@ def main(): import sys if len(sys.argv) < 2: - print("Usage: document.py ") + print("Usage: fileparser ") sys.exit(1) filepath = sys.argv[1] tokenize(filepath) diff --git a/llmpa/fileparser/base.py b/llmpa/fileparser/base.py index 07d389c..2157c16 100644 --- a/llmpa/fileparser/base.py +++ b/llmpa/fileparser/base.py @@ -1,10 +1,17 @@ +from .mimetype import detect, mimetypes_names + + class BaseFileParser: def __init__(self, file_path): self.file_path = file_path - self.file_type = None + self.file_type = detect(file_path) + self.file_mimetype_name = mimetypes_names[self.file_type] def parse(self): raise NotImplementedError def tokenize(self): raise NotImplementedError + + def prompt_for_tokenizing(self): + return f"this is a {self.file_mimetype_name} file, tokenize it, and return the embeddings" diff --git a/llmpa/fileparser/mimetype.py b/llmpa/fileparser/mimetype.py index 35eaa90..5430bb8 100644 --- a/llmpa/fileparser/mimetype.py +++ b/llmpa/fileparser/mimetype.py @@ -3,44 +3,49 @@ mime = magic.Magic(mime=True) -supported_mimetypes = [ - "text/plain", - "text/csv", - "text/html", - "text/xml", - "text/x-c", ## c, go ## XXX: add file extension checking - "text/x-c++", ## cpp - "text/x-java-source", ## java - "text/x-go", ## FAKE: not support by python-magic, but checking file extension - "text/x-matlab", ## matlab - "text/x-perl", ## perl - "text/x-php", ## php - "text/x-python", ## python, /etc/mime.types - "text/x-script.python", ## python, by magic - "application/javascript", ## js - "application/json", - # "application/x-python-code", ## python, /etc/mime.types, pyc pyo - "application/pdf", - "application/vnd.openxmlformats-officedocument.presentationml.presentation", ## pptx - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ## xlsx - "audio/mpeg", ## mp3 - "audio/x-wav", ## wav - "image/gif", ## gif - "image/jpeg", ## jpg - "image/png", ## png - "image/svg+xml", ## svg - "video/mp4", ## mp4 - "video/quicktime", ## mov - "video/x-msvideo", ## avi -] +mimetypes_names = { + "text/plain": "Plain Text", + "text/csv": "CSV (Comma Separated Values)", + "text/html": "HTML", + "text/xml": "XML", + "text/x-c": "C Source Code", # c, go # XXX: Add file extension checking + "text/x-c++": "C++ Source Code", # cpp + "text/x-java-source": "Java Source Code", # java + "text/x-go": "Go Source Code", # FAKE: Not supported by python-magic, but checking file extension + "text/x-ini": "INI Configuration", # FAKE: Not detected by python-magic, but checking file extension + "text/x-matlab": "MATLAB Source Code", # matlab + "text/x-perl": "Perl Source Code", # perl + "text/x-php": "PHP Source Code", # php + "text/x-python": "Python Source Code", # python, /etc/mime.types + "text/x-script.python": "Python Script", # python, by magic + "application/javascript": "JavaScript", # js + "application/json": "JSON", + # "application/x-python-code": "Compiled Python Code", # python, /etc/mime.types, pyc pyo + "application/pdf": "PDF", + "application/vnd.openxmlformats-officedocument.presentationml.presentation": "PPTX Presentation", # pptx + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": "XLSX Spreadsheet", # xlsx + "audio/mpeg": "MPEG Audio", # mp3 + "audio/x-wav": "WAVE Audio", # wav + "image/gif": "GIF Image", # gif + "image/jpeg": "JPEG Image", # jpg + "image/png": "PNG Image", # png + "image/svg+xml": "SVG Image", # svg + "video/mp4": "MP4 Video", # mp4 + "video/quicktime": "QuickTime Video", # mov + "video/x-msvideo": "AVI Video", # avi +} + +supported_mimetypes_names = mimetypes_names def detect(filepath: str, follow_symlinks: bool = True) -> str: resolved_filepath = follow_symlinks and os.path.realpath(filepath) or filepath mimetype = mime.from_file(resolved_filepath) if os.path.isfile(filepath) else None - if mimetype in supported_mimetypes: + if mimetype in supported_mimetypes_names: if mimetype == "text/x-c" and filepath.endswith(".go"): return "text/x-go" + if mimetype == "text/plain" and filepath.endswith(".ini"): + return "text/x-ini" return mimetype return None diff --git a/llmpa/fileparser/text.py b/llmpa/fileparser/text.py index 7ea58cc..a81ec90 100644 --- a/llmpa/fileparser/text.py +++ b/llmpa/fileparser/text.py @@ -9,4 +9,6 @@ def parse(self): print("Parsing text file") def tokenize(self): - print("Tokenizing text file") + print( + f"Tokenizing file: {self.file_path} ({self.file_type}): {self.file_mimetype_name}" + ) diff --git a/requirements.txt b/requirements.txt index 15a7899..bb1a407 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,4 @@ black==24.10.0 bcrypt==4.2.0 uvicorn==0.32.0 asgiref==3.8.1 +requests[security]=2.32.3