diff --git a/.gitignore b/.gitignore index b4c6676..dfbd4c3 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ __pycache__/ # C extensions *.so +arcan.db + # Distribution / packaging .Python build/ diff --git a/arcan/ai/agents/__init__.py b/arcan/ai/agents/__init__.py index ddb95f0..94cad89 100644 --- a/arcan/ai/agents/__init__.py +++ b/arcan/ai/agents/__init__.py @@ -1,138 +1,876 @@ # %% +# %% +from __future__ import annotations + +import ast import asyncio import os +import pickle +import weakref +from datetime import datetime + +# Ensure necessary imports for ArcanAgent from tempfile import TemporaryDirectory +from typing import Any, AsyncIterator, Dict, List, Optional, cast +from fastapi import Depends from fastapi.responses import StreamingResponse -from langchain.agents import AgentExecutor, create_tool_calling_agent, load_tools +from langchain.agents import ( + AgentExecutor, + AgentType, + create_tool_calling_agent, + initialize_agent, + load_tools, +) +from langchain.agents.agent_types import AgentType from langchain.agents.format_scratchpad.openai_tools import ( format_to_openai_tool_messages, ) -from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser +from langchain.agents.format_scratchpad.tools import format_to_tool_messages +from langchain.agents.output_parsers.tools import ToolsAgentOutputParser +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.memory import ConversationBufferMemory +from langchain.pydantic_v1 import BaseModel from langchain.sql_database import SQLDatabase from langchain_community.agent_toolkits import FileManagementToolkit, SQLDatabaseToolkit -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.load.serializable import ( + Serializable, + SerializedConstructor, + SerializedNotImplemented, +) +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_core.prompts import ChatPromptTemplate + +# from langchain_core.pydantic_v1 import BaseModel +from langchain_core.runnables import ( + ConfigurableField, + ConfigurableFieldSpec, + Runnable, + RunnableConfig, + RunnablePassthrough, + RunnableSerializable, +) +from langchain_core.runnables.base import Runnable, RunnableBindingBase +from langchain_core.runnables.utils import ( + AddableDict, + AnyConfigurableField, + ConfigurableField, + ConfigurableFieldSpec, + Input, + Output, + create_model, + get_unique_config_specs, +) +from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from pydantic import BaseModel, Field +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session from arcan.ai.agents.helpers import AsyncIteratorCallbackHandler +from arcan.ai.agents.session import ArcanSession from arcan.ai.llm import LLM +from arcan.ai.parser import ArcanOutputParser from arcan.ai.prompts import arcan_prompt, spells_agent_prompt # from arcan.ai.router import semantic_layer from arcan.ai.tools import tools as spells - -class ArcanAgent: - """ - Represents a Arcan Agent that interacts with the user and provides responses using OpenAI tools. - - Attributes: - llm (LLM): The Language Model Manager used by the agent. - tools (list): The list of tools used by the agent. - hub_prompt (str): The prompt for the OpenAI tools agent. - agent_type (str): The type of the agent. - chat_history (list): The chat history of the agent. - llm_with_tools: The Language Model Manager with the tools bound. - prompt: The chat prompt template for the agent. - agent: The agent pipeline. - agent_executor: The executor for the agent. - user_id: The unique identifier for the user. - verbose: A boolean indicating whether to print verbose output. - - Methods: - get_response: Gets the response from the agent given user input. - - """ +# from arcan.api.session import ArcanSession +# from arcan.ai.agents import ArcanAgent +from arcan.datamodel.chat_history import ChatHistory +from arcan.datamodel.conversation import Conversation +from arcan.datamodel.engine import session_scope + + +class ArcanAgent(RunnableSerializable): + tools: List = Field(default_factory=list) + bare_tools: List = Field(default_factory=list) + agent_tools: List = Field(default_factory=list) + agent_type: str = "arcan_spells_agent" + chat_history: List = Field(default_factory=list) + user_id: Optional[str] = None + verbose: bool = False + prompt: ChatPromptTemplate = (spells_agent_prompt,) + configs: List[ConfigurableFieldSpec] = Field(default_factory=list) + llm_with_tools: LLM = Field(default_factory=lambda: LLM().llm) + agent: Runnable = Field(default_factory=RunnablePassthrough) + runnable: Runnable = Field(default_factory=RunnablePassthrough) + session: ArcanSession = Field(default_factory=ArcanSession) + + class Config: + arbitrary_types_allowed = True + extra = "allow" # This allows additional fields not explicitly defined def __init__( self, - database: SQLDatabase, - llm: LLM = LLM().llm, + llm=None, tools: list = spells, - hub_prompt: str = "broomva/arcan", + prompt: ChatPromptTemplate = spells_agent_prompt, agent_type="arcan_spells_agent", - context: list = [], # represents the chat history, can be pulled from a db + chat_history: list = [], user_id: str = None, verbose: bool = False, + configs: list = [], + **kwargs, ): - self.llm: LLM = llm - self.tools: list = tools - self.hub_prompt: str = hub_prompt - self.agent_type: str = agent_type - self.chat_history: list = context - self.user_id: str = user_id - self.verbose: bool = verbose - - self.db = database - self.toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm) - self.context = self.toolkit.get_context() - self.prompt = arcan_prompt.partial(**self.context) - self.sql_tools = self.toolkit.get_tools() - self.working_directory = TemporaryDirectory() - self.file_system_tools = FileManagementToolkit( - root_dir=str(self.working_directory.name) - ).get_tools() - self.parser = OpenAIToolsAgentOutputParser() - self.bare_tools = load_tools( - [ - "llm-math", - # "human", - # "wolfram-alpha" - ], - llm=self.llm, - ) - self.agent_tools = ( - self.tools + self.bare_tools # + self.sql_tools + self.file_system_tools + super().__init__( + tools=tools, + agent_type=agent_type, + chat_history=chat_history, + user_id=user_id, + verbose=verbose, + prompt=prompt, + configs=configs, + **kwargs, ) + object.__setattr__(self, "_llm", llm or LLM().llm) + # Initialize other fields after the main Pydantic initialization + self.session: ArcanSession = ArcanSession() + self.bare_tools = load_tools(["llm-math"], llm=self.llm) + self.agent_tools = self.tools + self.bare_tools self.llm_with_tools = self.llm.bind_tools(self.agent_tools) - self.agent = ( - { - "input": lambda x: x["input"], - "agent_scratchpad": lambda x: format_to_openai_tool_messages( + self.agent, self.runnable = self.get_or_create_agent(self.user_id) + + @property + def llm(self): + return self._llm + + @property + def default_configs(self): + return [ + ConfigurableFieldSpec( + id="user_id", + annotation=str, + name="User ID", + description="Unique identifier for the user.", + default="", + is_shared=True, + ), + ConfigurableFieldSpec( + id="conversation_id", + annotation=str, + name="Conversation ID", + description="Unique identifier for the conversation.", + default="", + is_shared=True, + ), + ] + + @property + def config_specs(self) -> List[ConfigurableFieldSpec]: + return self.agent.config_specs + + async def astream( + self, + input: Input, + config: Optional[RunnableConfig] = None, + **kwargs: Optional[Any], + ) -> AsyncIterator[Output]: + """Stream the agent's output.""" + configurable = cast(Dict[str, Any], config.pop("configurable", {})) + + if configurable: + configured_agent = self.agent.with_config( + { + "configurable": configurable, + } + ) + else: + configured_agent = self.agent + + self.runnable.with_config({"run_name": "executor"}) + + async for output in self.runnable.astream(input, config=config, **kwargs): + yield output + + def get_or_create_agent( + self, user_id: str, provided_agent: ArcanAgent = None + ) -> ArcanAgent: + """ + Retrieves or creates a ArcanAgent for a given user_id. + + :param user_id: The unique identifier for the user. + :return: An instance of ArcanAgent. + """ + if provided_agent is None: + agent = self.session.agents.get(user_id) + chat_history = [] + + # Obtain a new database session + try: + chat_history = self.session.get_chat_history(user_id) + except Exception as e: + print(f"Error getting chat history for {user_id}: {e}") + + if agent is not None and chat_history: + print(f"Using existing agent {agent}") + elif agent is None and chat_history: + print(f"Using reloaded agent with history {chat_history}") + self.chat_history = chat_history + elif agent is None and not chat_history: + print("Using a new agent") + agent, runnable = self.get_agent() + self.session.agents[user_id] = agent + return agent, runnable + else: + provided_agent.user_id = user_id + self.session.agents[user_id] = provided_agent + return provided_agent, provided_agent.runnable + + def get_agent(self): + """ + Retrieves or creates a ArcanAgent for a given user_id. + + :param user_id: The unique identifier for the user. + :return: An instance of ArcanAgent. + """ + if self.session is None: + raise ValueError("Session is not initialized.") + agent = ( + RunnablePassthrough.assign( + agent_scratchpad=lambda x: format_to_tool_messages( x["intermediate_steps"] - ), - "chat_history": lambda x: x["chat_history"], - } + ) + ) | self.prompt | self.llm_with_tools - | self.parser + | ToolsAgentOutputParser() ) - self.agent_executor = AgentExecutor( - agent=self.agent, tools=self.agent_tools, verbose=self.verbose + runnable = AgentExecutor( + agent=agent, tools=self.agent_tools, verbose=self.verbose ) + return agent, runnable - def get_response(self, user_content: str): + def invoke( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: """ - Gets the response from the agent given user input. - - Args: - user_content (str): The user input. - - Returns: - str: The response from the agent. - + Override the invoke method to include custom logic. """ - # routed_content = semantic_layer(query=user_content, user_id=self.user_id) - response = self.agent_executor.invoke( + user_content = inputs.get("input") + if not user_content: + raise ValueError("Input must contain 'input' key with user content.") + + # route_text, routed_content = semantic_layer( + # query=user_content, user_id=self.user_id + # ) + self.chat_history.extend( + [ + # SystemMessage(content=route_text), + HumanMessage(content=user_content), + ] + ) + response = self.runnable.invoke( {"input": user_content, "chat_history": self.chat_history} ) self.chat_history.extend( [ - HumanMessage(content=user_content), AIMessage(content=response["output"]), ] ) - return response["output"] + try: + self.session.store_message( + user_id=self.user_id, body=user_content, response=response["output"] + ) + self.session.store_chat_history( + user_id=self.user_id, agent_history=self.chat_history + ) + except SQLAlchemyError as e: + self.session.database.rollback() + print(f"Error storing conversation in database: {e}") + return response + + # def configurable_fields( + # self, **kwargs: AnyConfigurableField + # ): + # """Configure particular runnable fields at runtime. + + # .. code-block:: python + + # from langchain_core.runnables import ConfigurableField + # from langchain_openai import ChatOpenAI + + # model = ChatOpenAI(max_tokens=20).configurable_fields( + # max_tokens=ConfigurableField( + # id="output_token_number", + # name="Max tokens in the output", + # description="The maximum number of tokens in the output", + # ) + # ) + + # # max_tokens = 20 + # print( + # "max_tokens_20: ", + # model.invoke("tell me something about chess").content + # ) + + # # max_tokens = 200 + # print("max_tokens_200: ", model.with_config( + # configurable={"output_token_number": 200} + # ).invoke("tell me something about chess").content + # ) + # """ + # from langchain_core.runnables.configurable import \ + # RunnableConfigurableFields + + # for key in kwargs: + # # print(f"Checking key {key} in {self}") + # # print(f"Available keys are {vars(self).keys()}") + # if key not in vars(self).keys(): + # raise ValueError( + # f"Configuration key {key} not found in {self}: " + # f"available keys are {vars(self).keys()}" + # ) + # # updated the self class arguments with the new values + # setattr(self, key, kwargs[key]) + # return self + + # %% + + # class ArcanAgent(RunnableSerializable): + # """ + # Represents an Arcan Agent that interacts with the user and provides responses using OpenAI tools. + + # Attributes: + # llm (LLM): The Language Model Manager used by the agent. + # tools (list): The list of tools used by the agent. + # hub_prompt (str): The prompt for the OpenAI tools agent. + # agent_type (str): The type of the agent. + # chat_history (list): The chat history of the agent. + # llm_with_tools: The Language Model Manager with the tools bound. + # prompt: The chat prompt template for the agent. + # agent: The agent pipeline. + # agent_executor: The executor for the agent. + # user_id: The unique identifier for the user. + # verbose: A boolean indicating whether to print verbose output. + # """ + # llm: LLM = LLM().llm + # tools: List = spells + # agent_type: str = 'arcan_spells_agent' + # chat_history: List = Field(default_factory=list) + # user_id: Optional[str] = None + # verbose: bool = False + # # Assuming session and prompt types are defined somewhere + # session: ArcanSession + # prompt: str = spells_agent_prompt + # configs: List[ConfigurableFieldSpec] = Field(default_factory=list) + + # class Config: + # arbitrary_types_allowed = True + + # def __init__(self, llm: LLM = LLM().llm, tools: list = spells, prompt: str = spells_agent_prompt, + # agent_type="arcan_spells_agent", chat_history: list = [], user_id: str = None, + # verbose: bool = False, configs: list = None, **kwargs): + # super().__init__(**kwargs) # Initialize BaseModel with kwargs + # self.llm = llm + # self.tools = tools + # self.agent_type = agent_type + # self.chat_history = chat_history + # self.user_id = user_id + # self.verbose = verbose + # self.session = ArcanSession() + # self.prompt = prompt + # self.working_directory = TemporaryDirectory() + # self.file_system_tools = FileManagementToolkit( + # root_dir=str(self.working_directory.name) + # ).get_tools() + # self.bare_tools = load_tools( + # [ + # "llm-math", + # ], + # llm=self.llm, + # ) + # self.agent_tools = self.tools + self.bare_tools + # self.llm_with_tools = self.llm.bind_tools(self.agent_tools) + # missing_vars = {"agent_scratchpad"}.difference( + # prompt.input_variables + list(prompt.partial_variables) + # ) + # if missing_vars: + # raise ValueError(f"Prompt missing required variables: {missing_vars}") + + # if not hasattr(llm, "bind_tools"): + # raise ValueError( + # "This function requires a .bind_tools method be implemented on the LLM.", + # ) + # self.llm_with_tools = llm.bind_tools(tools) + + # self.agent, self.runnable = self.get_or_create_agent(self.user_id) + + # self.configs = configs or [ + # ConfigurableFieldSpec( + # id="user_id", + # annotation=str, + # name="User ID", + # description="Unique identifier for the user.", + # default="", + # is_shared=True, + # ), + # ConfigurableFieldSpec( + # id="conversation_id", + # annotation=str, + # name="Conversation ID", + # description="Unique identifier for the conversation.", + # default="", + # is_shared=True, + # ) + # ] + + # def __init__( + # self, + # llm: LLM = LLM().llm, + # tools: list = spells, + # prompt: str = spells_agent_prompt, + # agent_type="arcan_spells_agent", + # chat_history: list = [], # represents the chat history, can be pulled from a db + # user_id: str = None, + # verbose: bool = False, + # session_factory: callable = session_scope, + # configs: list = [], + # **kwargs + # ): + # """Initialize the runnable.""" + # super().__init__(**kwargs) + # self.llm: LLM = llm + # self.tools: list = tools + # self.agent_type: str = agent_type + # self.chat_history: list = chat_history + # self.user_id: str = kwargs.get('user_id', user_id) + # self.verbose: bool = verbose + # self.session: ArcanSession = ArcanSession() + # self.prompt = prompt + # self.working_directory = TemporaryDirectory() + # self.file_system_tools = FileManagementToolkit( + # root_dir=str(self.working_directory.name) + # ).get_tools() + # self.bare_tools = load_tools( + # [ + # "llm-math", + # ], + # llm=self.llm, + # ) + # self.agent_tools = self.tools + self.bare_tools + # self.llm_with_tools = self.llm.bind_tools(self.agent_tools) + # missing_vars = {"agent_scratchpad"}.difference( + # prompt.input_variables + list(prompt.partial_variables) + # ) + # if missing_vars: + # raise ValueError(f"Prompt missing required variables: {missing_vars}") + + # if not hasattr(llm, "bind_tools"): + # raise ValueError( + # "This function requires a .bind_tools method be implemented on the LLM.", + # ) + # self.llm_with_tools = llm.bind_tools(tools) + + # self.agent, self.runnable = self.get_or_create_agent(self.user_id) + + # self.configs = configs if configs is not None else [ + # ConfigurableFieldSpec( + # id="user_id", + # annotation=str, + # name="User ID", + # description="Unique identifier for the user.", + # default="", + # is_shared=True, + # ), + # ConfigurableFieldSpec( + # id="conversation_id", + # annotation=str, + # name="Conversation ID", + # description="Unique identifier for the conversation.", + # default="", + # is_shared=True, + # ) + # ] + + # @property + # def config_specs(self) -> List[ConfigurableFieldSpec]: + # return self.agent.config_specs + + # async def astream( + # self, + # input: Input, + # config: Optional[RunnableConfig] = None, + # **kwargs: Optional[Any], + # ) -> AsyncIterator[Output]: + # """Stream the agent's output.""" + # configurable = cast(Dict[str, Any], config.pop("configurable", {})) + + # if configurable: + # configured_agent = self.agent.with_config( + # { + # "configurable": configurable, + # } + # ) + # else: + # configured_agent = self.agent + + # self.runnable.with_config({"run_name": "executor"}) + + # async for output in self.runnable.astream(input, config=config, **kwargs): + # yield output + + # def get_or_create_agent( + # self, user_id: str, provided_agent: ArcanAgent = None + # ) -> ArcanAgent: + # """ + # Retrieves or creates a ArcanAgent for a given user_id. + + # :param user_id: The unique identifier for the user. + # :return: An instance of ArcanAgent. + # """ + # if provided_agent is None: + # agent = self.session.agents.get(user_id) + # chat_history = [] + + # # Obtain a new database session + # try: + # chat_history = self.session.get_chat_history(user_id) + # except Exception as e: + # print(f"Error getting chat history for {user_id}: {e}") + + # if agent is not None and chat_history: + # print(f"Using existing agent {agent}") + # elif agent is None and chat_history: + # print(f"Using reloaded agent with history {chat_history}") + # self.chat_history = chat_history + # elif agent is None and not chat_history: + # print("Using a new agent") + # agent, runnable = self.get_agent() + # self.session.agents[user_id] = agent + # return agent, runnable + # else: + # provided_agent.user_id = user_id + # self.session.agents[user_id] = provided_agent + # return provided_agent, provided_agent.runnable + + # def get_agent(self): + # """ + # Retrieves or creates a ArcanAgent for a given user_id. + + # :param user_id: The unique identifier for the user. + # :return: An instance of ArcanAgent. + # """ + # if self.session is None: + # raise ValueError("Session is not initialized.") + # agent = ( + # RunnablePassthrough.assign( + # agent_scratchpad=lambda x: format_to_tool_messages( + # x["intermediate_steps"] + # ) + # ) + # | self.prompt + # | self.llm_with_tools + # | ToolsAgentOutputParser() + # ) + # runnable = AgentExecutor( + # agent=agent, tools=self.agent_tools, verbose=self.verbose + # ) + # return agent, runnable + + # def invoke( + # self, + # inputs: Dict[str, Any], + # run_manager: Optional[CallbackManagerForChainRun] = None, + # ) -> Dict[str, Any]: + # """ + # Override the invoke method to include custom logic. + # """ + # user_content = inputs.get("input") + # if not user_content: + # raise ValueError("Input must contain 'input' key with user content.") + + # # route_text, routed_content = semantic_layer( + # # query=user_content, user_id=self.user_id + # # ) + # self.chat_history.extend( + # [ + # # SystemMessage(content=route_text), + # HumanMessage(content=user_content), + # ] + # ) + # response = self.runnable.invoke( + # {"input": user_content, "chat_history": self.chat_history} + # ) + # self.chat_history.extend( + # [ + # AIMessage(content=response["output"]), + # ] + # ) + # # try: + # # self.session.store_message(user_id=self.user_id, body=user_content, response=response) + # # self.session.store_chat_history(user_id=self.user_id, agent_history=self.chat_history) + # # except SQLAlchemyError as e: + # # self.session.database.rollback() + # # print(f"Error storing conversation in database: {e}") + # return response + + # def configurable_fields(self, **kwargs: AnyConfigurableField): + # """Configure particular runnable fields at runtime. + + # .. code-block:: python + + # from langchain_core.runnables import ConfigurableField + # from langchain_openai import ChatOpenAI + + # model = ChatOpenAI(max_tokens=20).configurable_fields( + # max_tokens=ConfigurableField( + # id="output_token_number", + # name="Max tokens in the output", + # description="The maximum number of tokens in the output", + # ) + # ) + + # # max_tokens = 20 + # print( + # "max_tokens_20: ", + # model.invoke("tell me something about chess").content + # ) + + # # max_tokens = 200 + # print("max_tokens_200: ", model.with_config( + # configurable={"output_token_number": 200} + # ).invoke("tell me something about chess").content + # ) + # """ + # from langchain_core.runnables.configurable import \ + # RunnableConfigurableFields + + # for key in kwargs: + # # print(f"Checking key {key} in {self}") + # # print(f"Available keys are {vars(self).keys()}") + # if key not in vars(self).keys(): + # raise ValueError( + # f"Configuration key {key} not found in {self}: " + # f"available keys are {vars(self).keys()}" + # ) + # # updated the self class arguments with the new values + # setattr(self, key, kwargs[key]) + # return self # %% -# -from langchain.agents import AgentType, initialize_agent, load_tools -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.memory import ConversationBufferMemory -from pydantic import BaseModel +# %% -from arcan.ai.parser import ArcanOutputParser + +# %% + + +# class ArcanAgent: +# """ +# Represents a Arcan Agent that interacts with the user and provides responses using OpenAI tools. + +# Attributes: +# llm (LLM): The Language Model Manager used by the agent. +# tools (list): The list of tools used by the agent. +# hub_prompt (str): The prompt for the OpenAI tools agent. +# agent_type (str): The type of the agent. +# chat_history (list): The chat history of the agent. +# llm_with_tools: The Language Model Manager with the tools bound. +# prompt: The chat prompt template for the agent. +# agent: The agent pipeline. +# agent_executor: The executor for the agent. +# user_id: The unique identifier for the user. +# verbose: A boolean indicating whether to print verbose output. + +# Methods: +# get_response: Gets the response from the agent given user input. + +# """ + +# def __init__( +# self, +# # database: SQLDatabase, +# llm: LLM = LLM().llm, +# tools: list = spells, +# hub_prompt: str = "broomva/arcan", +# agent_type="arcan_spells_agent", +# context: list = [], # represents the chat history, can be pulled from a db +# user_id: str = None, +# verbose: bool = False, +# ): +# self.llm: LLM = llm +# self.tools: list = tools +# self.hub_prompt: str = hub_prompt +# self.agent_type: str = agent_type +# self.chat_history: list = context +# self.user_id: str = user_id +# self.verbose: bool = verbose + +# # self.db = database +# # self.toolkit = SQLDatabaseToolkit(db=self.db, llm=self.llm) +# # self.context = self.toolkit.get_context() +# self.prompt = arcan_prompt # .partial(**self.context) +# # self.sql_tools = self.toolkit.get_tools() +# self.working_directory = TemporaryDirectory() +# self.file_system_tools = FileManagementToolkit( +# root_dir=str(self.working_directory.name) +# ).get_tools() +# self.parser = ArcanOutputParser() +# self.bare_tools = load_tools( +# [ +# "llm-math", +# # "human", +# # "wolfram-alpha" +# ], +# llm=self.llm, +# ) +# self.agent_tools = ( +# self.tools + self.bare_tools # + self.sql_tools + self.file_system_tools +# ) +# self.llm_with_tools = self.llm.bind_tools(self.agent_tools) +# self.agent = ( +# { +# "input": lambda x: x["input"], +# "agent_scratchpad": lambda x: format_to_openai_tool_messages( +# x["intermediate_steps"] +# ), +# "chat_history": lambda x: x["chat_history"], +# } +# | self.prompt +# | self.llm_with_tools +# | self.parser +# ) +# self.agent_executor = AgentExecutor( +# agent=self.agent, tools=self.agent_tools, verbose=self.verbose +# ) + +# def get_response(self, user_content: str): +# """ +# Gets the response from the agent given user input. + +# Args: +# user_content (str): The user input. + +# Returns: +# str: The response from the agent. + +# """ +# # routed_content = semantic_layer(query=user_content, user_id=self.user_id) +# response = self.agent_executor.invoke( +# {"input": user_content, "chat_history": self.chat_history} +# ) +# self.chat_history.extend( +# [ +# HumanMessage(content=user_content), +# AIMessage(content=response["output"]), +# ] +# ) +# return response["output"] + + +# class ArcanSpellsAgent(ArcanAgent): +# """ +# Represents a Arcan Agent that interacts with the user and provides responses using OpenAI tools. + +# Attributes: +# llm (LLM): The Language Model Manager used by the agent. +# tools (list): The list of tools used by the agent. +# hub_prompt (str): The prompt for the OpenAI tools agent. +# agent_type (str): The type of the agent. +# chat_history (list): The chat history of the agent. +# llm_with_tools: The Language Model Manager with the tools bound. +# prompt: The chat prompt template for the agent. +# agent: The agent pipeline. +# agent_executor: The executor for the agent. +# user_id: The unique identifier for the user. +# verbose: A boolean indicating whether to print verbose output. + +# Methods: +# get_response: Gets the response from the agent given user input. + +# """ + +# def __init__( +# self, +# # database: SQLDatabase, +# llm: LLM = LLM().llm, +# tools: list = spells, +# prompt: str = spells_agent_prompt, +# agent_type="arcan_spells_agent", +# context: list = [], # represents the chat history, can be pulled from a db +# user_id: str = None, +# verbose: bool = False, +# ): +# self.llm: LLM = llm +# self.tools: list = tools +# self.agent_type: str = agent_type +# self.chat_history: list = context +# self.user_id: str = user_id +# self.verbose: bool = verbose +# # self.database = database +# # self.toolkit = SQLDatabaseToolkit(db=database, llm=self.llm) +# # self.context = self.toolkit.get_context() +# # self.sql_tools = self.toolkit.get_tools() +# self.prompt = prompt # arcan_prompt.partial(**self.context) +# self.working_directory = TemporaryDirectory() +# self.file_system_tools = FileManagementToolkit( +# root_dir=str(self.working_directory.name) +# ).get_tools() +# self.parser = ToolsAgentOutputParser() +# self.bare_tools = load_tools( +# [ +# "llm-math", +# # "human", +# # "wolfram-alpha" +# ], +# llm=self.llm, +# ) +# self.agent_tools = ( +# self.tools + self.bare_tools # + self.sql_tools + self.file_system_tools +# ) +# self.llm_with_tools = self.llm.bind_tools(self.agent_tools) +# # Construct the Tools agent +# # self.agent = create_tool_calling_agent(self.llm, self.agent_tools, self.prompt) +# self.agent = ( +# { +# "input": lambda x: x["input"], +# "agent_scratchpad": lambda x: format_to_openai_tool_messages( +# x["intermediate_steps"] +# ), +# "chat_history": lambda x: x["chat_history"], +# } +# | self.prompt +# | self.llm_with_tools +# | self.parser +# ) +# self.agent_executor = AgentExecutor( +# agent=self.agent, tools=self.agent_tools, verbose=self.verbose +# ) + +# def get_response(self, user_content: str): +# """ +# Gets the response from the agent given user input. + +# Args: +# user_content (str): The user input. + +# Returns: +# str: The response from the agent. + +# """ +# routed_content, route_text = semantic_layer( +# query=user_content, user_id=self.user_id +# ) +# response = self.agent_executor.invoke( +# {"input": routed_content, "chat_history": self.chat_history} +# ) +# self.chat_history.extend( +# [ +# AIMessage(content=route_text), +# HumanMessage(content=user_content), +# AIMessage(content=response["output"]), +# ] +# ) +# return response["output"] + + +# %% class ArcanConversationAgent: @@ -196,98 +934,6 @@ async def agent_chat(text: str, agent): # query: Query = Body(...),): # %% -from langchain.agents import AgentExecutor, create_tool_calling_agent - - -class ArcanSpellsAgent(ArcanAgent): - """ - Represents a Arcan Agent that interacts with the user and provides responses using OpenAI tools. - - Attributes: - llm (LLM): The Language Model Manager used by the agent. - tools (list): The list of tools used by the agent. - hub_prompt (str): The prompt for the OpenAI tools agent. - agent_type (str): The type of the agent. - chat_history (list): The chat history of the agent. - llm_with_tools: The Language Model Manager with the tools bound. - prompt: The chat prompt template for the agent. - agent: The agent pipeline. - agent_executor: The executor for the agent. - user_id: The unique identifier for the user. - verbose: A boolean indicating whether to print verbose output. - - Methods: - get_response: Gets the response from the agent given user input. - - """ - - def __init__( - self, - # database: SQLDatabase, - llm: LLM = LLM().llm, - tools: list = spells, - prompt: str = spells_agent_prompt, - agent_type="arcan_spells_agent", - context: list = [], # represents the chat history, can be pulled from a db - user_id: str = None, - verbose: bool = False, - ): - self.llm: LLM = llm - self.tools: list = tools - self.agent_type: str = agent_type - self.chat_history: list = context - self.user_id: str = user_id - self.verbose: bool = verbose - # self.database = database - # self.toolkit = SQLDatabaseToolkit(db=database, llm=self.llm) - # self.context = self.toolkit.get_context() - # self.sql_tools = self.toolkit.get_tools() - self.prompt = prompt # arcan_prompt.partial(**self.context) - self.working_directory = TemporaryDirectory() - self.file_system_tools = FileManagementToolkit( - root_dir=str(self.working_directory.name) - ).get_tools() - self.parser = OpenAIToolsAgentOutputParser() - self.bare_tools = load_tools( - [ - "llm-math", - # "human", - # "wolfram-alpha" - ], - llm=self.llm, - ) - self.agent_tools = ( - self.tools + self.bare_tools # + self.sql_tools + self.file_system_tools - ) - self.llm_with_tools = self.llm.bind_tools(self.agent_tools) - # Construct the Tools agent - self.agent = create_tool_calling_agent(self.llm, self.agent_tools, self.prompt) - self.agent_executor = AgentExecutor( - agent=self.agent, tools=self.agent_tools, verbose=self.verbose - ) - - def get_response(self, user_content: str): - """ - Gets the response from the agent given user input. - - Args: - user_content (str): The user input. - - Returns: - str: The response from the agent. - - """ - # routed_content = semantic_layer(query=user_content, user_id=self.user_id) - response = self.agent_executor.invoke( - {"input": user_content, "chat_history": self.chat_history} - ) - self.chat_history.extend( - [ - HumanMessage(content=user_content), - AIMessage(content=response["output"]), - ] - ) - return response["output"] # %% diff --git a/arcan/ai/agents/session.py b/arcan/ai/agents/session.py new file mode 100644 index 0000000..7e25fec --- /dev/null +++ b/arcan/ai/agents/session.py @@ -0,0 +1,154 @@ +# %% + +import ast +import os +import pickle +import weakref +from datetime import datetime +from typing import Any, Dict + +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.orm import Session, joinedload + +from arcan.datamodel.chat_history import ChatHistory +from arcan.datamodel.conversation import Conversation +from arcan.datamodel.engine import SessionLocal + + +class ArcanSession: + def __init__(self, database: callable = SessionLocal): + self.database = database + self.database_uri = os.environ.get("SQLALCHEMY_URL") + self.agents: Dict[str, weakref.ref] = weakref.WeakValueDictionary() + + def _get_session(self) -> Session: + if self.database is None: + raise ValueError("Database factory is not initialized.") + return self.database() + + def store_message(self, user_id: str, body: str, response: str): + with self._get_session() as db_session: + conversation = Conversation(sender=user_id, message=body, response=response) + db_session.add(conversation) + db_session.commit() + print(f"Conversation #{conversation.id} stored in database") + + def store_chat_history(self, user_id, agent_history): + history = pickle.dumps(agent_history) + stmt = ( + insert(ChatHistory) + .values( + sender=user_id, + history=str(history), + updated_at=datetime.utcnow(), + ) + .on_conflict_do_update( + index_elements=["sender"], + set_={ + "history": str(history), + "updated_at": datetime.utcnow(), + }, + ) + ) + with self._get_session() as db: + db.execute(stmt) + db.commit() + print(f"Upsert chat history for user {user_id} with statement {stmt}") + + def get_chat_history(self, user_id: str) -> list: + with self._get_session() as db_session: + history = ( + db_session.query(ChatHistory) + # .options(joinedload(ChatHistory.history)) + .filter(ChatHistory.sender == user_id) + .order_by(ChatHistory.updated_at.asc()) + .all() + ) or [] + if not history: + return [] + chat_history = history[0].history + loaded = pickle.loads(ast.literal_eval(chat_history)) + return loaded + + def rollback(self): + with self._get_session() as db: + db.rollback() + print("Rollback transaction") + + +# class ArcanSession: +# def __init__(self, database: Session = None): +# """ +# Initializes a new instance of the ArcanSession class. + +# :param database: A callable that returns a new SQLAlchemy Session instance when called. +# """ +# self.database = database +# self.database_uri = os.environ.get("SQLALCHEMY_URL") +# self.agents: Dict[str, weakref.ref] = weakref.WeakValueDictionary() + +# def store_message(self, user_id: str, body: str, response: str): +# """ +# Stores a message in the database. + +# :param user_id: The unique identifier for the user. +# :param Body: The body of the message sent by the user. +# :param response: The response generated by the system. +# """ +# with self.database as db_session: +# conversation = Conversation(sender=user_id, message=body, response=response) +# db_session.add(conversation) +# db_session.commit() +# print(f"Conversation #{conversation.id} stored in database") + +# def store_chat_history(self, user_id, agent_history): +# """ +# Stores or updates the chat history for a user in the database. + +# :param user_id: The unique identifier for the user. +# :param agent_history: The chat history to be stored. +# """ +# history = pickle.dumps(agent_history) +# # Upsert statement +# stmt = ( +# insert(ChatHistory) +# .values( +# sender=user_id, +# history=str(history), +# updated_at=datetime.utcnow(), # Explicitly set updated_at on insert +# ) +# .on_conflict_do_update( +# index_elements=["sender"], # Specify the conflict target +# set_={ +# "history": str(history), # Update the history field upon conflict +# "updated_at": datetime.utcnow(), # Update the updated_at field upon conflict +# }, +# ) +# ) +# # Execute the upsert +# with self.database as db: +# db.execute(stmt) +# db.commit() +# print(f"Upsert chat history for user {user_id} with statement {stmt}") + +# def get_chat_history(self, user_id: str) -> list: +# """ +# Retrieves the chat history for a user from the database. + +# :param db_session: The SQLAlchemy Session instance. +# :param user_id: The unique identifier for the user. +# :return: A list representing the chat history. +# """ +# with self.database as db_session: +# history = ( +# db_session.query(ChatHistory) +# .filter(ChatHistory.sender == user_id) +# .order_by(ChatHistory.updated_at.asc()) +# .all() +# ) or [] +# if not history: +# return [] +# chat_history = history[0].history +# loaded = pickle.loads(ast.literal_eval(chat_history)) +# return loaded +# %% diff --git a/arcan/ai/llm/__init__.py b/arcan/ai/llm/__init__.py index 581a627..6ed71d3 100644 --- a/arcan/ai/llm/__init__.py +++ b/arcan/ai/llm/__init__.py @@ -3,6 +3,7 @@ import os from typing import Any, Callable, Dict, List, Optional, Union +from langchain_community.chat_models import ChatOllama from langchain_groq import ChatGroq from langchain_openai import ChatOpenAI, OpenAI from pydantic import BaseModel @@ -69,6 +70,9 @@ class LLMFactory: os.getenv("TOGETHER_MODEL_NAME", "llama3-8b-8192"), ), ), + "ChatOllama": lambda **kwargs: ChatOllama( + model=kwargs.get("model", os.getenv("OLLAMA_MODEL", "phi3")), + ), } @staticmethod diff --git a/arcan/ai/router/__init__.py b/arcan/ai/router/__init__.py index dcd3442..2e26396 100644 --- a/arcan/ai/router/__init__.py +++ b/arcan/ai/router/__init__.py @@ -63,11 +63,12 @@ def get_response(self, query: str, user_id: str) -> str: f"Matched route: {route.name}, using strategy: {strategy.__class__.__name__}" ) route_text = strategy.execute(query=query, user_id=user_id) - query += f" (SYSTEM NOTE: {route_text})" - return query + route_template = f" (SYSTEM NOTE: {route_text})" + query += route_template + return route_text, query else: print(f"No route found for query: {query}") - return query + return "No Router Matched", query # Initialize RouteManager with an encoder @@ -89,5 +90,5 @@ def semantic_layer(query: str, user_id: str = None): Returns: str: The response generated by the route manager. """ - response = route_manager.get_response(query=query, user_id=user_id) - return response + response, route_text = route_manager.get_response(query=query, user_id=user_id) + return response, route_text diff --git a/arcan/ai/runnables/__init__.py b/arcan/ai/runnables/__init__.py index cd45168..1c90fcb 100644 --- a/arcan/ai/runnables/__init__.py +++ b/arcan/ai/runnables/__init__.py @@ -1,6 +1,8 @@ # %% -from langchain.prompts import ChatPromptTemplate +from langchain.agents import AgentExecutor from langchain_core.runnables import Runnable +from langchain_groq import ChatGroq +from langchain_openai import ChatOpenAI from langserve import RemoteRunnable @@ -9,7 +11,7 @@ def __init__(self, base_url: str = "http://localhost:8000/"): self.base_url = base_url self.runnable_cache = {} - def get_runnable(self, runnable_name: str, cache: bool = True) -> Runnable: + def get_runnable(self, runnable_name: str, cache: bool = True) -> RemoteRunnable: if cache and runnable_name in self.runnable_cache: return self.runnable_cache[runnable_name] @@ -23,40 +25,23 @@ class ArcanRunnables: def __init__(self, base_url: str = "http://localhost:8000/"): self.factory = RunnableFactory(base_url=base_url) - def get_chat_spells_agent_runnable(self): - return self.factory.get_runnable(runnable_name="spells_agent") + def get_spells_runnable(self) -> AgentExecutor: + return self.factory.get_runnable(runnable_name="spells") - def get_openai_runnable(self): + def get_openai_runnable(self) -> ChatOpenAI: return self.factory.get_runnable(runnable_name="openai") - def get_groq_runnable(self): + def get_groq_runnable(self) -> ChatGroq: return self.factory.get_runnable(runnable_name="groq") + def get_ollama_runnable(self) -> AgentExecutor: + return self.factory.get_runnable(runnable_name="ollama") -# %% - - -# from langchain.schema import HumanMessage, SystemMessage -# from langchain.schema.runnable import RunnableMap - -# arcan_runnables = ArcanRunnables(base_url="http://localhost:8000/") -# chat_spells_agent = arcan_runnables.get_chat_spells_agent_runnable() -# openai_runnable = arcan_runnables.get_openai_runnable() -# groq_runnable = arcan_runnables.get_groq_runnable() - - -# prompt = ChatPromptTemplate.from_messages( -# [("system", "Tell me a long story about {topic}")] -# ) - -# # Can define custom chains -# chain = prompt | RunnableMap({ -# "openai": openai_runnable, -# "groq": groq_runnable, -# }) -# # %% + def get_auth_spells_runnable(self) -> AgentExecutor: + return self.factory.get_runnable(runnable_name="auth_spells") -# chain.batch([{"topic": "parrots"}, {"topic": "cats"}]) + def get_chain_with_history_runnable(self) -> AgentExecutor: + return self.factory.get_runnable(runnable_name="chain_with_history") # %% diff --git a/arcan/api/__init__.py b/arcan/api/__init__.py index 0efc89b..b07467d 100644 --- a/arcan/api/__init__.py +++ b/arcan/api/__init__.py @@ -1,9 +1,12 @@ # %% +import os +import re from datetime import datetime, timedelta, timezone -from typing import Annotated, Any, Dict, List, Optional, Union +from pathlib import Path +from typing import Annotated, Any, Callable, Dict, List, Optional, Union from dotenv import load_dotenv -from fastapi import Depends, FastAPI, Form, HTTPException, Request, status +from fastapi import Depends, FastAPI, Form, Header, HTTPException, Request, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse @@ -14,43 +17,59 @@ OAuth2PasswordBearer, OAuth2PasswordRequestForm, ) +from langchain_community.chat_message_histories import FileChatMessageHistory +from langchain_core import __version__ +from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import AIMessage, FunctionMessage, HumanMessage +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain_core.runnables import ConfigurableField, ConfigurableFieldSpec +from langchain_core.runnables.history import RunnableWithMessageHistory +from langchain_openai import ChatOpenAI from langserve import add_routes from langserve.pydantic_v1 import BaseModel, Field from pydantic import BaseModel from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session -from typing_extensions import Annotated +from typing_extensions import Annotated, TypedDict -from arcan.ai.agents import ArcanSpellsAgent +from arcan.ai.agents import ArcanAgent from arcan.ai.llm import LLM -from arcan.api.datamodel import get_db, get_db_context -from arcan.api.datamodel.chat_history import ChatHistory -from arcan.api.datamodel.conversation import Conversation -from arcan.api.datamodel.user import ( +from arcan.api.auth import fetch_session_from_header +from arcan.datamodel.engine import session_scope # , session_scope_context +from arcan.datamodel.user import ( ACCESS_TOKEN_EXPIRE_MINUTES, TokenModel, - User, - UserInDB, UserModel, UserRepository, UserService, oauth2_scheme, pwd_context, ) -from arcan.api.session import ArcanSession, run_agent -# from arcan.api.session.auth import requires_auth -from arcan.spells.vector_search import ( - get_per_user_retriever, - per_req_config_modifier, - pgVectorStore, -) +# from arcan.spells.vector_search import (get_per_user_retriever, +# per_req_config_modifier, pgVectorStore) + +# %% +MIN_VERSION_LANGCHAIN_CORE = (0, 1, 0) + +# Split the version string by "." and convert to integers +LANGCHAIN_CORE_VERSION = tuple(map(int, __version__.split("."))) + +if LANGCHAIN_CORE_VERSION < MIN_VERSION_LANGCHAIN_CORE: + raise RuntimeError( + f"Minimum required version of langchain-core is {MIN_VERSION_LANGCHAIN_CORE}, " + f"but found {LANGCHAIN_CORE_VERSION}" + ) + +# %% auth_scheme = HTTPBearer() load_dotenv() +ENVIRONMENT = os.environ.get("ENVIRONMENT") +ARCANAI_API_TOKEN = os.environ.get("ARCANAI_API_TOKEN") app = FastAPI() @@ -76,57 +95,88 @@ async def index(): return {"message": "Arcan is Running!"} -# %% +# @requires_auth +@app.get("/api/chat") +async def chat( + user_id: str, + query: str, + # current_user: Annotated[UserModel, Depends(get_current_active_user_from_request)], + db: Session = Depends(session_scope), +): + if ENVIRONMENT == "cloud": + # from arcan.api.session import ArcanSession, run_agent + agent = ArcanAgent(user_id=user_id) + # user = await get_current_active_user_from_request(request=Request) + response = agent.invoke({"input": query}) + elif ENVIRONMENT == "local": + agent = ArcanAgent( + user_id=user_id, + ) + response = agent.invoke({"input": query, "chat_history": []}) + return {"response": response} class Input(BaseModel): input: str - chat_history: List[Union[HumanMessage, AIMessage, FunctionMessage]] = Field( - ..., - extra={"widget": {"type": "chat", "input": "input", "output": "output"}}, - ) class Output(BaseModel): output: Any +dynamic_spells_model = ( + ArcanAgent() + .configurable_fields( + user_id=ConfigurableField( + id="user_id", + name="Arcan AI User ID", + description=("user_id Key for Arcan AI interactions"), + ) + ) + .with_types(input_type=Input, output_type=Output) +) + add_routes( app=app, - runnable=ArcanSpellsAgent( - # database=SQLDatabase.from_uri(os.environ.get("SQLALCHEMY_URL")) - ) - .agent_executor.with_types(input_type=Input, output_type=Output) - .with_config({"run_name": "agent"}), - path="/spells_agent", - enable_feedback_endpoint=True, + runnable=dynamic_spells_model, + per_req_config_modifier=fetch_session_from_header, + path="/spells", ) add_routes( app, LLM(provider="ChatOpenAI").llm, - per_req_config_modifier=per_req_config_modifier, path="/openai", + per_req_config_modifier=fetch_session_from_header, ) + add_routes( app, LLM(provider="ChatGroq").llm, - per_req_config_modifier=per_req_config_modifier, + per_req_config_modifier=fetch_session_from_header, path="/groq", ) add_routes( app, LLM(provider="ChatTogetherAI").llm, + per_req_config_modifier=fetch_session_from_header, path="/together", ) +add_routes( + app, + runnable=LLM(provider="ChatOllama").llm, + per_req_config_modifier=fetch_session_from_header, + path="/ollama", +) + @app.post("/token") async def login_for_access_token( form_data: Annotated[OAuth2PasswordRequestForm, Depends()], - session: Session = Depends(get_db), + session: Session = Depends(session_scope), ) -> TokenModel: user_repo = UserRepository(session) user_interface = UserService(user_repository=user_repo, pwd_context=pwd_context) @@ -142,7 +192,6 @@ async def login_for_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) return TokenModel( - id=1, access_token=access_token, token_type="bearer", user_id=user.username, @@ -151,12 +200,13 @@ async def login_for_access_token( async def get_current_active_user_from_request( - request: Request, session: Session = Depends(get_db) + request: Request, session: Session = Depends(session_scope) ) -> UserModel: """Get the current active user from the request.""" user_repo = UserRepository(session) user_interface = UserService(user_repository=user_repo, pwd_context=pwd_context) token = await oauth2_scheme(request) + print(token) user = user_interface.get_current_user(token=token) if not user: raise HTTPException( @@ -164,40 +214,189 @@ async def get_current_active_user_from_request( detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"}, ) - if user.disabled: - raise HTTPException(status_code=400, detail="Inactive user") + # if user.disabled: + # raise HTTPException(status_code=400, detail="Inactive user") return user -@app.get("/users/me/", response_model=UserModel) -async def read_users_me( - current_user: Annotated[UserModel, Depends(get_current_active_user_from_request)], -): - return current_user +# @app.get("/users/me/", response_model=UserModel) +# async def read_users_me( +# current_user: Annotated[UserModel, Depends(get_current_active_user_from_request)], +# ): +# return current_user -add_routes( - app, - get_per_user_retriever(vectorstore=pgVectorStore().get_vector_store()), - per_req_config_modifier=per_req_config_modifier, - enabled_endpoints=["invoke"], -) +# add_routes( +# app, +# get_per_user_retriever(vectorstore=pgVectorStore().get_vector_store()), +# per_req_config_modifier=per_req_config_modifier, +# enabled_endpoints=["invoke"], +# ) # %% +# def create_session_factory( +# base_dir: Union[str, Path], +# ) -> Callable[[str], BaseChatMessageHistory]: +# """Create a factory that can retrieve chat histories. + +# The chat histories are keyed by user ID and conversation ID. + +# Args: +# base_dir: Base directory to use for storing the chat histories. + +# Returns: +# A factory that can retrieve chat histories keyed by user ID and conversation ID. +# """ +# base_dir_ = Path(base_dir) if isinstance(base_dir, str) else base_dir +# if not base_dir_.exists(): +# base_dir_.mkdir(parents=True) + +# def get_chat_history(user_id: str, conversation_id: str) -> FileChatMessageHistory: +# """Get a chat history from a user id and conversation id.""" +# if not _is_valid_identifier(user_id): +# raise ValueError( +# f"User ID {user_id} is not in a valid format. " +# "User ID must only contain alphanumeric characters, " +# "hyphens, and underscores." +# "Please include a valid cookie in the request headers called 'user-id'." +# ) +# if not _is_valid_identifier(conversation_id): +# raise ValueError( +# f"Conversation ID {conversation_id} is not in a valid format. " +# "Conversation ID must only contain alphanumeric characters, " +# "hyphens, and underscores. Please provide a valid conversation id " +# "via config. For example, " +# "chain.invoke(.., {'configurable': {'conversation_id': '123'}})" +# ) + +# user_dir = base_dir_ / user_id +# if not user_dir.exists(): +# user_dir.mkdir(parents=True) +# file_path = user_dir / f"{conversation_id}.json" +# return FileChatMessageHistory(str(file_path)) + +# return get_chat_history + + +# def _per_request_config_modifier( +# config: Dict[str, Any], request: Request +# ) -> Dict[str, Any]: +# """Update the config""" +# config = config.copy() +# configurable = config.get("configurable", {}) +# # Look for a cookie named "user_id" +# user_id = request.cookies.get("user_id", None) + +# if user_id is None: +# raise HTTPException( +# status_code=400, +# detail="No user id found. Please set a cookie named 'user_id'.", +# ) + +# configurable["user_id"] = user_id +# config["configurable"] = configurable +# return config + + +# # Declare a chain +# prompt = ChatPromptTemplate.from_messages( +# [ +# ("system", "You're an assistant by the name of Bob."), +# MessagesPlaceholder(variable_name="history"), +# ("human", "{human_input}"), +# ] +# ) + +# chain = prompt | ChatOpenAI() + + +# class InputChat(TypedDict): +# """Input for the chat endpoint.""" + +# human_input: str +# """Human input""" + + +# chain_with_history = RunnableWithMessageHistory( +# chain, +# create_session_factory("chat_histories"), +# input_messages_key="human_input", +# history_messages_key="history", +# history_factory_config=[ +# ConfigurableFieldSpec( +# id="user_id", +# annotation=str, +# name="User ID", +# description="Unique identifier for the user.", +# default="", +# is_shared=True, +# ), +# ConfigurableFieldSpec( +# id="conversation_id", +# annotation=str, +# name="Conversation ID", +# description="Unique identifier for the conversation.", +# default="", +# is_shared=True, +# ), +# ], +# ).with_types(input_type=InputChat) + + +# add_routes( +# app, +# chain_with_history, +# per_req_config_modifier=_per_request_config_modifier, +# # Disable playground and batch +# # 1) Playground we're passing information via headers, which is not supported via +# # the playground right now. +# # 2) Disable batch to avoid users being confused. Batch will work fine +# # as long as users invoke it with multiple configs appropriately, but +# # without validation users are likely going to forget to do that. +# # In addition, there's likely little sense in support batch for a chatbot. +# disabled_endpoints=["playground", "batch"], +# path="/chain_with_history", +# ) + + +# def _per_request_session_modifier( +# config: Dict[str, Any], request: Request +# ) -> Dict[str, Any]: +# """Update the config""" +# config = config.copy() +# configurable = config.get("configurable", {}) +# # Look for a cookie named "user_id" +# user_id = request.cookies.get("user_id", None) + +# if user_id is None: +# raise HTTPException( +# status_code=400, +# detail="No user id found. Please set a cookie named 'user_id'.", +# ) + +# agent = ArcanAgent(user_id=user_id) + +# configurable["user_id"] = user_id +# config["configurable"] = configurable +# return config, agent + +# add_routes( +# app, +# ArcanAgent(), +# path="/auth_spells", +# per_req_config_modifier=_per_request_session_modifier, +# # Disable playground and batch +# # 1) Playground we're passing information via headers, which is not supported via +# # the playground right now. +# # 2) Disable batch to avoid users being confused. Batch will work fine +# # as long as users invoke it with multiple configs appropriately, but +# # without validation users are likely going to forget to do that. +# # In addition, there's likely little sense in support batch for a chatbot. +# disabled_endpoints=["playground", "batch"], +# ) -# @requires_auth -@app.get("/api/chat") -async def chat( - user_id: str, - query: str, - current_user: Annotated[UserModel, Depends(get_current_active_user_from_request)], - db: Session = Depends(get_db), -): - arcan_session = ArcanSession(db) - response = run_agent(session=arcan_session, user_id=current_user, query=query) - return {"response": response} - +# %% if __name__ == "__main__": import uvicorn diff --git a/arcan/api/auth.py b/arcan/api/auth.py new file mode 100644 index 0000000..bcffc8c --- /dev/null +++ b/arcan/api/auth.py @@ -0,0 +1,38 @@ +import re +from typing import Any, Dict + +from fastapi import HTTPException, Request + +# def fetch_api_key_from_header(config: Dict[str, Any], req: Request) -> Dict[str, Any]: +# if "x-api-key" in req.headers: +# config["configurable"]["openai_api_key"] = req.headers["x-api-key"] +# if "user_id" in req.headers: +# config["configurable"]["user_id"] = req.headers["user_id"] +# else: +# raise HTTPException(401, "No User ID provided") +# else: +# raise HTTPException(401, "No API key provided") + +# return config + + +def fetch_session_from_header(config: Dict[str, Any], req: Request) -> Dict[str, Any]: + config = config.copy() + configurable = config.get("configurable", {}) + + if "arcanai_api_key" in req.headers: + if "user_id" in req.headers: + configurable["user_id"] = req.headers["user_id"] + config["configurable"] = configurable + # config["configurable"]["user_id"] = req.headers["user_id"] + # config["configurable"] = configurable + else: + raise HTTPException(401, "No Arcan AI API key provided") + return config + + +def _is_valid_identifier(value: str) -> bool: + """Check if the value is a valid identifier.""" + # Use a regular expression to match the allowed characters + valid_characters = re.compile(r"^[a-zA-Z0-9-_]+$") + return bool(valid_characters.match(value)) diff --git a/arcan/api/datamodel/__init__.py b/arcan/api/datamodel/__init__.py deleted file mode 100644 index 67d3f41..0000000 --- a/arcan/api/datamodel/__init__.py +++ /dev/null @@ -1,53 +0,0 @@ -# %% -import os -from contextlib import contextmanager - -from dotenv import load_dotenv -from sqlalchemy import create_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker - -load_dotenv() - -# %% - - -DATABASE_URL = str(os.environ.get("SQLALCHEMY_URL")) -print(DATABASE_URL) -assert DATABASE_URL is not None, "SQLALCHEMY_URL environment variable not found" - -engine = create_engine( - DATABASE_URL -) # Oddly requires the hard coded string or else fails to connect -SessionLocal = sessionmaker(bind=engine) -Base = declarative_base() - - -def get_db(): - """ - Returns a database session. - - Yields: - SessionLocal: The database session. - - """ - try: - db = SessionLocal() - yield db - finally: - db.close() - - -@contextmanager -def get_db_context(): - """ - Context manager wrapper for the get_db generator. - """ - try: - db = next(get_db()) # Get the session from the generator - yield db - finally: - db.close() - - -# %% diff --git a/arcan/api/session/__init__.py b/arcan/api/session/__init__.py deleted file mode 100644 index 956c79a..0000000 --- a/arcan/api/session/__init__.py +++ /dev/null @@ -1,151 +0,0 @@ -import ast -import os -import pickle -import weakref -from datetime import datetime -from typing import Dict - -from fastapi import Depends -from langchain.sql_database import SQLDatabase -from sqlalchemy.dialects.postgresql import insert -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import Session - -from arcan.ai.agents import ArcanAgent, ArcanSpellsAgent -from arcan.api.datamodel.chat_history import ChatHistory -from arcan.api.datamodel.conversation import Conversation - - -class ArcanSession: - def __init__(self, database: Session): - """ - Initializes a new instance of the ArcanSession class. - - :param database: A callable that returns a new SQLAlchemy Session instance when called. - """ - self.database = database - self.database_uri = os.environ.get("SQLALCHEMY_URL") - self.agents: Dict[str, weakref.ref] = weakref.WeakValueDictionary() - - def get_or_create_agent( - self, user_id: str, provided_agent: ArcanAgent = None - ) -> ArcanAgent: - """ - Retrieves or creates a ArcanAgent for a given user_id. - - :param user_id: The unique identifier for the user. - :return: An instance of ArcanAgent. - """ - if provided_agent is None: - agent = self.agents.get(user_id) - chat_history = [] - - # Obtain a new database session - try: - chat_history = self.get_chat_history(user_id) - except Exception as e: - print(f"Error getting chat history for {user_id}: {e}") - - if agent is not None and chat_history: - print(f"Using existing agent {agent}") - elif agent is None and chat_history: - print(f"Using reloaded agent with history {chat_history}") - agent = ArcanSpellsAgent( - context=chat_history, - user_id=user_id, - # database=SQLDatabase.from_uri(self.database_uri) - ) # Initialize with chat history - elif agent is None and not chat_history: - print("Using a new agent") - agent = ArcanSpellsAgent( - user_id=user_id, - ) - # database=SQLDatabase.from_uri(self.database_uri)) # Initialize without chat history - - self.agents[user_id] = agent - return agent - - else: - provided_agent.user_id = user_id - self.agents[user_id] = provided_agent - return provided_agent - - def store_message(self, user_id: str, body: str, response: str): - """ - Stores a message in the database. - - :param user_id: The unique identifier for the user. - :param Body: The body of the message sent by the user. - :param response: The response generated by the system. - """ - with self.database as db_session: - conversation = Conversation(sender=user_id, message=body, response=response) - db_session.add(conversation) - db_session.commit() - print(f"Conversation #{conversation.id} stored in database") - - def store_chat_history(self, user_id, agent_history): - """ - Stores or updates the chat history for a user in the database. - - :param user_id: The unique identifier for the user. - :param agent_history: The chat history to be stored. - """ - history = pickle.dumps(agent_history) - # Upsert statement - stmt = ( - insert(ChatHistory) - .values( - sender=user_id, - history=str(history), - updated_at=datetime.utcnow(), # Explicitly set updated_at on insert - ) - .on_conflict_do_update( - index_elements=["sender"], # Specify the conflict target - set_={ - "history": str(history), # Update the history field upon conflict - "updated_at": datetime.utcnow(), # Update the updated_at field upon conflict - }, - ) - ) - # Execute the upsert - with self.database as db: - db.execute(stmt) - db.commit() - print(f"Upsert chat history for user {user_id} with statement {stmt}") - - def get_chat_history(self, user_id: str) -> list: - """ - Retrieves the chat history for a user from the database. - - :param db_session: The SQLAlchemy Session instance. - :param user_id: The unique identifier for the user. - :return: A list representing the chat history. - """ - with self.database as db_session: - history = ( - db_session.query(ChatHistory) - .filter(ChatHistory.sender == user_id) - .order_by(ChatHistory.updated_at.asc()) - .all() - ) or [] - if not history: - return [] - chat_history = history[0].history - loaded = pickle.loads(ast.literal_eval(chat_history)) - return loaded - - -def run_agent(session: ArcanSession, query: str, user_id: str) -> Dict[str, str]: - print(f"Sending the LangChain response to user: {user_id}") - agent = session.get_or_create_agent(user_id) - # Get the generated text from the LangChain agent - response = agent.get_response(user_content=query) - # Store the conversation in the database - try: - session.store_message(user_id=user_id, body=query, response=response) - session.store_chat_history(user_id=user_id, agent_history=agent.chat_history) - except SQLAlchemyError as e: - session.database.rollback() - print(f"Error storing conversation in database: {e}") - return response diff --git a/arcan/api/session/auth.py b/arcan/api/session/auth.py deleted file mode 100644 index 5b2b795..0000000 --- a/arcan/api/session/auth.py +++ /dev/null @@ -1,50 +0,0 @@ -import os -from functools import wraps - -from fastapi import HTTPException, Request, status -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer - -security = HTTPBearer() - - -def requires_auth(func): - @wraps(func) - def wrapper(*args, token: HTTPAuthorizationCredentials = security, **kwargs): - if token.credentials != os.environ["ARCAN_API_KEY"]: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect bearer token", - headers={"WWW-Authenticate": "Bearer"}, - ) - - return func(*args, **kwargs) - - return wrapper - - -def aio_requires_auth(func): - @wraps(func) - async def wrapper(*args, token: HTTPAuthorizationCredentials = None, **kwargs): - if token is None or token.credentials != os.environ["ARCAN_API_KEY"]: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect bearer token", - headers={"WWW-Authenticate": "Bearer"}, - ) - - return await func(*args, **kwargs) - - return wrapper - - -def log_endpoint(func): - @wraps(func) - def wrapper(request: Request, *args, **kwargs): - client_host = request.client.host - client_user_agent = request.headers.get("user-agent") - print( - f"Endpoint hit with query: {kwargs['query']}, context_url: {kwargs['context_url']}, client_host: {client_host}, client_user_agent: {client_user_agent}" - ) - return func(request, *args, **kwargs) - - return wrapper diff --git a/arcan/datamodel/__init__.py b/arcan/datamodel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/arcan/api/datamodel/chat_history.py b/arcan/datamodel/chat_history.py similarity index 92% rename from arcan/api/datamodel/chat_history.py rename to arcan/datamodel/chat_history.py index dcd43c0..ebf8c3c 100644 --- a/arcan/api/datamodel/chat_history.py +++ b/arcan/datamodel/chat_history.py @@ -2,7 +2,7 @@ from sqlalchemy import Column, DateTime, String, Text -from arcan.api.datamodel import Base, engine +from arcan.datamodel.engine import Base, engine Base.metadata.create_all(engine) diff --git a/arcan/api/datamodel/conversation.py b/arcan/datamodel/conversation.py similarity index 93% rename from arcan/api/datamodel/conversation.py rename to arcan/datamodel/conversation.py index face889..34e97d7 100644 --- a/arcan/api/datamodel/conversation.py +++ b/arcan/datamodel/conversation.py @@ -2,7 +2,7 @@ from sqlalchemy import Column, DateTime, Integer, String -from arcan.api.datamodel import Base, engine +from arcan.datamodel.engine import Base, engine Base.metadata.create_all(engine) diff --git a/arcan/datamodel/engine.py b/arcan/datamodel/engine.py new file mode 100644 index 0000000..0aed7e8 --- /dev/null +++ b/arcan/datamodel/engine.py @@ -0,0 +1,87 @@ +import os +from contextlib import contextmanager + +from dotenv import load_dotenv +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +load_dotenv() + + +class Config: + DATABASE_URL = os.getenv("SQLALCHEMY_URL") + ENVIRONMENT = os.getenv("ENVIRONMENT") + + +class EngineFactory: + def __init__(self): + self.engines = {"local": self.local_engine, "cloud": self.cloud_engine} + + def get_engine(self): + # Fetch the appropriate engine creation method from the dictionary + engine_type = ( + Config.ENVIRONMENT or "cloud" + ) # Default to 'cloud' if not specified + engine_creator = self.engines.get( + engine_type, self.cloud_engine + ) # Fallback to cloud engine + return engine_creator() + + def local_engine(self): + """Create a local SQLite engine""" + return create_engine("sqlite:////arcan.db") + + def cloud_engine(self): + """Create a cloud engine from a URL in the config""" + if not Config.DATABASE_URL: + raise ValueError("No database URL provided for cloud environment.") + return create_engine(Config.DATABASE_URL) + + +factory = EngineFactory() +engine = factory.get_engine() + +SessionLocal = sessionmaker(bind=engine) +Base = declarative_base() + + +@contextmanager +def session_scope(): + """Provide a transactional scope around a series of operations.""" + session = SessionLocal() + try: + yield session + session.commit() + except: + session.rollback() + raise + finally: + session.close() + + +# def get_db(): +# """ +# Returns a database session. + +# Yields: +# SessionLocal: The database session. + +# """ +# try: +# db = SessionLocal() +# yield db +# finally: +# db.close() + + +# @contextmanager +# def get_db_context(): +# """ +# Context manager wrapper for the get_db generator. +# """ +# try: +# db = next(get_db()) # Get the session from the generator +# yield db +# finally: +# db.close() diff --git a/arcan/api/datamodel/user.py b/arcan/datamodel/user.py similarity index 98% rename from arcan/api/datamodel/user.py rename to arcan/datamodel/user.py index 9a14753..f6b97d4 100644 --- a/arcan/api/datamodel/user.py +++ b/arcan/datamodel/user.py @@ -11,7 +11,7 @@ from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text from sqlalchemy.orm import Session, relationship -from arcan.api.datamodel import Base, engine +from arcan.datamodel.engine import Base, engine pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -23,7 +23,7 @@ # to get a string like this run: # openssl rand -hex 32 -SECRET_KEY = os.environ.get("ARCAN_API_KEY") +SECRET_KEY = os.environ.get("ARCANAI_API_KEY") ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 30 diff --git a/notebooks/runnables.ipynb b/notebooks/runnables.ipynb new file mode 100644 index 0000000..f5e6956 --- /dev/null +++ b/notebooks/runnables.ipynb @@ -0,0 +1,254 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from arcan.ai.runnables import ArcanRunnables\n", + "\n", + "arcan_runnables = ArcanRunnables(base_url=\"http://localhost:8000/\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Client error '401 Unauthorized' for url 'http://localhost:8000/spells/invoke'\n", + "For more information check: https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/401 for {\"detail\":\"No Arcan AI API key provided\"}\n" + ] + } + ], + "source": [ + "# Requires Arcan AI API key header\n", + "import pytest\n", + "from httpx import HTTPStatusError\n", + "# Requires Arcan AI API key header\n", + "spells_runnable = arcan_runnables.get_spells_runnable()\n", + "\n", + "# Assert that spells_runnable.invoke raises HTTPStatusError\n", + "with pytest.raises(HTTPStatusError) as exc_info:\n", + " spells_runnable.invoke({'input': 'hi'})\n", + " assert str(exc_info.value) == \"Client error '401 Unauthorized'\"\n", + "print(exc_info.value)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'output': 'Your name is Carlos. How can I assist you today, Carlos?'}" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langserve import RemoteRunnable\n", + "\n", + "spells_runnable = RemoteRunnable(\"http://localhost:8000/spells/\", headers={\"arcanai_api_key\": '1234', \"user_id\": \"broomva\"})\n", + "response = spells_runnable.invoke({\"input\": \"hi there, whats my name?\"},config={\n", + " \"configurable\": {\"user_id\": \"broomva\"},\n", + " })\n", + "response" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[{'openai': AIMessage(content='Parrots are known for their ability to mimic human speech and other sounds, making them popular pets and performers in circuses and shows.', response_metadata={'token_usage': {'completion_tokens': 27, 'prompt_tokens': 17, 'total_tokens': 44}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-05e66dd2-9e7f-4fab-b879-121b7a0daf7b-0'),\n", + " 'groq': AIMessage(content='Here\\'s something quick and interesting about parrots:\\n\\nDid you know that parrots have a special type of feather on their beaks called \"beak feathers\"? These feathers help to keep their beaks clean and free of debris, and they\\'re also thought to play a role in the parrot\\'s ability to communicate and express emotions through facial expressions! Some parrot species even have over 4,000 beak feathers, which is a lot considering they\\'re only about 1-2 inches long!', response_metadata={'token_usage': {'completion_time': 0.115, 'completion_tokens': 101, 'prompt_time': 0.006, 'prompt_tokens': 26, 'queue_time': None, 'total_time': 0.12100000000000001, 'total_tokens': 127}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_af05557ca2', 'finish_reason': 'stop', 'logprobs': None}, id='run-f73fa995-e8f3-426d-bc26-36d59fa06653-0')},\n", + " {'openai': AIMessage(content='Cats have a unique grooming behavior called \"allogrooming,\" where they will groom each other as a form of social bonding and to maintain cleanliness.', response_metadata={'token_usage': {'completion_tokens': 31, 'prompt_tokens': 16, 'total_tokens': 47}, 'model_name': 'gpt-3.5-turbo-0125', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-a3681db9-6567-4aeb-949e-3b97385e5fe6-0'),\n", + " 'groq': AIMessage(content=\"Here's something quick and interesting about cats:\\n\\nDid you know that cats have a special talent for recognizing and remembering sounds? They can distinguish between over 50 different sounds, including the sound of their owner's voice, the sound of a can opener, and even the sound of a bag of treats being opened!\", response_metadata={'token_usage': {'completion_time': 0.074, 'completion_tokens': 62, 'prompt_time': 0.01, 'prompt_tokens': 25, 'queue_time': None, 'total_time': 0.08399999999999999, 'total_tokens': 87}, 'model_name': 'llama3-8b-8192', 'system_fingerprint': 'fp_6a6771ae9c', 'finish_reason': 'stop', 'logprobs': None}, id='run-6e4c8006-c5aa-498c-b7a5-06c863d13ca7-0')}]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "from langchain.prompts import ChatPromptTemplate\n", + "from langchain.schema import HumanMessage, SystemMessage\n", + "from langchain.schema.runnable import RunnableMap\n", + "\n", + "openai_runnable = arcan_runnables.get_openai_runnable()\n", + "groq_runnable = arcan_runnables.get_groq_runnable()\n", + "\n", + "prompt = ChatPromptTemplate.from_messages(\n", + " [(\"system\", \"Tell soemthing quick and interesting about {topic}\")]\n", + ")\n", + "\n", + "# Can define custom chains\n", + "chain = prompt | RunnableMap({\n", + " \"openai\": openai_runnable,\n", + " \"groq\": groq_runnable,\n", + "})\n", + "chain.batch([{\"topic\": \"parrots\"}, {\"topic\": \"cats\"}])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ollama_runnable = arcan_runnables.get_ollama_runnable()\n", + "# ollama_runnable.invoke('hi')" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'output': 'test'}" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langserve import RemoteRunnable\n", + "\n", + "gpt4o_runnable = RemoteRunnable(\"http://localhost:8000/spells/\", headers={\"arcanai_api_key\": '1234', \"user_id\": \"test\"})\n", + "response = spells_runnable.invoke({\"input\": \"testinggggg$#@\"},)\n", + "response" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "ename": "JSONDecodeError", + "evalue": "Expecting value: line 1 column 1 (char 0)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m~/miniconda3/envs/arcan/lib/python3.11/site-packages/requests/models.py:971\u001b[0m, in \u001b[0;36mResponse.json\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 970\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 971\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcomplexjson\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloads\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 972\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m JSONDecodeError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 973\u001b[0m \u001b[38;5;66;03m# Catch JSON-related errors and raise as requests.JSONDecodeError\u001b[39;00m\n\u001b[1;32m 974\u001b[0m \u001b[38;5;66;03m# This aliases json.JSONDecodeError and simplejson.JSONDecodeError\u001b[39;00m\n", + "File \u001b[0;32m~/miniconda3/envs/arcan/lib/python3.11/json/__init__.py:346\u001b[0m, in \u001b[0;36mloads\u001b[0;34m(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 344\u001b[0m parse_int \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m parse_float \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 345\u001b[0m parse_constant \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_pairs_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kw):\n\u001b[0;32m--> 346\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_default_decoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", + "File \u001b[0;32m~/miniconda3/envs/arcan/lib/python3.11/json/decoder.py:337\u001b[0m, in \u001b[0;36mJSONDecoder.decode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Return the Python representation of ``s`` (a ``str`` instance\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;124;03mcontaining a JSON document).\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \n\u001b[1;32m 336\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m--> 337\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraw_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_w\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mend\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m end \u001b[38;5;241m=\u001b[39m _w(s, end)\u001b[38;5;241m.\u001b[39mend()\n", + "File \u001b[0;32m~/miniconda3/envs/arcan/lib/python3.11/json/decoder.py:355\u001b[0m, in \u001b[0;36mJSONDecoder.raw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[0;32m--> 355\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m JSONDecodeError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExpecting value\u001b[39m\u001b[38;5;124m\"\u001b[39m, s, err\u001b[38;5;241m.\u001b[39mvalue) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 356\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m obj, end\n", + "\u001b[0;31mJSONDecodeError\u001b[0m: Expecting value: line 1 column 1 (char 0)", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[20], line 10\u001b[0m\n\u001b[1;32m 4\u001b[0m arcanai_api_key \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39menviron\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mARCANAI_API_KEY\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 5\u001b[0m authenticated_response \u001b[38;5;241m=\u001b[39m requests\u001b[38;5;241m.\u001b[39mpost(\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhttp://localhost:8000/spells/invoke\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 7\u001b[0m json\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhello\u001b[39m\u001b[38;5;124m\"\u001b[39m},\n\u001b[1;32m 8\u001b[0m headers\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124marcanai_api_key\u001b[39m\u001b[38;5;124m\"\u001b[39m: arcanai_api_key, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muser_id\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbroomva\u001b[39m\u001b[38;5;124m\"\u001b[39m},\n\u001b[1;32m 9\u001b[0m )\n\u001b[0;32m---> 10\u001b[0m \u001b[43mauthenticated_response\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjson\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/miniconda3/envs/arcan/lib/python3.11/site-packages/requests/models.py:975\u001b[0m, in \u001b[0;36mResponse.json\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 971\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m complexjson\u001b[38;5;241m.\u001b[39mloads(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtext, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 972\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m JSONDecodeError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 973\u001b[0m \u001b[38;5;66;03m# Catch JSON-related errors and raise as requests.JSONDecodeError\u001b[39;00m\n\u001b[1;32m 974\u001b[0m \u001b[38;5;66;03m# This aliases json.JSONDecodeError and simplejson.JSONDecodeError\u001b[39;00m\n\u001b[0;32m--> 975\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m RequestsJSONDecodeError(e\u001b[38;5;241m.\u001b[39mmsg, e\u001b[38;5;241m.\u001b[39mdoc, e\u001b[38;5;241m.\u001b[39mpos)\n", + "\u001b[0;31mJSONDecodeError\u001b[0m: Expecting value: line 1 column 1 (char 0)" + ] + } + ], + "source": [ + "# import os\n", + "# import requests\n", + "\n", + "# arcanai_api_key = os.environ.get(\"ARCANAI_API_KEY\")\n", + "# authenticated_response = requests.post(\n", + "# \"http://localhost:8000/spells/invoke\",\n", + "# json={\"input\": \"hello\"},\n", + "# headers={\"arcanai_api_key\": arcanai_api_key, \"user_id\": \"broomva\"},\n", + "# )\n", + "# authenticated_response.json()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'output': {'content': 'Hello! How can I assist you today?',\n", + " 'additional_kwargs': {},\n", + " 'response_metadata': {'token_usage': {'completion_tokens': 9,\n", + " 'prompt_tokens': 8,\n", + " 'total_tokens': 17},\n", + " 'model_name': 'gpt-3.5-turbo-0125',\n", + " 'system_fingerprint': None,\n", + " 'finish_reason': 'stop',\n", + " 'logprobs': None},\n", + " 'type': 'ai',\n", + " 'name': None,\n", + " 'id': 'run-5da008b9-e665-470f-99c0-b495f2884b2d-0',\n", + " 'example': False,\n", + " 'tool_calls': [],\n", + " 'invalid_tool_calls': []},\n", + " 'metadata': {'run_id': '5da008b9-e665-470f-99c0-b495f2884b2d',\n", + " 'feedback_tokens': []}}" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os\n", + "import requests\n", + "\n", + "test_key = os.environ[\"OPENAI_API_KEY\"]\n", + "authenticated_response = requests.post(\n", + " \"http://localhost:8000/openai/invoke\",\n", + " json={\"input\": \"hello\"},\n", + " headers={\"arcanai_api_key\": test_key, \"user_id\": \"broomva\"},\n", + ")\n", + "authenticated_response.json()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "arcan", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/poetry.lock b/poetry.lock index cfc9089..d9f787c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2447,13 +2447,13 @@ sympy = "*" [[package]] name = "openai" -version = "1.28.1" +version = "1.29.0" description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-1.28.1-py3-none-any.whl", hash = "sha256:943e0d0d587b9a62f99bd3acbaf479ae5362986e5fff013f57b5b7bde85cce93"}, - {file = "openai-1.28.1.tar.gz", hash = "sha256:8a3adbba16882434768d76fd3129fcc9b40ace98f8d55a6ddacfc05c4096ac30"}, + {file = "openai-1.29.0-py3-none-any.whl", hash = "sha256:c61cd12376c84362d406341f9e2f9a9d6b81c082b133b44484dc0f43954496b1"}, + {file = "openai-1.29.0.tar.gz", hash = "sha256:d5a769f485610cff8bae14343fa45a8b1d346be3d541fa5b28ccd040dbc8baf8"}, ] [package.dependencies] @@ -4245,47 +4245,47 @@ test = ["pytest", "tornado (>=4.5)", "typeguard"] [[package]] name = "tiktoken" -version = "0.6.0" +version = "0.7.0" description = "tiktoken is a fast BPE tokeniser for use with OpenAI's models" optional = false python-versions = ">=3.8" files = [ - {file = "tiktoken-0.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:277de84ccd8fa12730a6b4067456e5cf72fef6300bea61d506c09e45658d41ac"}, - {file = "tiktoken-0.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9c44433f658064463650d61387623735641dcc4b6c999ca30bc0f8ba3fccaf5c"}, - {file = "tiktoken-0.6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afb9a2a866ae6eef1995ab656744287a5ac95acc7e0491c33fad54d053288ad3"}, - {file = "tiktoken-0.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c62c05b3109fefca26fedb2820452a050074ad8e5ad9803f4652977778177d9f"}, - {file = "tiktoken-0.6.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0ef917fad0bccda07bfbad835525bbed5f3ab97a8a3e66526e48cdc3e7beacf7"}, - {file = "tiktoken-0.6.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e095131ab6092d0769a2fda85aa260c7c383072daec599ba9d8b149d2a3f4d8b"}, - {file = "tiktoken-0.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:05b344c61779f815038292a19a0c6eb7098b63c8f865ff205abb9ea1b656030e"}, - {file = "tiktoken-0.6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cefb9870fb55dca9e450e54dbf61f904aab9180ff6fe568b61f4db9564e78871"}, - {file = "tiktoken-0.6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:702950d33d8cabc039845674107d2e6dcabbbb0990ef350f640661368df481bb"}, - {file = "tiktoken-0.6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8d49d076058f23254f2aff9af603863c5c5f9ab095bc896bceed04f8f0b013a"}, - {file = "tiktoken-0.6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:430bc4e650a2d23a789dc2cdca3b9e5e7eb3cd3935168d97d43518cbb1f9a911"}, - {file = "tiktoken-0.6.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:293cb8669757301a3019a12d6770bd55bec38a4d3ee9978ddbe599d68976aca7"}, - {file = "tiktoken-0.6.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7bd1a288b7903aadc054b0e16ea78e3171f70b670e7372432298c686ebf9dd47"}, - {file = "tiktoken-0.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:ac76e000183e3b749634968a45c7169b351e99936ef46f0d2353cd0d46c3118d"}, - {file = "tiktoken-0.6.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:17cc8a4a3245ab7d935c83a2db6bb71619099d7284b884f4b2aea4c74f2f83e3"}, - {file = "tiktoken-0.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:284aebcccffe1bba0d6571651317df6a5b376ff6cfed5aeb800c55df44c78177"}, - {file = "tiktoken-0.6.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c1a3a5d33846f8cd9dd3b7897c1d45722f48625a587f8e6f3d3e85080559be8"}, - {file = "tiktoken-0.6.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6318b2bb2337f38ee954fd5efa82632c6e5ced1d52a671370fa4b2eff1355e91"}, - {file = "tiktoken-0.6.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:1f5f0f2ed67ba16373f9a6013b68da298096b27cd4e1cf276d2d3868b5c7efd1"}, - {file = "tiktoken-0.6.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:75af4c0b16609c2ad02581f3cdcd1fb698c7565091370bf6c0cf8624ffaba6dc"}, - {file = "tiktoken-0.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:45577faf9a9d383b8fd683e313cf6df88b6076c034f0a16da243bb1c139340c3"}, - {file = "tiktoken-0.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7c1492ab90c21ca4d11cef3a236ee31a3e279bb21b3fc5b0e2210588c4209e68"}, - {file = "tiktoken-0.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e2b380c5b7751272015400b26144a2bab4066ebb8daae9c3cd2a92c3b508fe5a"}, - {file = "tiktoken-0.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9f497598b9f58c99cbc0eb764b4a92272c14d5203fc713dd650b896a03a50ad"}, - {file = "tiktoken-0.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e65e8bd6f3f279d80f1e1fbd5f588f036b9a5fa27690b7f0cc07021f1dfa0839"}, - {file = "tiktoken-0.6.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5f1495450a54e564d236769d25bfefbf77727e232d7a8a378f97acddee08c1ae"}, - {file = "tiktoken-0.6.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6c4e4857d99f6fb4670e928250835b21b68c59250520a1941618b5b4194e20c3"}, - {file = "tiktoken-0.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:168d718f07a39b013032741867e789971346df8e89983fe3c0ef3fbd5a0b1cb9"}, - {file = "tiktoken-0.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:47fdcfe11bd55376785a6aea8ad1db967db7f66ea81aed5c43fad497521819a4"}, - {file = "tiktoken-0.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fb7d2ccbf1a7784810aff6b80b4012fb42c6fc37eaa68cb3b553801a5cc2d1fc"}, - {file = "tiktoken-0.6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1ccb7a111ee76af5d876a729a347f8747d5ad548e1487eeea90eaf58894b3138"}, - {file = "tiktoken-0.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2048e1086b48e3c8c6e2ceeac866561374cd57a84622fa49a6b245ffecb7744"}, - {file = "tiktoken-0.6.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:07f229a5eb250b6403a61200199cecf0aac4aa23c3ecc1c11c1ca002cbb8f159"}, - {file = "tiktoken-0.6.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:432aa3be8436177b0db5a2b3e7cc28fd6c693f783b2f8722539ba16a867d0c6a"}, - {file = "tiktoken-0.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:8bfe8a19c8b5c40d121ee7938cd9c6a278e5b97dc035fd61714b4f0399d2f7a1"}, - {file = "tiktoken-0.6.0.tar.gz", hash = "sha256:ace62a4ede83c75b0374a2ddfa4b76903cf483e9cb06247f566be3bf14e6beed"}, + {file = "tiktoken-0.7.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:485f3cc6aba7c6b6ce388ba634fbba656d9ee27f766216f45146beb4ac18b25f"}, + {file = "tiktoken-0.7.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e54be9a2cd2f6d6ffa3517b064983fb695c9a9d8aa7d574d1ef3c3f931a99225"}, + {file = "tiktoken-0.7.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79383a6e2c654c6040e5f8506f3750db9ddd71b550c724e673203b4f6b4b4590"}, + {file = "tiktoken-0.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d4511c52caacf3c4981d1ae2df85908bd31853f33d30b345c8b6830763f769c"}, + {file = "tiktoken-0.7.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:13c94efacdd3de9aff824a788353aa5749c0faee1fbe3816df365ea450b82311"}, + {file = "tiktoken-0.7.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8e58c7eb29d2ab35a7a8929cbeea60216a4ccdf42efa8974d8e176d50c9a3df5"}, + {file = "tiktoken-0.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:21a20c3bd1dd3e55b91c1331bf25f4af522c525e771691adbc9a69336fa7f702"}, + {file = "tiktoken-0.7.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:10c7674f81e6e350fcbed7c09a65bca9356eaab27fb2dac65a1e440f2bcfe30f"}, + {file = "tiktoken-0.7.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:084cec29713bc9d4189a937f8a35dbdfa785bd1235a34c1124fe2323821ee93f"}, + {file = "tiktoken-0.7.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:811229fde1652fedcca7c6dfe76724d0908775b353556d8a71ed74d866f73f7b"}, + {file = "tiktoken-0.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86b6e7dc2e7ad1b3757e8a24597415bafcfb454cebf9a33a01f2e6ba2e663992"}, + {file = "tiktoken-0.7.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1063c5748be36344c7e18c7913c53e2cca116764c2080177e57d62c7ad4576d1"}, + {file = "tiktoken-0.7.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:20295d21419bfcca092644f7e2f2138ff947a6eb8cfc732c09cc7d76988d4a89"}, + {file = "tiktoken-0.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:959d993749b083acc57a317cbc643fb85c014d055b2119b739487288f4e5d1cb"}, + {file = "tiktoken-0.7.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:71c55d066388c55a9c00f61d2c456a6086673ab7dec22dd739c23f77195b1908"}, + {file = "tiktoken-0.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:09ed925bccaa8043e34c519fbb2f99110bd07c6fd67714793c21ac298e449410"}, + {file = "tiktoken-0.7.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03c6c40ff1db0f48a7b4d2dafeae73a5607aacb472fa11f125e7baf9dce73704"}, + {file = "tiktoken-0.7.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d20b5c6af30e621b4aca094ee61777a44118f52d886dbe4f02b70dfe05c15350"}, + {file = "tiktoken-0.7.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d427614c3e074004efa2f2411e16c826f9df427d3c70a54725cae860f09e4bf4"}, + {file = "tiktoken-0.7.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8c46d7af7b8c6987fac9b9f61041b452afe92eb087d29c9ce54951280f899a97"}, + {file = "tiktoken-0.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:0bc603c30b9e371e7c4c7935aba02af5994a909fc3c0fe66e7004070858d3f8f"}, + {file = "tiktoken-0.7.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2398fecd38c921bcd68418675a6d155fad5f5e14c2e92fcf5fe566fa5485a858"}, + {file = "tiktoken-0.7.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8f5f6afb52fb8a7ea1c811e435e4188f2bef81b5e0f7a8635cc79b0eef0193d6"}, + {file = "tiktoken-0.7.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:861f9ee616766d736be4147abac500732b505bf7013cfaf019b85892637f235e"}, + {file = "tiktoken-0.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54031f95c6939f6b78122c0aa03a93273a96365103793a22e1793ee86da31685"}, + {file = "tiktoken-0.7.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:fffdcb319b614cf14f04d02a52e26b1d1ae14a570f90e9b55461a72672f7b13d"}, + {file = "tiktoken-0.7.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c72baaeaefa03ff9ba9688624143c858d1f6b755bb85d456d59e529e17234769"}, + {file = "tiktoken-0.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:131b8aeb043a8f112aad9f46011dced25d62629091e51d9dc1adbf4a1cc6aa98"}, + {file = "tiktoken-0.7.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cabc6dc77460df44ec5b879e68692c63551ae4fae7460dd4ff17181df75f1db7"}, + {file = "tiktoken-0.7.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8d57f29171255f74c0aeacd0651e29aa47dff6f070cb9f35ebc14c82278f3b25"}, + {file = "tiktoken-0.7.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ee92776fdbb3efa02a83f968c19d4997a55c8e9ce7be821ceee04a1d1ee149c"}, + {file = "tiktoken-0.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e215292e99cb41fbc96988ef62ea63bb0ce1e15f2c147a61acc319f8b4cbe5bf"}, + {file = "tiktoken-0.7.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:8a81bac94769cab437dd3ab0b8a4bc4e0f9cf6835bcaa88de71f39af1791727a"}, + {file = "tiktoken-0.7.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:d6d73ea93e91d5ca771256dfc9d1d29f5a554b83821a1dc0891987636e0ae226"}, + {file = "tiktoken-0.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:2bcb28ddf79ffa424f171dfeef9a4daff61a94c631ca6813f43967cb263b83b9"}, + {file = "tiktoken-0.7.0.tar.gz", hash = "sha256:1077266e949c24e0291f6c350433c6f0971365ece2b173a23bc3b9f9defef6b6"}, ] [package.dependencies] @@ -5149,4 +5149,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "8276a46985ecc1e75a2d55f5e630bb2d26073e7339025854ee30a67828a242d3" +content-hash = "d5ab97314080ae943b1c6487e9691fb0974ac3d7d819659db631695a60ff3d82" diff --git a/pyproject.toml b/pyproject.toml index 8b4f64d..c4bc0ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ openai = "^1.14.0" langchain = "^0.1.16" langchain-openai = "^0.0.8" uvicorn = "^0.28.0" -pydantic = "<2" #">=1.10.13,<3.0.0" +pydantic = ">=1.10.13,<2" python-dotenv = "^1.0.1" typing-extensions = "^4.9.0" pandas = "^2.2.1" @@ -37,7 +37,6 @@ langchain-experimental = "^0.0.54" langchain-community = "^0.0.32" # langchain-together = "^0.0.2.post1" # langchain-fireworks = "^0.1.1" -# semantic-router = "^0.0.28" modal = "^0.61.54" typer = "^0.9.0" langserve = {extras = ["all"], version = ">=0.0.30"} #"^0.1.1" diff --git a/tests/arcan/ai/runnables/test_runnables.py b/tests/arcan/ai/runnables/test_runnables.py new file mode 100644 index 0000000..c12ebd7 --- /dev/null +++ b/tests/arcan/ai/runnables/test_runnables.py @@ -0,0 +1,57 @@ +import os +from unittest.mock import MagicMock + +import pytest +from httpx import AsyncClient + +from arcan.ai.runnables import ArcanRunnables +from arcan.api import app + + +@pytest.fixture +def base_url(): + return "http://localhost:8000/" + + +def test_get_spells_runnable(base_url): + runnable_factory = MagicMock() + arcan_runnables = ArcanRunnables(base_url=base_url) + arcan_runnables.factory = runnable_factory + + arcan_runnables.get_spells_runnable() + + runnable_factory.get_runnable.assert_called_once_with(runnable_name="spells") + + assert arcan_runnables.get_spells_runnable().invoke( + {"input": "testinggggg$#@"} + ).json() == {"response": "test"} + + +def test_get_openai_runnable(base_url): + runnable_factory = MagicMock() + arcan_runnables = ArcanRunnables(base_url=base_url) + arcan_runnables.factory = runnable_factory + + arcan_runnables.get_openai_runnable() + + runnable_factory.get_runnable.assert_called_once_with(runnable_name="openai") + + +def test_get_groq_runnable(base_url): + runnable_factory = MagicMock() + arcan_runnables = ArcanRunnables(base_url=base_url) + arcan_runnables.factory = runnable_factory + + arcan_runnables.get_groq_runnable() + + runnable_factory.get_runnable.assert_called_once_with(runnable_name="groq") + + +# def test_get_ollama_runnable(base_url): +# runnable_factory = MagicMock() +# arcan_runnables = ArcanRunnables(base_url=base_url) +# arcan_runnables.factory = runnable_factory + +# arcan_runnables.get_ollama_runnable() + +# runnable_factory.get_runnable.assert_called_once_with(runnable_name="ollama") diff --git a/tests/arcan/api/test_api.py b/tests/arcan/api/test_api.py index 72f6c98..949f51c 100644 --- a/tests/arcan/api/test_api.py +++ b/tests/arcan/api/test_api.py @@ -6,8 +6,7 @@ from sqlalchemy.orm import Session from arcan.api import app # Adjust this import based on your project structure -from arcan.api.datamodel import get_db -from arcan.api.session import ArcanSession +from arcan.datamodel.engine import session_scope @pytest.mark.asyncio @@ -30,11 +29,11 @@ async def test_index(): @pytest.mark.asyncio -@patch("arcan.api.datamodel.get_db") # Correct the import path as necessary -async def test_chat(mock_get_db): +@patch("arcan.datamodel.engine") # Correct the import path as necessary +async def test_chat(mock_session_scope): # Create a mock session mock_session = MagicMock() - mock_get_db.return_value = mock_session + mock_session_scope.return_value = mock_session # Mock specific behaviors, e.g., query handling # mock_session.query.return_value.filter.return_value.one.return_value = YourUserModel(id="1", name="Test User") @@ -44,7 +43,7 @@ async def test_chat(mock_get_db): # mock_run_agent.return_value = "Test Response" mock_token = MagicMock() - mock_token.credentials = os.getenv("ARCAN_API_KEY") + mock_token.credentials = os.getenv("ARCANAI_API_KEY") async with AsyncClient(app=app, base_url="http://test") as ac: response = await ac.get(