Skip to content

Commit

Permalink
feat: localai client updated
Browse files Browse the repository at this point in the history
  • Loading branch information
nuffin committed Oct 18, 2024
1 parent f896635 commit 8325826
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 82 deletions.
125 changes: 121 additions & 4 deletions llmpa/clients/base.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,139 @@
import requests
from requests.exceptions import HTTPError, Timeout, RequestException
from typing import Optional, List


class ClientBase:
def __init__(self, base_url: str, timeout: Optional[int] = 10):
def __init__(
self,
base_url: str,
api_key: Optional[str] = None,
verify_ssl: Optional[bool] = True,
timeout: Optional[int] = 10,
):
self.base_url = base_url
if self.base_url.endswith("/"):
self.base_url = self.base_url[:-1]
self.api_key = api_key
self.verify_ssl = verify_ssl
self.timeout = timeout
self.headers = {
"Authorization": f"Bearer {self.api_key}" if self.api_key else None,
"Content-Type": "application/json",
}

def _make_request(
self,
method,
endpoint,
data=None,
json=None,
params=None,
extra_headers=None,
timeout: Optional[int] = None,
):
"""
Make an HTTP request to the LocalAI server.
:param method: HTTP method (GET, POST, PUT, DELETE).
:param endpoint: The API endpoint to hit.
:param data: The request body for POST/PUT requests.
:param params: Query parameters for GET/DELETE requests.
:param extra_headers: Additional headers for the request.
:return: Parsed JSON response or None if an error occurred.
"""
url = f"{self.base_url}{endpoint}"
headers = {**self.headers, **(extra_headers or {})}

print(
f"data={data}, json={json}, params={params}, extra_headers={extra_headers}"
)
try:
response = requests.request(
method=method.upper(),
url=url,
data=data,
json=json,
params=params,
headers=headers,
verify=self.verify_ssl,
timeout=timeout or self.timeout,
)
response.raise_for_status() # Raise HTTPError for bad responses
return response
except HTTPError as http_err:
print(f"HTTP error occurred: {http_err}")
except Timeout as timeout_err:
print(f"Timeout error: {timeout_err}")
except RequestException as req_err:
print(f"Request error: {req_err}")
except Exception as err:
print(f"An error occurred: {err}")
return None

def get(
self, endpoint, params=None, extra_headers=None, timeout: Optional[int] = None
):
return self._make_request(
"GET", endpoint, params=params, extra_headers=extra_headers, timeout=timeout
)

def post(
self,
endpoint,
data=None,
json=None,
extra_headers=None,
timeout: Optional[int] = None,
):
return self._make_request(
"POST",
endpoint,
data=data,
json=json,
extra_headers=extra_headers,
timeout=timeout,
)

def put(
self,
endpoint,
data=None,
json=None,
extra_headers=None,
timeout: Optional[int] = None,
):
return self._make_request(
"PUT",
endpoint,
data=data,
json=json,
extra_headers=extra_headers,
timeout=timeout,
)

def delete(
self, endpoint, params=None, extra_headers=None, timeout: Optional[int] = None
):
return self._make_request(
"DELETE",
endpoint,
params=params,
extra_headers=extra_headers,
timeout=timeout,
)

def generate(
self,
model: str,
prompt: str,
model: Optional[str] = None,
max_tokens: Optional[int] = 150,
temperature: Optional[float] = 1.0,
) -> Optional[str]:
raise NotImplementedError

def embedding_text(self, model: str, text: str) -> Optional[List[float]]:
def embedding_text(self, text: str, model: str) -> Optional[List[float]]:
raise NotImplementedError

def embedding_image(self, model: str, filepath: str) -> Optional[List[float]]:
def embedding_file(self, filepath: str, model: str) -> Optional[List[float]]:
raise NotImplementedError
74 changes: 30 additions & 44 deletions llmpa/clients/localai.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,20 @@
import requests
import json
from requests.exceptions import HTTPError, Timeout, RequestException
from typing import Optional, List

from .base import ClientBase


class LocalAIClient(ClientBase):
def __init__(self, base_url: str, timeout: Optional[int] = 10):
def __init__(self, base_url: str, api_key=None, verify_ssl=True, timeout=10):
"""
Initializes the LocalAI client with the specified server base URL.
Args:
base_url (str): The base URL of the LocalAI server (e.g., "http://localhost:8080").
timeout (int, optional): Timeout for the HTTP requests. Default is 10 seconds.
"""
super(LocalAIClient, self).__init__(base_url, timeout)

def get_model_info(self, model: str) -> Optional[dict]:
"""
Retrieves model information from the LocalAI server.
Returns:
dict: Model information as a dictionary, or None if the request fails.
"""
try:
response = requests.get(
f"{self.base_url}/v1/models/{model}", timeout=self.timeout
)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
print(f"Error retrieving model information: {e}")
return None
super(LocalAIClient, self).__init__(base_url, api_key, verify_ssl, timeout)

