Skip to content

Commit

Permalink
feat: Enable usage token and use Converse API with all Bedrock models (
Browse files Browse the repository at this point in the history
  • Loading branch information
charles-marion authored Sep 19, 2024
1 parent 08f899b commit e0d310e
Show file tree
Hide file tree
Showing 19 changed files with 453 additions and 935 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,18 @@ def on_llm_end(

class ModelAdapter:
def __init__(
self, session_id, user_id, mode=ChatbotMode.CHAIN.value, model_kwargs={}
self,
session_id,
user_id,
mode=ChatbotMode.CHAIN.value,
disable_streaming=False,
model_kwargs={},
):
self.session_id = session_id
self.user_id = user_id
self._mode = mode
self.model_kwargs = model_kwargs
self.disable_streaming = disable_streaming

self.callback_handler = LLMStartHandler()
self.__bind_callbacks()
Expand Down Expand Up @@ -176,12 +182,12 @@ def run_with_chain_v2(self, user_prompt, workspace_id=None):

config = {"configurable": {"session_id": self.session_id}}
try:
if self.model_kwargs.get("streaming", False):
if not self.disable_streaming and self.model_kwargs.get("streaming", False):
answer = ""
for chunk in conversation.stream(
input={"input": user_prompt}, config=config
):
logger.info("chunk", chunk=chunk)
logger.debug("chunk", chunk=chunk)
if "answer" in chunk:
answer = answer + chunk["answer"]
elif isinstance(chunk, AIMessageChunk):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,2 @@
# flake8: noqa
from .claude import *
from .titan import *
from .ai21_j2 import *
from .cohere import *
from .llama2_chat import *
from .mistral import *
from .llama3 import *
from .base import *

This file was deleted.

Loading

0 comments on commit e0d310e

Please sign in to comment.