diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 518cd437c..05224180e 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -1,4 +1,4 @@ -name: Python Unit Tests +name: Python Tests # suppress warning raised by https://github.com/jupyter/jupyter_core/pull/292 env: @@ -12,7 +12,7 @@ on: jobs: unit-tests: - name: Linux + name: Unit tests runs-on: ubuntu-latest steps: - name: Checkout @@ -28,3 +28,22 @@ jobs: run: | set -eux pytest -vv -r ap --cov jupyter_ai + + typing-tests: + name: Typing test + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Base Setup + uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1 + + - name: Install extension dependencies and build the extension + run: ./scripts/install.sh + + - name: Run mypy + run: | + set -eux + mypy --version + mypy packages/jupyter-ai diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py b/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py index 147f6ceec..f2ee0cd54 100644 --- a/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py +++ b/packages/jupyter-ai-magics/jupyter_ai_magics/models/completion.py @@ -43,6 +43,7 @@ class InlineCompletionItem(BaseModel): class CompletionError(BaseModel): type: str + title: str traceback: str diff --git a/packages/jupyter-ai-magics/jupyter_ai_magics/py.typed b/packages/jupyter-ai-magics/jupyter_ai_magics/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py index 5c3026685..2da303e5a 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py @@ -72,6 +72,7 @@ async def process_message(self, message: HumanChatMessage): try: with self.pending("Searching learned documents", message): + assert self.llm_chain result = await self.llm_chain.acall({"question": query}) response = result["answer"] self.reply(response, message) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py index e524057be..faa357087 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/base.py @@ -13,6 +13,7 @@ Optional, Type, Union, + cast, ) from uuid import uuid4 @@ -28,6 +29,7 @@ ) from jupyter_ai_magics import Persona from jupyter_ai_magics.providers import BaseProvider +from langchain.chains import LLMChain from langchain.pydantic_v1 import BaseModel if TYPE_CHECKING: @@ -37,8 +39,8 @@ from langchain_core.chat_history import BaseChatMessageHistory -def get_preferred_dir(root_dir: str, preferred_dir: str) -> Optional[str]: - if preferred_dir != "": +def get_preferred_dir(root_dir: str, preferred_dir: Optional[str]) -> Optional[str]: + if preferred_dir is not None and preferred_dir != "": preferred_dir = os.path.expanduser(preferred_dir) if not preferred_dir.startswith(root_dir): preferred_dir = os.path.join(root_dir, preferred_dir) @@ -48,7 +50,7 @@ def get_preferred_dir(root_dir: str, preferred_dir: str) -> Optional[str]: # Chat handler type, with specific attributes for each class HandlerRoutingType(BaseModel): - routing_method: ClassVar[Union[Literal["slash_command"]]] = ... + routing_method: ClassVar[Union[Literal["slash_command"]]] """The routing method that sends commands to this handler.""" @@ -84,17 +86,17 @@ class BaseChatHandler: multiple chat handler classes.""" # Class attributes - id: ClassVar[str] = ... + id: ClassVar[str] """ID for this chat handler; should be unique""" - name: ClassVar[str] = ... + name: ClassVar[str] """User-facing name of this handler""" - help: ClassVar[str] = ... + help: ClassVar[str] """What this chat handler does, which third-party models it contacts, the data it returns to the user, and so on, for display in the UI.""" - routing_type: ClassVar[HandlerRoutingType] = ... + routing_type: ClassVar[HandlerRoutingType] uses_llm: ClassVar[bool] = True """Class attribute specifying whether this chat handler uses the LLM @@ -160,9 +162,9 @@ def __init__( self.chat_handlers = chat_handlers self.context_providers = context_providers - self.llm = None - self.llm_params = None - self.llm_chain = None + self.llm: Optional[BaseProvider] = None + self.llm_params: Optional[dict] = None + self.llm_chain: Optional[LLMChain] = None async def on_message(self, message: HumanChatMessage): """ @@ -175,9 +177,8 @@ async def on_message(self, message: HumanChatMessage): # ensure the current slash command is supported if self.routing_type.routing_method == "slash_command": - slash_command = ( - "/" + self.routing_type.slash_id if self.routing_type.slash_id else "" - ) + routing_type = cast(SlashCommandRoutingType, self.routing_type) + slash_command = "/" + routing_type.slash_id if routing_type.slash_id else "" if slash_command in lm_provider_klass.unsupported_slash_commands: self.reply( "Sorry, the selected language model does not support this slash command." diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 3f1677dbb..e178c3e4f 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -45,10 +45,10 @@ def create_llm_chain( self.llm = llm self.prompt_template = prompt_template - runnable = prompt_template | llm + runnable = prompt_template | llm # type:ignore if not llm.manages_history: runnable = RunnableWithMessageHistory( - runnable=runnable, + runnable=runnable, # type:ignore[arg-type] get_session_history=self.get_llm_chat_memory, input_messages_key="input", history_messages_key="history", @@ -121,6 +121,7 @@ async def process_message(self, message: HumanChatMessage): # stream response in chunks. this works even if a provider does not # implement streaming, as `astream()` defaults to yielding `_call()` # when `_stream()` is not implemented on the LLM class. + assert self.llm_chain async for chunk in self.llm_chain.astream( inputs, config={"configurable": {"last_human_msg": message}}, @@ -132,7 +133,7 @@ async def process_message(self, message: HumanChatMessage): stream_id = self._start_stream(human_msg=message) received_first_chunk = True - if isinstance(chunk, AIMessageChunk): + if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str): self._send_stream_chunk(stream_id, chunk.content) elif isinstance(chunk, str): self._send_stream_chunk(stream_id, chunk) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py index d6ecc6d81..1056e592c 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py @@ -93,6 +93,7 @@ async def process_message(self, message: HumanChatMessage): self.get_llm_chain() with self.pending("Analyzing error", message): + assert self.llm_chain response = await self.llm_chain.apredict( extra_instructions=extra_instructions, stop=["\nHuman:"], diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py index 52398eabe..a69b5ed28 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py @@ -226,7 +226,7 @@ class GenerateChatHandler(BaseChatHandler): def __init__(self, log_dir: Optional[str], *args, **kwargs): super().__init__(*args, **kwargs) self.log_dir = Path(log_dir) if log_dir else None - self.llm = None + self.llm: Optional[BaseProvider] = None def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] @@ -248,6 +248,7 @@ async def _generate_notebook(self, prompt: str): # Save the user input prompt, the description property is now LLM generated. outline["prompt"] = prompt + assert self.llm if self.llm.allows_concurrency: # fill the outline concurrently await afill_outline(outline, llm=self.llm, verbose=True) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py index 29e147f22..3d1c46661 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py @@ -223,7 +223,9 @@ async def learn_dir( } splitter = ExtensionSplitter( splitters=splitters, - default_splitter=RecursiveCharacterTextSplitter(**splitter_kwargs), + default_splitter=RecursiveCharacterTextSplitter( + **splitter_kwargs # type:ignore[arg-type] + ), ) delayed = split(path, all_files, splitter=splitter) @@ -352,7 +354,7 @@ async def aget_relevant_documents( self, query: str ) -> Coroutine[Any, Any, List[Document]]: if not self.index: - return [] + return [] # type:ignore[return-value] await self.delete_and_relearn() docs = self.index.similarity_search(query) @@ -370,12 +372,14 @@ def get_embedding_model(self): class Retriever(BaseRetriever): - learn_chat_handler: LearnChatHandler = None + learn_chat_handler: LearnChatHandler = None # type:ignore[assignment] - def _get_relevant_documents(self, query: str) -> List[Document]: + def _get_relevant_documents( # type:ignore[override] + self, query: str + ) -> List[Document]: raise NotImplementedError() - async def _aget_relevant_documents( + async def _aget_relevant_documents( # type:ignore[override] self, query: str ) -> Coroutine[Any, Any, List[Document]]: docs = await self.learn_chat_handler.aget_relevant_documents(query) diff --git a/packages/jupyter-ai/jupyter_ai/config_manager.py b/packages/jupyter-ai/jupyter_ai/config_manager.py index d7674c5a2..2f38ca992 100644 --- a/packages/jupyter-ai/jupyter_ai/config_manager.py +++ b/packages/jupyter-ai/jupyter_ai/config_manager.py @@ -3,7 +3,7 @@ import os import shutil import time -from typing import List, Optional, Union +from typing import List, Optional, Type, Union from deepmerge import always_merger as Merger from jsonschema import Draft202012Validator as Validator @@ -60,7 +60,7 @@ class BlockedModelError(Exception): pass -def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider): +def _validate_provider_authn(config: GlobalConfig, provider: Type[AnyProvider]): # TODO: handle non-env auth strategies if not provider.auth_strategy or provider.auth_strategy.type != "env": return @@ -147,7 +147,7 @@ def _init_config_schema(self): os.makedirs(os.path.dirname(self.schema_path), exist_ok=True) shutil.copy(OUR_SCHEMA_PATH, self.schema_path) - def _init_validator(self) -> Validator: + def _init_validator(self) -> None: with open(OUR_SCHEMA_PATH, encoding="utf-8") as f: schema = json.loads(f.read()) Validator.check_schema(schema) @@ -364,7 +364,7 @@ def delete_api_key(self, key_name: str): config_dict["api_keys"].pop(key_name, None) self._write_config(GlobalConfig(**config_dict)) - def update_config(self, config_update: UpdateConfigRequest): + def update_config(self, config_update: UpdateConfigRequest): # type:ignore last_write = os.stat(self.config_path).st_mtime_ns if config_update.last_read and config_update.last_read < last_write: raise WriteConflictError( diff --git a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py index 7b9b28328..c8af71d84 100644 --- a/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py +++ b/packages/jupyter-ai/jupyter_ai/document_loaders/directory.py @@ -30,7 +30,7 @@ def arxiv_to_text(id: str, output_dir: str) -> str: output path to the downloaded TeX file """ - import arxiv + import arxiv # type:ignore[import-not-found,import-untyped] outfile = f"{id}-{datetime.now():%Y-%m-%d-%H-%M}.tex" download_filename = "downloaded-paper.tar.gz" diff --git a/packages/jupyter-ai/jupyter_ai/extension.py b/packages/jupyter-ai/jupyter_ai/extension.py index fa1c66890..7f3e46cd6 100644 --- a/packages/jupyter-ai/jupyter_ai/extension.py +++ b/packages/jupyter-ai/jupyter_ai/extension.py @@ -54,7 +54,7 @@ class AiExtension(ExtensionApp): name = "jupyter_ai" - handlers = [ + handlers = [ # type:ignore[assignment] (r"api/ai/api_keys/(?P\w+)", ApiKeysHandler), (r"api/ai/config/?", GlobalConfigHandler), (r"api/ai/chats/?", RootChatHandler), diff --git a/packages/jupyter-ai/jupyter_ai/handlers.py b/packages/jupyter-ai/jupyter_ai/handlers.py index 77ff43110..0b0cfb152 100644 --- a/packages/jupyter-ai/jupyter_ai/handlers.py +++ b/packages/jupyter-ai/jupyter_ai/handlers.py @@ -4,7 +4,7 @@ import uuid from asyncio import AbstractEventLoop from dataclasses import asdict -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, cast import tornado from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType @@ -45,15 +45,13 @@ from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider from jupyter_ai_magics.providers import BaseProvider - from .context_providers import BaseContextProvider from .history import BoundedChatHistory + from .context_providers import BaseContextProvider class ChatHistoryHandler(BaseAPIHandler): """Handler to return message history""" - _messages = [] - @property def chat_history(self) -> List[ChatMessage]: return self.settings["chat_history"] @@ -149,6 +147,7 @@ def get_chat_user(self) -> ChatUser: environment.""" # Get a dictionary of all loaded extensions. # (`serverapp` is a property on all `JupyterHandler` subclasses) + assert self.serverapp extensions = self.serverapp.extension_manager.extensions collaborative = ( "jupyter_collaboration" in extensions @@ -405,7 +404,7 @@ def filter_predicate(local_model_id: str): if self.blocked_models: return model_id not in self.blocked_models else: - return model_id in self.allowed_models + return model_id in cast(List, self.allowed_models) # filter out every model w/ model ID according to allow/blocklist for provider in providers: @@ -518,7 +517,7 @@ def post(self): class ApiKeysHandler(BaseAPIHandler): @property - def config_manager(self) -> ConfigManager: + def config_manager(self) -> ConfigManager: # type:ignore[override] return self.settings["jai_config_manager"] @web.authenticated @@ -533,7 +532,7 @@ class SlashCommandsInfoHandler(BaseAPIHandler): """List slash commands that are currently available to the user.""" @property - def config_manager(self) -> ConfigManager: + def config_manager(self) -> ConfigManager: # type:ignore[override] return self.settings["jai_config_manager"] @property diff --git a/packages/jupyter-ai/jupyter_ai/history.py b/packages/jupyter-ai/jupyter_ai/history.py index c857b4486..0f1ba7dc0 100644 --- a/packages/jupyter-ai/jupyter_ai/history.py +++ b/packages/jupyter-ai/jupyter_ai/history.py @@ -25,7 +25,7 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel): _all_messages: List[BaseMessage] = PrivateAttr(default_factory=list) @property - def messages(self) -> List[BaseMessage]: + def messages(self) -> List[BaseMessage]: # type:ignore[override] if self.k is None: return self._all_messages return self._all_messages[-self.k * 2 :] @@ -92,7 +92,7 @@ class WrappedBoundedChatHistory(BaseChatMessageHistory, BaseModel): last_human_msg: HumanChatMessage @property - def messages(self) -> List[BaseMessage]: + def messages(self) -> List[BaseMessage]: # type:ignore[override] return self.history.messages def add_message(self, message: BaseMessage) -> None: diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 1ce3a04e2..0a3476bb4 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -70,8 +70,7 @@ class ChatClient(ChatUser): id: str -class AgentChatMessage(BaseModel): - type: Literal["agent"] = "agent" +class BaseAgentMessage(BaseModel): id: str time: float body: str @@ -89,7 +88,11 @@ class AgentChatMessage(BaseModel): """ -class AgentStreamMessage(AgentChatMessage): +class AgentChatMessage(BaseAgentMessage): + type: Literal["agent"] = "agent" + + +class AgentStreamMessage(BaseAgentMessage): type: Literal["agent-stream"] = "agent-stream" complete: bool # other attrs inherited from `AgentChatMessage` @@ -138,15 +141,13 @@ class PendingMessage(BaseModel): class ClosePendingMessage(BaseModel): - type: Literal["pending"] = "close-pending" + type: Literal["close-pending"] = "close-pending" id: str # the type of messages being broadcast to clients ChatMessage = Union[ - AgentChatMessage, - HumanChatMessage, - AgentStreamMessage, + AgentChatMessage, HumanChatMessage, AgentStreamMessage, AgentStreamChunkMessage ] @@ -164,8 +165,7 @@ class ConnectionMessage(BaseModel): Message = Union[ - AgentChatMessage, - HumanChatMessage, + ChatMessage, ConnectionMessage, ClearMessage, PendingMessage, diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index e8deeb133..88f5fc55f 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ "faiss-cpu<=1.8.0", # Not distributed by official repo "typing_extensions>=4.5.0", "traitlets>=5.0", - "deepmerge>=1.0", + "deepmerge>=2.0,<3", ] dynamic = ["version", "description", "authors", "urls", "keywords"] @@ -50,6 +50,8 @@ test = [ "pytest-tornasync", "pytest-jupyter", "syrupy~=4.0.8", + "types-jsonschema", + "mypy", ] dev = ["jupyter_ai_magics[dev]"] diff --git a/pyproject.toml b/pyproject.toml index cc45dd977..d10ce4594 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,3 +43,8 @@ ignore_imports = ["jupyter_ai_magics.providers -> pydantic"] [tool.pytest.ini_options] addopts = "--ignore packages/jupyter-ai-module-cookiecutter" + +[tool.mypy] +exclude = [ + "tests" +]