def list_available_models(self) -> Optional[list]:
"""
Expand All @@ -41,17 +24,20 @@ def list_available_models(self) -> Optional[list]:
list: List of available models, or None if the request fails.
"""
try:
response = requests.get(f"{self.base_url}/v1/models", timeout=self.timeout)
response = self.get("/v1/models", timeout=self.timeout)
print(response.json())
if not response:
return None
response.raise_for_status()
return response.json().get("models", [])
except requests.exceptions.RequestException as e:
return response.json().get("data", [])
except RequestException as e:
print(f"Error retrieving list of models: {e}")
return None

def generate(
self,
model: str,
prompt: str,
model: Optional[str] = None,
max_tokens: Optional[int] = 150,
temperature: Optional[float] = 1.0,
) -> Optional[str]:
Expand All @@ -77,12 +63,13 @@ def generate(
}

try:
response = requests.post(
f"{self.base_url}/v1/generate",
headers=headers,
data=json.dumps(payload),
response = self.post(
"/v1/chat/completions",
json=payload,
timeout=self.timeout,
)
if not response:
return None
response.raise_for_status()

# Extract and return the generated text from the response
Expand All @@ -93,11 +80,11 @@ def generate(
print("No valid response received from the model.")
return None

except requests.exceptions.RequestException as e:
except RequestException as e:
print(f"Error during the request to LocalAI: {e}")
return None

def embedding(self, model: str, text: str) -> Optional[List[float]]:
def embedding(self, text: str, model: str) -> Optional[List[float]]:
"""
Sends a request to the LocalAI server to generate embeddings from the input text.
Expand All @@ -113,12 +100,14 @@ def embedding(self, model: str, text: str) -> Optional[List[float]]:
payload = {"model": model, "input": text}

try:
response = requests.post(
f"{self.base_url}/v1/embeddings",
headers=headers,
data=json.dumps(payload),
response = self.post(
"/embeddings",
extra_headers=headers,
json=payload,
timeout=self.timeout,
)
if not response:
return None
response.raise_for_status()

result = response.json()
Expand All @@ -128,37 +117,34 @@ def embedding(self, model: str, text: str) -> Optional[List[float]]:
print("No valid embedding received from the model.")
return None

except requests.exceptions.RequestException as e:
except RequestException as e:
print(f"Error during the request to LocalAI for embedding: {e}")
return None


# Example usage:
if __name__ == "__main__":
# Initialize the client
client = LocalAIClient(base_url="http://localhost:8080", model="your-model-name")
client = LocalAIClient(base_url="http://localhost:58080")

# Example 1: Generating text
prompt = "Tell me a story about a brave knight."
generated_text = client.generate(prompt, max_tokens=100, temperature=0.7)
generated_text = client.generate(
prompt, max_tokens=100, temperature=0.7, model="text-embedding-ada-002"
)
if generated_text:
print(f"Generated Text: {generated_text}")

# Example 2: Get model information
model_info = client.get_model_info()
if model_info:
print(f"Model Information: {json.dumps(model_info, indent=2)}")

# Example 3: List available models
models = client.list_available_models()
if models:
print(f"Available Models: {models}")

# Example 4: Get embeddings for input text
input_text = "Artificial intelligence is transforming industries."
embedding = client.embedding(input_text)
embedding = client.embedding(input_text, "text-embedding-ada-002")
if embedding:
print(f"Embedding: {embedding}")

# Example 5: Update the model used by the client
client.update_model("another-model-name")
# client.update_model("another-model-name")
1 change: 1 addition & 0 deletions llmpa/fileparser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": xlsx.XlsxFileParser,
"text/csv": csv.CsvFileParser,
"text/plain": text.TextFileParser,
"text/x-ini": text.TextFileParser,
"image/gif": image.ImageFileParser,
"image/jpeg": image.ImageFileParser,
"image/png": image.ImageFileParser,
Expand Down
2 changes: 1 addition & 1 deletion llmpa/fileparser/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def main():
import sys

if len(sys.argv) < 2:
print("Usage: document.py <filepath>")
print("Usage: fileparser <filepath>")
sys.exit(1)
filepath = sys.argv[1]
tokenize(filepath)
Expand Down
9 changes: 8 additions & 1 deletion llmpa/fileparser/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from .mimetype import detect, mimetypes_names


class BaseFileParser:
def __init__(self, file_path):
self.file_path = file_path
self.file_type = None
self.file_type = detect(file_path)
self.file_mimetype_name = mimetypes_names[self.file_type]

def parse(self):
raise NotImplementedError

def tokenize(self):
raise NotImplementedError

def prompt_for_tokenizing(self):
return f"this is a {self.file_mimetype_name} file, tokenize it, and return the embeddings"
Loading

0 comments on commit 8325826

Please sign in to comment.