Skip to content

Commit

Permalink
chore: formatted with black
Browse files Browse the repository at this point in the history
  • Loading branch information
broomva committed May 14, 2024
1 parent 2802548 commit d550e66
Show file tree
Hide file tree
Showing 10 changed files with 159 additions and 102 deletions.
103 changes: 71 additions & 32 deletions arcan/ai/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,60 @@
import pickle
import weakref
from datetime import datetime

# Ensure necessary imports for ArcanAgent
from tempfile import TemporaryDirectory
from typing import Any, AsyncIterator, Dict, List, Optional, cast

from fastapi import Depends
from fastapi.responses import StreamingResponse
from langchain.agents import (AgentExecutor, AgentType,
create_tool_calling_agent, initialize_agent,
load_tools)
from langchain.agents import (
AgentExecutor,
AgentType,
create_tool_calling_agent,
initialize_agent,
load_tools,
)
from langchain.agents.agent_types import AgentType
from langchain.agents.format_scratchpad.openai_tools import \
format_to_openai_tool_messages
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages,
)
from langchain.agents.format_scratchpad.tools import format_to_tool_messages
from langchain.agents.output_parsers.tools import ToolsAgentOutputParser
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.memory import ConversationBufferMemory
from langchain.pydantic_v1 import BaseModel
from langchain.sql_database import SQLDatabase
from langchain_community.agent_toolkits import (FileManagementToolkit,
SQLDatabaseToolkit)
from langchain_community.agent_toolkits import FileManagementToolkit, SQLDatabaseToolkit
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.load.serializable import (Serializable,
SerializedConstructor,
SerializedNotImplemented)
from langchain_core.load.serializable import (
Serializable,
SerializedConstructor,
SerializedNotImplemented,
)
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate

# from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import (ConfigurableField, ConfigurableFieldSpec,
Runnable, RunnableConfig,
RunnablePassthrough,
RunnableSerializable)
from langchain_core.runnables import (
ConfigurableField,
ConfigurableFieldSpec,
Runnable,
RunnableConfig,
RunnablePassthrough,
RunnableSerializable,
)
from langchain_core.runnables.base import Runnable, RunnableBindingBase
from langchain_core.runnables.utils import (AddableDict, AnyConfigurableField,
ConfigurableField,
ConfigurableFieldSpec, Input,
Output, create_model,
get_unique_config_specs)
from langchain_core.runnables.utils import (
AddableDict,
AnyConfigurableField,
ConfigurableField,
ConfigurableFieldSpec,
Input,
Output,
create_model,
get_unique_config_specs,
)
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from pydantic import BaseModel, Field
from sqlalchemy.dialects.postgresql import insert
Expand All @@ -56,8 +73,10 @@
from arcan.ai.llm import LLM
from arcan.ai.parser import ArcanOutputParser
from arcan.ai.prompts import arcan_prompt, spells_agent_prompt

# from arcan.ai.router import semantic_layer
from arcan.ai.tools import tools as spells

# from arcan.api.session import ArcanSession
# from arcan.ai.agents import ArcanAgent
from arcan.datamodel.chat_history import ChatHistory
Expand All @@ -73,7 +92,7 @@ class ArcanAgent(RunnableSerializable):
chat_history: List = Field(default_factory=list)
user_id: Optional[str] = None
verbose: bool = False
prompt: ChatPromptTemplate = spells_agent_prompt,
prompt: ChatPromptTemplate = (spells_agent_prompt,)
configs: List[ConfigurableFieldSpec] = Field(default_factory=list)
llm_with_tools: LLM = Field(default_factory=lambda: LLM().llm)
agent: Runnable = Field(default_factory=RunnablePassthrough)
Expand All @@ -82,22 +101,38 @@ class ArcanAgent(RunnableSerializable):

class Config:
arbitrary_types_allowed = True
extra = 'allow' # This allows additional fields not explicitly defined

def __init__(self, llm=None, tools: list = spells, prompt: ChatPromptTemplate = spells_agent_prompt,
agent_type="arcan_spells_agent", chat_history: list = [],
user_id: str = None, verbose: bool = False, configs: list = [],
**kwargs):
super().__init__(tools=tools, agent_type=agent_type, chat_history=chat_history,
user_id=user_id, verbose=verbose, prompt=prompt, configs=configs, **kwargs)
object.__setattr__(self, '_llm', llm or LLM().llm)
extra = "allow" # This allows additional fields not explicitly defined

