diff --git a/arcan/ai/agents/__init__.py b/arcan/ai/agents/__init__.py index 47f1456..94cad89 100644 --- a/arcan/ai/agents/__init__.py +++ b/arcan/ai/agents/__init__.py @@ -8,43 +8,60 @@ import pickle import weakref from datetime import datetime + # Ensure necessary imports for ArcanAgent from tempfile import TemporaryDirectory from typing import Any, AsyncIterator, Dict, List, Optional, cast from fastapi import Depends from fastapi.responses import StreamingResponse -from langchain.agents import (AgentExecutor, AgentType, - create_tool_calling_agent, initialize_agent, - load_tools) +from langchain.agents import ( + AgentExecutor, + AgentType, + create_tool_calling_agent, + initialize_agent, + load_tools, +) from langchain.agents.agent_types import AgentType -from langchain.agents.format_scratchpad.openai_tools import \ - format_to_openai_tool_messages +from langchain.agents.format_scratchpad.openai_tools import ( + format_to_openai_tool_messages, +) from langchain.agents.format_scratchpad.tools import format_to_tool_messages from langchain.agents.output_parsers.tools import ToolsAgentOutputParser from langchain.embeddings.openai import OpenAIEmbeddings from langchain.memory import ConversationBufferMemory from langchain.pydantic_v1 import BaseModel from langchain.sql_database import SQLDatabase -from langchain_community.agent_toolkits import (FileManagementToolkit, - SQLDatabaseToolkit) +from langchain_community.agent_toolkits import FileManagementToolkit, SQLDatabaseToolkit from langchain_core.callbacks import CallbackManagerForChainRun -from langchain_core.load.serializable import (Serializable, - SerializedConstructor, - SerializedNotImplemented) +from langchain_core.load.serializable import ( + Serializable, + SerializedConstructor, + SerializedNotImplemented, +) from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.prompts import ChatPromptTemplate + # from langchain_core.pydantic_v1 import BaseModel -from langchain_core.runnables import (ConfigurableField, ConfigurableFieldSpec, - Runnable, RunnableConfig, - RunnablePassthrough, - RunnableSerializable) +from langchain_core.runnables import ( + ConfigurableField, + ConfigurableFieldSpec, + Runnable, + RunnableConfig, + RunnablePassthrough, + RunnableSerializable, +) from langchain_core.runnables.base import Runnable, RunnableBindingBase -from langchain_core.runnables.utils import (AddableDict, AnyConfigurableField, - ConfigurableField, - ConfigurableFieldSpec, Input, - Output, create_model, - get_unique_config_specs) +from langchain_core.runnables.utils import ( + AddableDict, + AnyConfigurableField, + ConfigurableField, + ConfigurableFieldSpec, + Input, + Output, + create_model, + get_unique_config_specs, +) from langchain_openai import ChatOpenAI, OpenAIEmbeddings from pydantic import BaseModel, Field from sqlalchemy.dialects.postgresql import insert @@ -56,8 +73,10 @@ from arcan.ai.llm import LLM from arcan.ai.parser import ArcanOutputParser from arcan.ai.prompts import arcan_prompt, spells_agent_prompt + # from arcan.ai.router import semantic_layer from arcan.ai.tools import tools as spells + # from arcan.api.session import ArcanSession # from arcan.ai.agents import ArcanAgent from arcan.datamodel.chat_history import ChatHistory @@ -73,7 +92,7 @@ class ArcanAgent(RunnableSerializable): chat_history: List = Field(default_factory=list) user_id: Optional[str] = None verbose: bool = False - prompt: ChatPromptTemplate = spells_agent_prompt, + prompt: ChatPromptTemplate = (spells_agent_prompt,) configs: List[ConfigurableFieldSpec] = Field(default_factory=list) llm_with_tools: LLM = Field(default_factory=lambda: LLM().llm) agent: Runnable = Field(default_factory=RunnablePassthrough) @@ -82,22 +101,38 @@ class ArcanAgent(RunnableSerializable): class Config: arbitrary_types_allowed = True - extra = 'allow' # This allows additional fields not explicitly defined - - def __init__(self, llm=None, tools: list = spells, prompt: ChatPromptTemplate = spells_agent_prompt, - agent_type="arcan_spells_agent", chat_history: list = [], - user_id: str = None, verbose: bool = False, configs: list = [], - **kwargs): - super().__init__(tools=tools, agent_type=agent_type, chat_history=chat_history, - user_id=user_id, verbose=verbose, prompt=prompt, configs=configs, **kwargs) - object.__setattr__(self, '_llm', llm or LLM().llm) + extra = "allow" # This allows additional fields not explicitly defined + + def __init__( + self, + llm=None, + tools: list = spells, + prompt: ChatPromptTemplate = spells_agent_prompt, + agent_type="arcan_spells_agent", + chat_history: list = [], + user_id: str = None, + verbose: bool = False, + configs: list = [], + **kwargs, + ): + super().__init__( + tools=tools, + agent_type=agent_type, + chat_history=chat_history, + user_id=user_id, + verbose=verbose, + prompt=prompt, + configs=configs, + **kwargs, + ) + object.__setattr__(self, "_llm", llm or LLM().llm) # Initialize other fields after the main Pydantic initialization self.session: ArcanSession = ArcanSession() self.bare_tools = load_tools(["llm-math"], llm=self.llm) self.agent_tools = self.tools + self.bare_tools self.llm_with_tools = self.llm.bind_tools(self.agent_tools) self.agent, self.runnable = self.get_or_create_agent(self.user_id) - + @property def llm(self): return self._llm @@ -238,8 +273,12 @@ def invoke( ] ) try: - self.session.store_message(user_id=self.user_id, body=user_content, response=response['output']) - self.session.store_chat_history(user_id=self.user_id, agent_history=self.chat_history) + self.session.store_message( + user_id=self.user_id, body=user_content, response=response["output"] + ) + self.session.store_chat_history( + user_id=self.user_id, agent_history=self.chat_history + ) except SQLAlchemyError as e: self.session.database.rollback() print(f"Error storing conversation in database: {e}") diff --git a/arcan/ai/agents/session.py b/arcan/ai/agents/session.py index 79022a1..7e25fec 100644 --- a/arcan/ai/agents/session.py +++ b/arcan/ai/agents/session.py @@ -1,4 +1,4 @@ -#%% +# %% import ast import os @@ -59,7 +59,7 @@ def get_chat_history(self, user_id: str) -> list: with self._get_session() as db_session: history = ( db_session.query(ChatHistory) - # .options(joinedload(ChatHistory.history)) + # .options(joinedload(ChatHistory.history)) .filter(ChatHistory.sender == user_id) .order_by(ChatHistory.updated_at.asc()) .all() @@ -87,19 +87,19 @@ def rollback(self): # self.database_uri = os.environ.get("SQLALCHEMY_URL") # self.agents: Dict[str, weakref.ref] = weakref.WeakValueDictionary() - # def store_message(self, user_id: str, body: str, response: str): - # """ - # Stores a message in the database. - - # :param user_id: The unique identifier for the user. - # :param Body: The body of the message sent by the user. - # :param response: The response generated by the system. - # """ - # with self.database as db_session: - # conversation = Conversation(sender=user_id, message=body, response=response) - # db_session.add(conversation) - # db_session.commit() - # print(f"Conversation #{conversation.id} stored in database") +# def store_message(self, user_id: str, body: str, response: str): +# """ +# Stores a message in the database. + +# :param user_id: The unique identifier for the user. +# :param Body: The body of the message sent by the user. +# :param response: The response generated by the system. +# """ +# with self.database as db_session: +# conversation = Conversation(sender=user_id, message=body, response=response) +# db_session.add(conversation) +# db_session.commit() +# print(f"Conversation #{conversation.id} stored in database") # def store_chat_history(self, user_id, agent_history): # """ diff --git a/arcan/ai/llm/__init__.py b/arcan/ai/llm/__init__.py index bebd726..6ed71d3 100644 --- a/arcan/ai/llm/__init__.py +++ b/arcan/ai/llm/__init__.py @@ -70,9 +70,9 @@ class LLMFactory: os.getenv("TOGETHER_MODEL_NAME", "llama3-8b-8192"), ), ), - 'ChatOllama' : lambda **kwargs: ChatOllama( - model = kwargs.get("model", os.getenv("OLLAMA_MODEL", "phi3")), - ) + "ChatOllama": lambda **kwargs: ChatOllama( + model=kwargs.get("model", os.getenv("OLLAMA_MODEL", "phi3")), + ), } @staticmethod diff --git a/arcan/ai/router/__init__.py b/arcan/ai/router/__init__.py index c3dac79..2e26396 100644 --- a/arcan/ai/router/__init__.py +++ b/arcan/ai/router/__init__.py @@ -68,7 +68,7 @@ def get_response(self, query: str, user_id: str) -> str: return route_text, query else: print(f"No route found for query: {query}") - return 'No Router Matched', query + return "No Router Matched", query # Initialize RouteManager with an encoder diff --git a/arcan/ai/runnables/__init__.py b/arcan/ai/runnables/__init__.py index 8e643a0..1c90fcb 100644 --- a/arcan/ai/runnables/__init__.py +++ b/arcan/ai/runnables/__init__.py @@ -24,7 +24,7 @@ def get_runnable(self, runnable_name: str, cache: bool = True) -> RemoteRunnable class ArcanRunnables: def __init__(self, base_url: str = "http://localhost:8000/"): self.factory = RunnableFactory(base_url=base_url) - + def get_spells_runnable(self) -> AgentExecutor: return self.factory.get_runnable(runnable_name="spells") @@ -39,9 +39,9 @@ def get_ollama_runnable(self) -> AgentExecutor: def get_auth_spells_runnable(self) -> AgentExecutor: return self.factory.get_runnable(runnable_name="auth_spells") - + def get_chain_with_history_runnable(self) -> AgentExecutor: return self.factory.get_runnable(runnable_name="chain_with_history") -#%% \ No newline at end of file +# %% diff --git a/arcan/api/__init__.py b/arcan/api/__init__.py index 228c647..b07467d 100644 --- a/arcan/api/__init__.py +++ b/arcan/api/__init__.py @@ -6,13 +6,17 @@ from typing import Annotated, Any, Callable, Dict, List, Optional, Union from dotenv import load_dotenv -from fastapi import (Depends, FastAPI, Form, Header, HTTPException, Request, - status) +from fastapi import Depends, FastAPI, Form, Header, HTTPException, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse + # %% -from fastapi.security import (HTTPAuthorizationCredentials, HTTPBearer, - OAuth2PasswordBearer, OAuth2PasswordRequestForm) +from fastapi.security import ( + HTTPAuthorizationCredentials, + HTTPBearer, + OAuth2PasswordBearer, + OAuth2PasswordRequestForm, +) from langchain_community.chat_message_histories import FileChatMessageHistory from langchain_core import __version__ from langchain_core.chat_history import BaseChatMessageHistory @@ -33,14 +37,20 @@ from arcan.ai.llm import LLM from arcan.api.auth import fetch_session_from_header from arcan.datamodel.engine import session_scope # , session_scope_context -from arcan.datamodel.user import (ACCESS_TOKEN_EXPIRE_MINUTES, TokenModel, - UserModel, UserRepository, UserService, - oauth2_scheme, pwd_context) +from arcan.datamodel.user import ( + ACCESS_TOKEN_EXPIRE_MINUTES, + TokenModel, + UserModel, + UserRepository, + UserService, + oauth2_scheme, + pwd_context, +) # from arcan.spells.vector_search import (get_per_user_retriever, # per_req_config_modifier, pgVectorStore) -#%% +# %% MIN_VERSION_LANGCHAIN_CORE = (0, 1, 0) # Split the version string by "." and convert to integers @@ -53,7 +63,7 @@ ) -#%% +# %% auth_scheme = HTTPBearer() load_dotenv() @@ -97,12 +107,12 @@ async def chat( # from arcan.api.session import ArcanSession, run_agent agent = ArcanAgent(user_id=user_id) # user = await get_current_active_user_from_request(request=Request) - response = agent.invoke({'input':query}) + response = agent.invoke({"input": query}) elif ENVIRONMENT == "local": agent = ArcanAgent( - user_id=user_id, - ) - response = agent.invoke({'input': query, 'chat_history': []}) + user_id=user_id, + ) + response = agent.invoke({"input": query, "chat_history": []}) return {"response": response} @@ -114,12 +124,17 @@ class Output(BaseModel): output: Any -dynamic_spells_model = ArcanAgent().configurable_fields( - user_id=ConfigurableField( - id="user_id", - name="Arcan AI User ID", - description=("user_id Key for Arcan AI interactions"), - )).with_types(input_type=Input, output_type=Output) +dynamic_spells_model = ( + ArcanAgent() + .configurable_fields( + user_id=ConfigurableField( + id="user_id", + name="Arcan AI User ID", + description=("user_id Key for Arcan AI interactions"), + ) + ) + .with_types(input_type=Input, output_type=Output) +) add_routes( app=app, @@ -200,7 +215,7 @@ async def get_current_active_user_from_request( headers={"WWW-Authenticate": "Bearer"}, ) # if user.disabled: - # raise HTTPException(status_code=400, detail="Inactive user") + # raise HTTPException(status_code=400, detail="Inactive user") return user @@ -264,7 +279,6 @@ async def get_current_active_user_from_request( # return get_chat_history - # def _per_request_config_modifier( # config: Dict[str, Any], request: Request # ) -> Dict[str, Any]: @@ -330,8 +344,6 @@ async def get_current_active_user_from_request( # ).with_types(input_type=InputChat) - - # add_routes( # app, # chain_with_history, @@ -348,7 +360,6 @@ async def get_current_active_user_from_request( # ) - # def _per_request_session_modifier( # config: Dict[str, Any], request: Request # ) -> Dict[str, Any]: @@ -363,7 +374,7 @@ async def get_current_active_user_from_request( # status_code=400, # detail="No user id found. Please set a cookie named 'user_id'.", # ) - + # agent = ArcanAgent(user_id=user_id) # configurable["user_id"] = user_id @@ -385,7 +396,7 @@ async def get_current_active_user_from_request( # disabled_endpoints=["playground", "batch"], # ) -#%% +# %% if __name__ == "__main__": import uvicorn diff --git a/arcan/api/auth.py b/arcan/api/auth.py index 0c01ee3..bcffc8c 100644 --- a/arcan/api/auth.py +++ b/arcan/api/auth.py @@ -19,7 +19,7 @@ def fetch_session_from_header(config: Dict[str, Any], req: Request) -> Dict[str, Any]: config = config.copy() configurable = config.get("configurable", {}) - + if "arcanai_api_key" in req.headers: if "user_id" in req.headers: configurable["user_id"] = req.headers["user_id"] @@ -31,10 +31,8 @@ def fetch_session_from_header(config: Dict[str, Any], req: Request) -> Dict[str, return config - - def _is_valid_identifier(value: str) -> bool: """Check if the value is a valid identifier.""" # Use a regular expression to match the allowed characters valid_characters = re.compile(r"^[a-zA-Z0-9-_]+$") - return bool(valid_characters.match(value)) \ No newline at end of file + return bool(valid_characters.match(value)) diff --git a/arcan/datamodel/engine.py b/arcan/datamodel/engine.py index 7b5e089..0aed7e8 100644 --- a/arcan/datamodel/engine.py +++ b/arcan/datamodel/engine.py @@ -8,6 +8,7 @@ load_dotenv() + class Config: DATABASE_URL = os.getenv("SQLALCHEMY_URL") ENVIRONMENT = os.getenv("ENVIRONMENT") @@ -15,33 +16,36 @@ class Config: class EngineFactory: def __init__(self): - self.engines = { - 'local': self.local_engine, - 'cloud': self.cloud_engine - } - + self.engines = {"local": self.local_engine, "cloud": self.cloud_engine} + def get_engine(self): # Fetch the appropriate engine creation method from the dictionary - engine_type = Config.ENVIRONMENT or 'cloud' # Default to 'cloud' if not specified - engine_creator = self.engines.get(engine_type, self.cloud_engine) # Fallback to cloud engine + engine_type = ( + Config.ENVIRONMENT or "cloud" + ) # Default to 'cloud' if not specified + engine_creator = self.engines.get( + engine_type, self.cloud_engine + ) # Fallback to cloud engine return engine_creator() - + def local_engine(self): - """ Create a local SQLite engine """ - return create_engine('sqlite:////arcan.db') - + """Create a local SQLite engine""" + return create_engine("sqlite:////arcan.db") + def cloud_engine(self): - """ Create a cloud engine from a URL in the config """ + """Create a cloud engine from a URL in the config""" if not Config.DATABASE_URL: raise ValueError("No database URL provided for cloud environment.") return create_engine(Config.DATABASE_URL) + factory = EngineFactory() engine = factory.get_engine() SessionLocal = sessionmaker(bind=engine) Base = declarative_base() + @contextmanager def session_scope(): """Provide a transactional scope around a series of operations.""" @@ -69,7 +73,7 @@ def session_scope(): # yield db # finally: # db.close() - + # @contextmanager # def get_db_context(): diff --git a/arcan/datamodel/user.py b/arcan/datamodel/user.py index e4f6391..f6b97d4 100644 --- a/arcan/datamodel/user.py +++ b/arcan/datamodel/user.py @@ -8,8 +8,7 @@ from jose import JWTError, jwt from passlib.context import CryptContext from pydantic import BaseModel -from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, String, - Text) +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text from sqlalchemy.orm import Session, relationship from arcan.datamodel.engine import Base, engine diff --git a/tests/arcan/ai/runnables/test_runnables.py b/tests/arcan/ai/runnables/test_runnables.py index dd8d424..c12ebd7 100644 --- a/tests/arcan/ai/runnables/test_runnables.py +++ b/tests/arcan/ai/runnables/test_runnables.py @@ -12,6 +12,7 @@ def base_url(): return "http://localhost:8000/" + def test_get_spells_runnable(base_url): runnable_factory = MagicMock() arcan_runnables = ArcanRunnables(base_url=base_url) @@ -20,8 +21,11 @@ def test_get_spells_runnable(base_url): arcan_runnables.get_spells_runnable() runnable_factory.get_runnable.assert_called_once_with(runnable_name="spells") - - assert arcan_runnables.get_spells_runnable().invoke({'input': 'testinggggg$#@'}).json() == {"response": "test"} + + assert arcan_runnables.get_spells_runnable().invoke( + {"input": "testinggggg$#@"} + ).json() == {"response": "test"} + def test_get_openai_runnable(base_url): runnable_factory = MagicMock() @@ -32,6 +36,7 @@ def test_get_openai_runnable(base_url): runnable_factory.get_runnable.assert_called_once_with(runnable_name="openai") + def test_get_groq_runnable(base_url): runnable_factory = MagicMock() arcan_runnables = ArcanRunnables(base_url=base_url) @@ -41,6 +46,7 @@ def test_get_groq_runnable(base_url): runnable_factory.get_runnable.assert_called_once_with(runnable_name="groq") + # def test_get_ollama_runnable(base_url): # runnable_factory = MagicMock() # arcan_runnables = ArcanRunnables(base_url=base_url) @@ -48,4 +54,4 @@ def test_get_groq_runnable(base_url): # arcan_runnables.get_ollama_runnable() -# runnable_factory.get_runnable.assert_called_once_with(runnable_name="ollama") \ No newline at end of file +# runnable_factory.get_runnable.assert_called_once_with(runnable_name="ollama")