Skip to content

Commit

Permalink
Run mypy on CI, fix or ignore typing issues (jupyterlab#987)
Browse files Browse the repository at this point in the history
* Run mypy on CI

* Rename, add mypy to test deps

* Fix typing jupyter-ai codebase (mostly)

* Three more cases

* update deepmerge version specifier

---------

Co-authored-by: David L. Qiu <[email protected]>
  • Loading branch information
2 people authored and michaelchia committed Sep 12, 2024
1 parent f03c5af commit caa91de
Show file tree
Hide file tree
Showing 17 changed files with 84 additions and 49 deletions.
23 changes: 21 additions & 2 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Python Unit Tests
name: Python Tests

# suppress warning raised by https://github.com/jupyter/jupyter_core/pull/292
env:
Expand All @@ -12,7 +12,7 @@ on:

jobs:
unit-tests:
name: Linux
name: Unit tests
runs-on: ubuntu-latest
steps:
- name: Checkout
Expand All @@ -28,3 +28,22 @@ jobs:
run: |
set -eux
pytest -vv -r ap --cov jupyter_ai
typing-tests:
name: Typing test
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Base Setup
uses: jupyterlab/maintainer-tools/.github/actions/base-setup@v1

- name: Install extension dependencies and build the extension
run: ./scripts/install.sh

- name: Run mypy
run: |
set -eux
mypy --version
mypy packages/jupyter-ai
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class InlineCompletionItem(BaseModel):

class CompletionError(BaseModel):
type: str
title: str
traceback: str


Expand Down
Empty file.
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ async def process_message(self, message: HumanChatMessage):

try:
with self.pending("Searching learned documents", message):
assert self.llm_chain
result = await self.llm_chain.acall({"question": query})
response = result["answer"]
self.reply(response, message)
Expand Down
27 changes: 14 additions & 13 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Optional,
Type,
Union,
cast,
)
from uuid import uuid4

Expand All @@ -28,6 +29,7 @@
)
from jupyter_ai_magics import Persona
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import LLMChain
from langchain.pydantic_v1 import BaseModel

if TYPE_CHECKING:
Expand All @@ -37,8 +39,8 @@
from langchain_core.chat_history import BaseChatMessageHistory


def get_preferred_dir(root_dir: str, preferred_dir: str) -> Optional[str]:
if preferred_dir != "":
def get_preferred_dir(root_dir: str, preferred_dir: Optional[str]) -> Optional[str]:
if preferred_dir is not None and preferred_dir != "":
preferred_dir = os.path.expanduser(preferred_dir)
if not preferred_dir.startswith(root_dir):
preferred_dir = os.path.join(root_dir, preferred_dir)
Expand All @@ -48,7 +50,7 @@ def get_preferred_dir(root_dir: str, preferred_dir: str) -> Optional[str]:

# Chat handler type, with specific attributes for each
class HandlerRoutingType(BaseModel):
routing_method: ClassVar[Union[Literal["slash_command"]]] = ...
routing_method: ClassVar[Union[Literal["slash_command"]]]
"""The routing method that sends commands to this handler."""


