Skip to content

Commit

Permalink
chore: update backends
Browse files Browse the repository at this point in the history
  • Loading branch information
nuffin committed Oct 20, 2024
1 parent 90b3dac commit 715321c
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 168 deletions.
3 changes: 2 additions & 1 deletion llmpa/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import plotly.express as px

project_root = os.path.abspath(os.path.dirname(__file__))
sys.path.insert(0, project_root)
if project_root not in sys.path:
sys.path.insert(0, project_root)


logging.basicConfig(level=logging.DEBUG)
Expand Down
123 changes: 5 additions & 118 deletions llmpa/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,125 +3,12 @@
from typing import Optional, List


class ClientBase:
def __init__(
self,
base_url: str,
api_key: Optional[str] = None,
verify_ssl: Optional[bool] = True,
timeout: Optional[int] = 10,
):
self.base_url = base_url
if self.base_url.endswith("/"):
self.base_url = self.base_url[:-1]
self.api_key = api_key
self.verify_ssl = verify_ssl
self.timeout = timeout
self.headers = {
"Authorization": f"Bearer {self.api_key}" if self.api_key else None,
"Content-Type": "application/json",
}

def _make_request(
self,
method,
endpoint,
data=None,
json=None,
params=None,
extra_headers=None,
timeout: Optional[int] = None,
):
"""
Make an HTTP request to the LocalAI server.
:param method: HTTP method (GET, POST, PUT, DELETE).
:param endpoint: The API endpoint to hit.
:param data: The request body for POST/PUT requests.
:param params: Query parameters for GET/DELETE requests.
:param extra_headers: Additional headers for the request.
:return: Parsed JSON response or None if an error occurred.
"""
url = f"{self.base_url}{endpoint}"
headers = {**self.headers, **(extra_headers or {})}

print(
f"data={data}, json={json}, params={params}, extra_headers={extra_headers}"
)
try:
response = requests.request(
method=method.upper(),
url=url,
data=data,
json=json,
params=params,
headers=headers,
verify=self.verify_ssl,
timeout=timeout or self.timeout,
)
response.raise_for_status() # Raise HTTPError for bad responses
return response
except HTTPError as http_err:
print(f"HTTP error occurred: {http_err}")
except Timeout as timeout_err:
print(f"Timeout error: {timeout_err}")
except RequestException as req_err:
print(f"Request error: {req_err}")
except Exception as err:
print(f"An error occurred: {err}")
return None

def get(
self, endpoint, params=None, extra_headers=None, timeout: Optional[int] = None
):
return self._make_request(
"GET", endpoint, params=params, extra_headers=extra_headers, timeout=timeout
)

def post(
self,
endpoint,
data=None,
json=None,
extra_headers=None,
timeout: Optional[int] = None,
):
return self._make_request(
"POST",
endpoint,
data=data,
json=json,
extra_headers=extra_headers,
timeout=timeout,
)

def put(
self,
endpoint,
data=None,
json=None,
extra_headers=None,
timeout: Optional[int] = None,
):
return self._make_request(
"PUT",
endpoint,
data=data,
json=json,
extra_headers=extra_headers,
timeout=timeout,
)
class BaseBackend:
def get_model_info(self, model: str) -> Optional[dict]:
raise NotImplementedError

def delete(
self, endpoint, params=None, extra_headers=None, timeout: Optional[int] = None
):
return self._make_request(
"DELETE",
endpoint,
params=params,
extra_headers=extra_headers,
timeout=timeout,
)
def list_available_models(self) -> Optional[list]:
raise NotImplementedError

def generate(
self,
Expand Down
68 changes: 19 additions & 49 deletions llmpa/backends/localai.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
import os
import sys

import json
from requests.exceptions import HTTPError, Timeout, RequestException
from typing import Optional, List

from .base import ClientBase
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)

from clients.http import HttpClient
from .base import BaseBackend

class LocalAIClient(ClientBase):
def __init__(self, base_url: str, api_key=None, verify_ssl=True, timeout=10):
"""
Initializes the LocalAI client with the specified server base URL.

Args:
base_url (str): The base URL of the LocalAI server (e.g., "http://localhost:8080").
timeout (int, optional): Timeout for the HTTP requests. Default is 10 seconds.
"""
super(LocalAIClient, self).__init__(base_url, api_key, verify_ssl, timeout)
class Backend(BaseBackend, HttpClient):
def __init__(self, base_url: str, api_key=None, verify_ssl=True, timeout=10):
super(Backend, self).__init__(base_url, api_key, verify_ssl, timeout)
self.client = HttpClient(base_url, api_key, verify_ssl, timeout)

def list_available_models(self) -> Optional[list]:
"""
Retrieves a list of available models from the LocalAI server.
Returns:
list: List of available models, or None if the request fails.
"""
try:
response = self.get("/v1/models", timeout=self.timeout)
response = self.client.get("/v1/models", timeout=self.timeout)
print(response.json())
if not response:
return None
Expand All @@ -41,18 +37,6 @@ def generate(
max_tokens: Optional[int] = 150,
temperature: Optional[float] = 1.0,
) -> Optional[str]:
"""
Sends a request to the LocalAI server to generate text based on the input prompt.
Args:
model (str): The name of the model to be used for inference and embeddings.
prompt (str): The prompt or question to send to the LocalAI model.
max_tokens (int, optional): The maximum number of tokens to generate. Default is 150.
temperature (float, optional): The sampling temperature to control the randomness of output. Default is 1.0.
Returns:
str: The generated text or None if the request fails.
"""
headers = {"Content-Type": "application/json"}

