forked from continuum-llms/chatgpt-memory
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rest_api.py
48 lines (33 loc) · 1.79 KB
/
rest_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from typing import Optional
from fastapi import FastAPI
from pydantic import BaseModel
from chatgpt_memory.datastore import RedisDataStore, RedisDataStoreConfig
from chatgpt_memory.environment import OPENAI_API_KEY, REDIS_HOST, REDIS_PASSWORD, REDIS_PORT
from chatgpt_memory.llm_client import ChatGPTClient, ChatGPTConfig, ChatGPTResponse, EmbeddingClient, EmbeddingConfig
from chatgpt_memory.memory import MemoryManager
# Instantiate an EmbeddingConfig object with the OpenAI API key
embedding_config = EmbeddingConfig(api_key=OPENAI_API_KEY)
# Instantiate an EmbeddingClient object with the EmbeddingConfig object
embed_client = EmbeddingClient(config=embedding_config)
# Instantiate a RedisDataStoreConfig object with the Redis connection details
redis_datastore_config = RedisDataStoreConfig(
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
)
# Instantiate a RedisDataStore object with the RedisDataStoreConfig object
redis_datastore = RedisDataStore(config=redis_datastore_config)
# Instantiate a MemoryManager object with the RedisDataStore object and EmbeddingClient object
memory_manager = MemoryManager(datastore=redis_datastore, embed_client=embed_client, topk=1)
# Instantiate a ChatGPTConfig object with the OpenAI API key and verbose set to True
chat_gpt_config = ChatGPTConfig(api_key=OPENAI_API_KEY, verbose=False)
# Instantiate a ChatGPTClient object with the ChatGPTConfig object and MemoryManager object
chat_gpt_client = ChatGPTClient(config=chat_gpt_config, memory_manager=memory_manager)
class MessagePayload(BaseModel):
conversation_id: Optional[str]
message: str
app = FastAPI()
@app.post("/converse/")
async def converse(message_payload: MessagePayload) -> ChatGPTResponse:
response = chat_gpt_client.converse(**message_payload.dict())
return response