def __init__(
self,
llm=None,
tools: list = spells,
prompt: ChatPromptTemplate = spells_agent_prompt,
agent_type="arcan_spells_agent",
chat_history: list = [],
user_id: str = None,
verbose: bool = False,
configs: list = [],
**kwargs,
):
super().__init__(
tools=tools,
agent_type=agent_type,
chat_history=chat_history,
user_id=user_id,
verbose=verbose,
prompt=prompt,
configs=configs,
**kwargs,
)
object.__setattr__(self, "_llm", llm or LLM().llm)
# Initialize other fields after the main Pydantic initialization
self.session: ArcanSession = ArcanSession()
self.bare_tools = load_tools(["llm-math"], llm=self.llm)
self.agent_tools = self.tools + self.bare_tools
self.llm_with_tools = self.llm.bind_tools(self.agent_tools)
self.agent, self.runnable = self.get_or_create_agent(self.user_id)

@property
def llm(self):
return self._llm
Expand Down Expand Up @@ -238,8 +273,12 @@ def invoke(
]
)
try:
self.session.store_message(user_id=self.user_id, body=user_content, response=response['output'])
self.session.store_chat_history(user_id=self.user_id, agent_history=self.chat_history)
self.session.store_message(
user_id=self.user_id, body=user_content, response=response["output"]
)
self.session.store_chat_history(
user_id=self.user_id, agent_history=self.chat_history
)
except SQLAlchemyError as e:
self.session.database.rollback()
print(f"Error storing conversation in database: {e}")
Expand Down
30 changes: 15 additions & 15 deletions arcan/ai/agents/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#%%
# %%

import ast
import os
Expand Down Expand Up @@ -59,7 +59,7 @@ def get_chat_history(self, user_id: str) -> list:
with self._get_session() as db_session:
history = (
db_session.query(ChatHistory)
# .options(joinedload(ChatHistory.history))
# .options(joinedload(ChatHistory.history))
.filter(ChatHistory.sender == user_id)
.order_by(ChatHistory.updated_at.asc())
.all()
Expand Down Expand Up @@ -87,19 +87,19 @@ def rollback(self):
# self.database_uri = os.environ.get("SQLALCHEMY_URL")
# self.agents: Dict[str, weakref.ref] = weakref.WeakValueDictionary()

# def store_message(self, user_id: str, body: str, response: str):
# """
# Stores a message in the database.

# :param user_id: The unique identifier for the user.
# :param Body: The body of the message sent by the user.
# :param response: The response generated by the system.
# """
# with self.database as db_session:
# conversation = Conversation(sender=user_id, message=body, response=response)
# db_session.add(conversation)
# db_session.commit()
# print(f"Conversation #{conversation.id} stored in database")
# def store_message(self, user_id: str, body: str, response: str):
# """
# Stores a message in the database.

# :param user_id: The unique identifier for the user.
# :param Body: The body of the message sent by the user.
# :param response: The response generated by the system.
# """
# with self.database as db_session:
# conversation = Conversation(sender=user_id, message=body, response=response)
# db_session.add(conversation)
# db_session.commit()
# print(f"Conversation #{conversation.id} stored in database")

# def store_chat_history(self, user_id, agent_history):
# """
Expand Down
6 changes: 3 additions & 3 deletions arcan/ai/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ class LLMFactory:
os.getenv("TOGETHER_MODEL_NAME", "llama3-8b-8192"),
),
),
'ChatOllama' : lambda **kwargs: ChatOllama(
model = kwargs.get("model", os.getenv("OLLAMA_MODEL", "phi3")),
)
"ChatOllama": lambda **kwargs: ChatOllama(
model=kwargs.get("model", os.getenv("OLLAMA_MODEL", "phi3")),
),
}

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion arcan/ai/router/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_response(self, query: str, user_id: str) -> str:
return route_text, query
else:
print(f"No route found for query: {query}")
return 'No Router Matched', query
return "No Router Matched", query


# Initialize RouteManager with an encoder
Expand Down
6 changes: 3 additions & 3 deletions arcan/ai/runnables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def get_runnable(self, runnable_name: str, cache: bool = True) -> RemoteRunnable
class ArcanRunnables:
def __init__(self, base_url: str = "http://localhost:8000/"):
self.factory = RunnableFactory(base_url=base_url)

def get_spells_runnable(self) -> AgentExecutor:
return self.factory.get_runnable(runnable_name="spells")

Expand All @@ -39,9 +39,9 @@ def get_ollama_runnable(self) -> AgentExecutor:

def get_auth_spells_runnable(self) -> AgentExecutor:
return self.factory.get_runnable(runnable_name="auth_spells")

def get_chain_with_history_runnable(self) -> AgentExecutor:
return self.factory.get_runnable(runnable_name="chain_with_history")


#%%
# %%
Loading

0 comments on commit d550e66

Please sign in to comment.