Expand Down Expand Up @@ -84,17 +86,17 @@ class BaseChatHandler:
multiple chat handler classes."""

# Class attributes
id: ClassVar[str] = ...
id: ClassVar[str]
"""ID for this chat handler; should be unique"""

name: ClassVar[str] = ...
name: ClassVar[str]
"""User-facing name of this handler"""

help: ClassVar[str] = ...
help: ClassVar[str]
"""What this chat handler does, which third-party models it contacts,
the data it returns to the user, and so on, for display in the UI."""

routing_type: ClassVar[HandlerRoutingType] = ...
routing_type: ClassVar[HandlerRoutingType]

uses_llm: ClassVar[bool] = True
"""Class attribute specifying whether this chat handler uses the LLM
Expand Down Expand Up @@ -160,9 +162,9 @@ def __init__(
self.chat_handlers = chat_handlers
self.context_providers = context_providers

self.llm = None
self.llm_params = None
self.llm_chain = None
self.llm: Optional[BaseProvider] = None
self.llm_params: Optional[dict] = None
self.llm_chain: Optional[LLMChain] = None

async def on_message(self, message: HumanChatMessage):
"""
Expand All @@ -175,9 +177,8 @@ async def on_message(self, message: HumanChatMessage):

# ensure the current slash command is supported
if self.routing_type.routing_method == "slash_command":
slash_command = (
"/" + self.routing_type.slash_id if self.routing_type.slash_id else ""
)
routing_type = cast(SlashCommandRoutingType, self.routing_type)
slash_command = "/" + routing_type.slash_id if routing_type.slash_id else ""
if slash_command in lm_provider_klass.unsupported_slash_commands:
self.reply(
"Sorry, the selected language model does not support this slash command."
Expand Down
7 changes: 4 additions & 3 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ def create_llm_chain(
self.llm = llm
self.prompt_template = prompt_template

runnable = prompt_template | llm
runnable = prompt_template | llm # type:ignore
if not llm.manages_history:
runnable = RunnableWithMessageHistory(
runnable=runnable,
runnable=runnable, # type:ignore[arg-type]
get_session_history=self.get_llm_chat_memory,
input_messages_key="input",
history_messages_key="history",
Expand Down Expand Up @@ -121,6 +121,7 @@ async def process_message(self, message: HumanChatMessage):
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
assert self.llm_chain
async for chunk in self.llm_chain.astream(
inputs,
config={"configurable": {"last_human_msg": message}},
Expand All @@ -132,7 +133,7 @@ async def process_message(self, message: HumanChatMessage):
stream_id = self._start_stream(human_msg=message)
received_first_chunk = True

if isinstance(chunk, AIMessageChunk):
if isinstance(chunk, AIMessageChunk) and isinstance(chunk.content, str):
self._send_stream_chunk(stream_id, chunk.content)
elif isinstance(chunk, str):
self._send_stream_chunk(stream_id, chunk)
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ async def process_message(self, message: HumanChatMessage):

self.get_llm_chain()
with self.pending("Analyzing error", message):
assert self.llm_chain
response = await self.llm_chain.apredict(
extra_instructions=extra_instructions,
stop=["\nHuman:"],
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ class GenerateChatHandler(BaseChatHandler):
def __init__(self, log_dir: Optional[str], *args, **kwargs):
super().__init__(*args, **kwargs)
self.log_dir = Path(log_dir) if log_dir else None
self.llm = None
self.llm: Optional[BaseProvider] = None

def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
Expand All @@ -248,6 +248,7 @@ async def _generate_notebook(self, prompt: str):
# Save the user input prompt, the description property is now LLM generated.
outline["prompt"] = prompt

assert self.llm
if self.llm.allows_concurrency:
# fill the outline concurrently
await afill_outline(outline, llm=self.llm, verbose=True)
Expand Down
14 changes: 9 additions & 5 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ async def learn_dir(
}
splitter = ExtensionSplitter(
splitters=splitters,
default_splitter=RecursiveCharacterTextSplitter(**splitter_kwargs),
default_splitter=RecursiveCharacterTextSplitter(
**splitter_kwargs # type:ignore[arg-type]
),
)

delayed = split(path, all_files, splitter=splitter)
Expand Down Expand Up @@ -352,7 +354,7 @@ async def aget_relevant_documents(
self, query: str
) -> Coroutine[Any, Any, List[Document]]:
if not self.index:
return []
return [] # type:ignore[return-value]

await self.delete_and_relearn()
docs = self.index.similarity_search(query)
Expand All @@ -370,12 +372,14 @@ def get_embedding_model(self):


class Retriever(BaseRetriever):
learn_chat_handler: LearnChatHandler = None
learn_chat_handler: LearnChatHandler = None # type:ignore[assignment]

def _get_relevant_documents(self, query: str) -> List[Document]:
def _get_relevant_documents( # type:ignore[override]
self, query: str
) -> List[Document]:
raise NotImplementedError()

async def _aget_relevant_documents(
async def _aget_relevant_documents( # type:ignore[override]
self, query: str
) -> Coroutine[Any, Any, List[Document]]:
docs = await self.learn_chat_handler.aget_relevant_documents(query)
Expand Down
8 changes: 4 additions & 4 deletions packages/jupyter-ai/jupyter_ai/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import shutil
import time
from typing import List, Optional, Union
from typing import List, Optional, Type, Union

from deepmerge import always_merger as Merger
from jsonschema import Draft202012Validator as Validator
Expand Down Expand Up @@ -60,7 +60,7 @@ class BlockedModelError(Exception):
pass


def _validate_provider_authn(config: GlobalConfig, provider: AnyProvider):
def _validate_provider_authn(config: GlobalConfig, provider: Type[AnyProvider]):
# TODO: handle non-env auth strategies
if not provider.auth_strategy or provider.auth_strategy.type != "env":
return
Expand Down Expand Up @@ -147,7 +147,7 @@ def _init_config_schema(self):
os.makedirs(os.path.dirname(self.schema_path), exist_ok=True)
shutil.copy(OUR_SCHEMA_PATH, self.schema_path)

def _init_validator(self) -> Validator:
def _init_validator(self) -> None:
with open(OUR_SCHEMA_PATH, encoding="utf-8") as f:
schema = json.loads(f.read())
Validator.check_schema(schema)
Expand Down Expand Up @@ -364,7 +364,7 @@ def delete_api_key(self, key_name: str):
config_dict["api_keys"].pop(key_name, None)
self._write_config(GlobalConfig(**config_dict))

def update_config(self, config_update: UpdateConfigRequest):
def update_config(self, config_update: UpdateConfigRequest): # type:ignore
last_write = os.stat(self.config_path).st_mtime_ns
if config_update.last_read and config_update.last_read < last_write:
raise WriteConflictError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def arxiv_to_text(id: str, output_dir: str) -> str:
output path to the downloaded TeX file
"""

import arxiv
import arxiv # type:ignore[import-not-found,import-untyped]

outfile = f"{id}-{datetime.now():%Y-%m-%d-%H-%M}.tex"
download_filename = "downloaded-paper.tar.gz"
Expand Down
2 changes: 1 addition & 1 deletion packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@

class AiExtension(ExtensionApp):
name = "jupyter_ai"
handlers = [
handlers = [ # type:ignore[assignment]
(r"api/ai/api_keys/(?P<api_key_name>\w+)", ApiKeysHandler),
(r"api/ai/config/?", GlobalConfigHandler),
(r"api/ai/chats/?", RootChatHandler),
Expand Down
13 changes: 6 additions & 7 deletions packages/jupyter-ai/jupyter_ai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from asyncio import AbstractEventLoop
from dataclasses import asdict
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, cast

import tornado
from jupyter_ai.chat_handlers import BaseChatHandler, SlashCommandRoutingType
Expand Down Expand Up @@ -45,15 +45,13 @@
from jupyter_ai_magics.embedding_providers import BaseEmbeddingsProvider
from jupyter_ai_magics.providers import BaseProvider

from .context_providers import BaseContextProvider
from .history import BoundedChatHistory
from .context_providers import BaseContextProvider


class ChatHistoryHandler(BaseAPIHandler):
"""Handler to return message history"""

_messages = []

@property
def chat_history(self) -> List[ChatMessage]:
return self.settings["chat_history"]
Expand Down Expand Up @@ -149,6 +147,7 @@ def get_chat_user(self) -> ChatUser:
environment."""
# Get a dictionary of all loaded extensions.
# (`serverapp` is a property on all `JupyterHandler` subclasses)
assert self.serverapp
extensions = self.serverapp.extension_manager.extensions
collaborative = (
"jupyter_collaboration" in extensions
Expand Down Expand Up @@ -405,7 +404,7 @@ def filter_predicate(local_model_id: str):
if self.blocked_models:
return model_id not in self.blocked_models
else:
return model_id in self.allowed_models
return model_id in cast(List, self.allowed_models)

# filter out every model w/ model ID according to allow/blocklist
for provider in providers:
Expand Down Expand Up @@ -518,7 +517,7 @@ def post(self):

class ApiKeysHandler(BaseAPIHandler):
@property
def config_manager(self) -> ConfigManager:
def config_manager(self) -> ConfigManager: # type:ignore[override]
return self.settings["jai_config_manager"]

@web.authenticated
Expand All @@ -533,7 +532,7 @@ class SlashCommandsInfoHandler(BaseAPIHandler):
"""List slash commands that are currently available to the user."""

@property
def config_manager(self) -> ConfigManager:
def config_manager(self) -> ConfigManager: # type:ignore[override]
return self.settings["jai_config_manager"]

@property
Expand Down
4 changes: 2 additions & 2 deletions packages/jupyter-ai/jupyter_ai/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class BoundedChatHistory(BaseChatMessageHistory, BaseModel):
_all_messages: List[BaseMessage] = PrivateAttr(default_factory=list)

@property
def messages(self) -> List[BaseMessage]:
def messages(self) -> List[BaseMessage]: # type:ignore[override]
if self.k is None:
return self._all_messages
return self._all_messages[-self.k * 2 :]
Expand Down Expand Up @@ -92,7 +92,7 @@ class WrappedBoundedChatHistory(BaseChatMessageHistory, BaseModel):
last_human_msg: HumanChatMessage

@property
def messages(self) -> List[BaseMessage]:
def messages(self) -> List[BaseMessage]: # type:ignore[override]
return self.history.messages

def add_message(self, message: BaseMessage) -> None:
Expand Down
Loading

0 comments on commit caa91de

Please sign in to comment.