payload = {
Expand All @@ -63,7 +47,7 @@ def generate(
}

try:
response = self.post(
response = self.client.post(
"/v1/chat/completions",
json=payload,
timeout=self.timeout,
Expand All @@ -85,22 +69,12 @@ def generate(
return None

def embedding(self, text: str, model: str) -> Optional[List[float]]:
"""
Sends a request to the LocalAI server to generate embeddings from the input text.
Args:
model (str): The name of the model to be used for inference and embeddings.
text (str): The input text for which to generate embeddings.
Returns:
list: A list of floats representing the embedding vector, or None if the request fails.
"""
headers = {"Content-Type": "application/json"}

payload = {"model": model, "input": text}

try:
response = self.post(
response = self.client.post(
"/embeddings",
extra_headers=headers,
json=payload,
Expand All @@ -124,27 +98,23 @@ def embedding(self, text: str, model: str) -> Optional[List[float]]:

# Example usage:
if __name__ == "__main__":
# Initialize the client
client = LocalAIClient(base_url="http://localhost:58080")
backend = Backend(base_url="http://localhost:58080")

# Example 1: Generating text
prompt = "Tell me a story about a brave knight."
generated_text = client.generate(
generated_text = backend.generate(
prompt, max_tokens=100, temperature=0.7, model="text-embedding-ada-002"
)
if generated_text:
print(f"Generated Text: {generated_text}")

# Example 3: List available models
models = client.list_available_models()
models = backend.list_available_models()
if models:
print(f"Available Models: {models}")

# Example 4: Get embeddings for input text
input_text = "Artificial intelligence is transforming industries."
embedding = client.embedding(input_text, "text-embedding-ada-002")
embedding = backend.embedding(input_text, "text-embedding-ada-002")
if embedding:
print(f"Embedding: {embedding}")

# Example 5: Update the model used by the client
# client.update_model("another-model-name")
Empty file added llmpa/clients/__init__.py
Empty file.
125 changes: 125 additions & 0 deletions llmpa/clients/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import requests
from requests.exceptions import HTTPError, Timeout, RequestException
from typing import Optional, List

class HttpClient:
def __init__(
self,
base_url: str,
api_key: Optional[str] = None,
verify_ssl: Optional[bool] = True,
timeout: Optional[int] = 10,
):
self.base_url = base_url
if self.base_url.endswith("/"):
self.base_url = self.base_url[:-1]
self.api_key = api_key
self.verify_ssl = verify_ssl
self.timeout = timeout
self.headers = {
"Authorization": f"Bearer {self.api_key}" if self.api_key else None,
"Content-Type": "application/json",
}

def _make_request(
self,
method,
endpoint,
data=None,
json=None,
params=None,
extra_headers=None,
timeout: Optional[int] = None,
):
"""
Make an HTTP request to the LocalAI server.
:param method: HTTP method (GET, POST, PUT, DELETE).
:param endpoint: The API endpoint to hit.
:param data: The request body for POST/PUT requests.
:param params: Query parameters for GET/DELETE requests.
:param extra_headers: Additional headers for the request.
:return: Parsed JSON response or None if an error occurred.
"""
url = f"{self.base_url}{endpoint}"
headers = {**self.headers, **(extra_headers or {})}

print(
f"data={data}, json={json}, params={params}, extra_headers={extra_headers}"
)
try:
response = requests.request(
method=method.upper(),
url=url,
data=data,
json=json,
params=params,
headers=headers,
verify=self.verify_ssl,
timeout=timeout or self.timeout,
)
response.raise_for_status() # Raise HTTPError for bad responses
return response
except HTTPError as http_err:
print(f"HTTP error occurred: {http_err}")
except Timeout as timeout_err:
print(f"Timeout error: {timeout_err}")
except RequestException as req_err:
print(f"Request error: {req_err}")
except Exception as err:
print(f"An error occurred: {err}")
return None

def get(
self, endpoint, params=None, extra_headers=None, timeout: Optional[int] = None
):
return self._make_request(
"GET", endpoint, params=params, extra_headers=extra_headers, timeout=timeout
)

def post(
self,
endpoint,
data=None,
json=None,
extra_headers=None,
timeout: Optional[int] = None,
):
return self._make_request(
"POST",
endpoint,
data=data,
json=json,
extra_headers=extra_headers,
timeout=timeout,
)

def put(
self,
endpoint,
data=None,
json=None,
extra_headers=None,
timeout: Optional[int] = None,
):
return self._make_request(
"PUT",
endpoint,
data=data,
json=json,
extra_headers=extra_headers,
timeout=timeout,
)

def delete(
self, endpoint, params=None, extra_headers=None, timeout: Optional[int] = None
):
return self._make_request(
"DELETE",
endpoint,
params=params,
extra_headers=extra_headers,
timeout=timeout,
)


0 comments on commit 715321c

Please sign in to comment.