diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 1ed8c8e..1c3d4b4 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -40,6 +40,7 @@ "files.autoSave": "onFocusChange", "jupyter.kernels.excludePythonEnvironments": ["/usr/local/bin/python"], "mypy-type-checker.importStrategy": "fromEnvironment", + "mypy-type-checker.preferDaemon": true, "notebook.codeActionsOnSave": { "notebook.source.fixAll": "explicit", "notebook.source.organizeImports": "explicit" diff --git a/pyproject.toml b/pyproject.toml index 8e88810..6a1d397 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,56 +123,7 @@ src = ["src", "tests"] target-version = "py310" [tool.ruff.lint] -ignore-init-module-imports = true -select = [ - "A", - "ASYNC", - "B", - "BLE", - "C4", - "C90", - "D", - "DTZ", - "E", - "EM", - "ERA", - "F", - "FBT", - "FLY", - "FURB", - "G", - "I", - "ICN", - "INP", - "INT", - "ISC", - "LOG", - "N", - "NPY", - "PERF", - "PGH", - "PIE", - "PL", - "PT", - "PTH", - "PYI", - "Q", - "RET", - "RSE", - "RUF", - "S", - "SIM", - "SLF", - "SLOT", - "T10", - "T20", - "TCH", - "TID", - "TRY", - "UP", - "W", - "YTT", -] +select = ["A", "ASYNC", "B", "BLE", "C4", "C90", "D", "DTZ", "E", "EM", "ERA", "F", "FBT", "FLY", "FURB", "G", "I", "ICN", "INP", "INT", "ISC", "LOG", "N", "NPY", "PERF", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "Q", "RET", "RSE", "RUF", "S", "SIM", "SLF", "SLOT", "T10", "T20", "TCH", "TID", "TRY", "UP", "W", "YTT"] ignore = ["D203", "D213", "E501", "RET504", "RUF002", "S101", "S307"] unfixable = ["ERA001", "F401", "F841", "T201", "T203"] diff --git a/src/raglite/_config.py b/src/raglite/_config.py index 8f7bb42..0c8ee9d 100644 --- a/src/raglite/_config.py +++ b/src/raglite/_config.py @@ -5,7 +5,7 @@ import numpy as np import numpy.typing as npt -from llama_cpp import Llama, LlamaRAMCache +from llama_cpp import Llama, LlamaRAMCache # type: ignore[attr-defined] from sqlalchemy.engine import URL diff --git a/src/raglite/_database.py b/src/raglite/_database.py index f87ee58..7965a9b 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -14,18 +14,22 @@ from sqlalchemy.types import LargeBinary, TypeDecorator from sqlmodel import JSON, Column, Field, Relationship, Session, SQLModel, create_engine, text +from raglite._typing import FloatMatrix + def hash_bytes(data: bytes, max_len: int = 16) -> str: """Hash bytes to a hexadecimal string.""" return sha256(data, usedforsecurity=False).hexdigest()[:max_len] -class NumpyArray(TypeDecorator): +class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]): """A NumPy array column type for SQLAlchemy.""" impl = LargeBinary - def process_bind_param(self, value: np.ndarray | None, dialect: Dialect) -> bytes | None: + def process_bind_param( + self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect + ) -> bytes | None: """Convert a NumPy array to bytes.""" if value is None: return None @@ -33,14 +37,16 @@ def process_bind_param(self, value: np.ndarray | None, dialect: Dialect) -> byte np.save(buffer, value, allow_pickle=False, fix_imports=False) return buffer.getvalue() - def process_result_value(self, value: bytes | None, dialect: Dialect) -> np.ndarray | None: + def process_result_value( + self, value: bytes | None, dialect: Dialect + ) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None: """Convert bytes to a NumPy array.""" if value is None: return None - return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False) + return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False) # type: ignore[no-any-return] -class PickledObject(TypeDecorator): +class PickledObject(TypeDecorator[object]): """A pickled object column type for SQLAlchemy.""" impl = LargeBinary @@ -55,7 +61,7 @@ def process_result_value(self, value: bytes | None, dialect: Dialect) -> object """Convert bytes to a Python object.""" if value is None: return None - return pickle.loads(value, fix_imports=False) # noqa: S301 + return pickle.loads(value, fix_imports=False) # type: ignore[no-any-return] # noqa: S301 class Document(SQLModel, table=True): @@ -99,7 +105,7 @@ class Chunk(SQLModel, table=True): index: int = Field(..., index=True) headings: str body: str - multi_vector_embedding: np.ndarray = Field(..., sa_column=Column(NumpyArray)) + multi_vector_embedding: FloatMatrix = Field(..., sa_column=Column(NumpyArray)) metadata_: dict[str, Any] = Field(default={}, sa_column=Column("metadata", JSON)) # Add relationship so we can access chunk.document. @@ -111,7 +117,7 @@ def from_body( index: int, body: str, headings: str = "", - multi_vector_embedding: np.ndarray | None = None, + multi_vector_embedding: FloatMatrix | None = None, **kwargs: Any, ) -> "Chunk": """Create a chunk from Markdown.""" @@ -167,7 +173,7 @@ class ChunkANNIndex(SQLModel, table=True): id: str = Field(..., primary_key=True) chunk_sizes: list[int] = Field(default=[], sa_column=Column(JSON)) index: NNDescent | None = Field(default=None, sa_column=Column(PickledObject)) - query_adapter: np.ndarray | None = Field(default=None, sa_column=Column(NumpyArray)) + query_adapter: FloatMatrix | None = Field(default=None, sa_column=Column(NumpyArray)) metadata_: dict[str, Any] = Field(default={}, sa_column=Column("metadata", JSON)) # Enable support for JSON, PickledObject, and NumpyArray columns. @@ -199,7 +205,7 @@ def from_chunks( contexts: list[Chunk], ground_truth: str, **kwargs: Any, - ) -> "Chunk": + ) -> "Eval": """Create a chunk from Markdown.""" document_id = contexts[0].document_id chunk_ids = [context.id for context in contexts] @@ -237,26 +243,26 @@ def create_database_engine(db_url: str | URL = "sqlite:///raglite.sqlite") -> En # We use the chunk table as an external content table [1] to avoid duplicating the data. # [1] https://www.sqlite.org/fts5.html#external_content_tables with Session(engine) as session: - session.exec( + session.execute( text(""" CREATE VIRTUAL TABLE IF NOT EXISTS chunk_fts USING fts5(body, content='chunk', content_rowid='rowid'); """) ) - session.exec( + session.execute( text(""" CREATE TRIGGER IF NOT EXISTS chunk_fts_auto_insert AFTER INSERT ON chunk BEGIN INSERT INTO chunk_fts(rowid, body) VALUES (new.rowid, new.body); END; """) ) - session.exec( + session.execute( text(""" CREATE TRIGGER IF NOT EXISTS chunk_fts_auto_delete AFTER DELETE ON chunk BEGIN INSERT INTO chunk_fts(chunk_fts, rowid, body) VALUES('delete', old.rowid, old.body); END; """) ) - session.exec( + session.execute( text(""" CREATE TRIGGER IF NOT EXISTS chunk_fts_auto_update AFTER UPDATE ON chunk BEGIN INSERT INTO chunk_fts(chunk_fts, rowid, body) VALUES('delete', old.rowid, old.body); diff --git a/src/raglite/_embed.py b/src/raglite/_embed.py index c06eacb..7aa4a93 100644 --- a/src/raglite/_embed.py +++ b/src/raglite/_embed.py @@ -6,18 +6,19 @@ from tqdm.auto import trange from raglite._config import RAGLiteConfig +from raglite._typing import FloatMatrix @lru_cache(maxsize=128) def _embed_string_batch( - string_batch: tuple[str], *, config: RAGLiteConfig | None = None -) -> np.ndarray: + string_batch: tuple[str, ...], *, config: RAGLiteConfig | None = None +) -> FloatMatrix: # Embed a batch of strings. config = config or RAGLiteConfig() if len(string_batch) == 0: embeddings = np.zeros((0, config.embedder.n_embd())) else: - embeddings = np.asarray(config.embedder.embed(string_batch)) + embeddings = np.asarray(config.embedder.embed(string_batch)) # type: ignore[arg-type] # Normalise embeddings to unit norm. if config.embedder_normalize: embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True) @@ -26,7 +27,7 @@ def _embed_string_batch( return embeddings -def embed_strings(strings: list[str], *, config: RAGLiteConfig | None = None) -> np.ndarray: +def embed_strings(strings: list[str], *, config: RAGLiteConfig | None = None) -> FloatMatrix: """Embed a list of strings as a NumPy array of row vectors.""" assert isinstance(strings, list), "Input must be a list of strings" config = config or RAGLiteConfig() diff --git a/src/raglite/_extract.py b/src/raglite/_extract.py index 0988c05..8ce7294 100644 --- a/src/raglite/_extract.py +++ b/src/raglite/_extract.py @@ -33,7 +33,7 @@ class MyNameResponse(BaseModel): config = config or RAGLiteConfig() # Update the system prompt with the JSON schema of the return type to help the LLM. system_prompt = ( - return_type.system_prompt.strip() + "\n", + return_type.system_prompt.strip() + "\n", # type: ignore[attr-defined] "Format your response according to this JSON schema:\n", return_type.model_json_schema(), ) @@ -47,7 +47,7 @@ class MyNameResponse(BaseModel): for _ in range(config.llm_max_tries): response = config.llm.create_chat_completion( messages=[ - {"role": "system", "content": system_prompt}, + {"role": "system", "content": system_prompt}, # type: ignore[list-item,misc] {"role": "user", "content": user_prompt}, ], response_format={"type": "json_object", "schema": return_type.model_json_schema()}, @@ -55,13 +55,13 @@ class MyNameResponse(BaseModel): **kwargs, ) try: - instance = return_type.model_validate_json(response["choices"][0]["message"]["content"]) + instance = return_type.model_validate_json(response["choices"][0]["message"]["content"]) # type: ignore[arg-type,index] except (KeyError, ValueError, ValidationError): # Malformed response, not a JSON string, or not a valid instance of the return type. continue else: break else: - error_message = f"Failed to extract {return_type} from input {user_prompt}" + error_message = f"Failed to extract {return_type} from input {user_prompt}." raise ValueError(error_message) return instance diff --git a/src/raglite/_index.py b/src/raglite/_index.py index c126b4d..33674a8 100644 --- a/src/raglite/_index.py +++ b/src/raglite/_index.py @@ -14,14 +14,14 @@ from raglite._markdown import document_to_markdown from raglite._split_chunks import split_chunks from raglite._split_sentences import split_sentences +from raglite._typing import FloatMatrix def _create_chunk_records( document_id: str, chunks: list[str], - multi_vector_embeddings: list[np.ndarray], - *, - config: RAGLiteConfig | None = None, + multi_vector_embeddings: list[FloatMatrix], + config: RAGLiteConfig, ) -> list[Chunk]: """Process chunks into headings, body and improved multi-vector embeddings.""" # Create the chunk records. @@ -86,7 +86,7 @@ def insert_document(doc_path: Path, *, config: RAGLiteConfig | None = None) -> N session.commit() # Create the chunk records. chunk_records = _create_chunk_records( - document_record.id, chunks, multi_vector_embeddings, config=config + document_record.id, chunks, multi_vector_embeddings, config ) # Store the chunk records. for chunk_record in tqdm( @@ -129,8 +129,8 @@ def update_vector_index(config: RAGLiteConfig | None = None) -> None: ) chunk_ann_index.index.prepare() else: - chunk_ann_index.index.update(X_unindexed) - chunk_ann_index.index.prepare() + chunk_ann_index.index.update(X_unindexed) # type: ignore[union-attr] + chunk_ann_index.index.prepare() # type: ignore[union-attr] chunk_ann_index.chunk_sizes.extend( [chunk.multi_vector_embedding.shape[0] for chunk in unindexed_chunks] ) diff --git a/src/raglite/_markdown.py b/src/raglite/_markdown.py index feba689..09ff6b2 100644 --- a/src/raglite/_markdown.py +++ b/src/raglite/_markdown.py @@ -29,7 +29,7 @@ def add_heading_level_metadata(pages: list[dict[str, Any]]) -> list[dict[str, An for line in block["lines"] for span in line["spans"] ] - font_sizes = np.asarray(font_sizes) + font_sizes = np.asarray(font_sizes) # type: ignore[assignment] font_sizes = np.round(font_sizes * 2) / 2 unique_font_sizes, counts = np.unique(font_sizes, return_counts=True) # Determine the paragraph font size as the mode font size. @@ -60,7 +60,7 @@ def add_heading_level_metadata(pages: list[dict[str, Any]]) -> list[dict[str, An elif span_font_size == mode_font_size: idx = 6 else: - idx = np.argmin(np.abs(heading_font_sizes - span_font_size)) + idx = np.argmin(np.abs(heading_font_sizes - span_font_size)) # type: ignore[assignment] span["md"]["heading_level"] = idx + 1 heading_level[idx] += len(span["text"]) line["md"]["heading_level"] = np.argmax(heading_level) + 1 @@ -153,7 +153,7 @@ def convert_to_markdown(pages: list[dict[str, Any]]) -> list[str]: # noqa: C901 def merge_split_headings(pages: list[str]) -> list[str]: """Merge headings that are split across lines.""" - def _merge_split_headings(match: re.Match) -> str: + def _merge_split_headings(match: re.Match[str]) -> str: atx_headings = [line.strip("# ").strip() for line in match.group().splitlines()] return f"{match.group(1)} {' '.join(atx_headings)}\n\n" diff --git a/src/raglite/_query_adapter.py b/src/raglite/_query_adapter.py index a6f592b..10ebb89 100644 --- a/src/raglite/_query_adapter.py +++ b/src/raglite/_query_adapter.py @@ -156,7 +156,7 @@ class AnswerResponse(BaseModel): session.commit() -def update_query_adapter( +def update_query_adapter( # noqa: C901 *, max_triplets: int = 4096, max_triplets_per_eval: int = 64, @@ -207,10 +207,8 @@ def update_query_adapter( select(Eval).order_by(Eval.id).limit(max(8, max_triplets // max_triplets_per_eval)) ).all() if len(evals) * max_triplets_per_eval < config.embedder.n_embd(): - error_message = ( - "Run `insert_evals` to generate sufficient Evals before updating the query adapter" - ) - raise ValueError + error_message = "First run `insert_evals()` to generate sufficient Evals." + raise ValueError(error_message) # Loop over the evals to generate (q, p, n) triplets. Q = np.zeros((0, config.embedder.n_embd())) # We want double precision here. # noqa: N806 P = np.zeros_like(Q) # noqa: N806 @@ -231,17 +229,23 @@ def update_query_adapter( # Extract (q, p, n) triplets by comparing the retrieved chunks with the eval. num_triplets = 0 for i, retrieved_chunk in enumerate(retrieved_chunks): + # Raise an error if the retrieved chunk is None. + if retrieved_chunk is None: + error_message = ( + f"The chunk with rowid {chunk_rowids[i]} is missing from the database." + ) + raise ValueError(error_message) # Select irrelevant chunks. if retrieved_chunk.id not in eval_.chunk_ids: # Look up all positive chunks that are ranked lower than this negative one. - p = [ + p_mve = [ np.mean(chunk.multi_vector_embedding, axis=0, keepdims=True) for chunk in retrieved_chunks[i + 1 :] - if chunk.id in eval_.chunk_ids + if chunk is not None and chunk.id in eval_.chunk_ids ] - if not p: + if not p_mve: continue - p = np.vstack(p) + p = np.vstack(p_mve) n = np.repeat( np.mean(retrieved_chunk.multi_vector_embedding, axis=0, keepdims=True), p.shape[0], @@ -298,15 +302,15 @@ def answer_evals( with Session(engine) as session: evals = session.exec(select(Eval)).all() # Answer evals with RAG. - answers = [] + answers: list[str] = [] for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True): response = rag(eval_.question, search=search, config=config) answer = "".join(response) answers.append(answer) # Evaluate the answers. - test_set = { + test_set: dict[str, list[str | list[str]]] = { "question": [eval_.question for eval_ in evals], - "answer": answers, + "answer": answers, # type: ignore[dict-item] "contexts": [eval_.contexts for eval_ in evals], "ground_truth": [eval_.ground_truth for eval_ in evals], } diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index b547f47..9361f0b 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -17,7 +17,7 @@ def rag( """Retrieval-augmented generation.""" # Retrieve relevant chunks. config = config or RAGLiteConfig() - chunk_rowids, _ = search(prompt, num_results=num_contexts, config=config) + chunk_rowids, _ = search(prompt, num_results=num_contexts, config=config) # type: ignore[call-arg] chunks = retrieve_segments(chunk_rowids, neighbors=context_neighbors) # Respond with an LLM. contexts = "\n\n".join( @@ -40,4 +40,5 @@ def rag( ) # Stream the response. for output in stream: - yield output["choices"][0]["delta"].get("content", "") + token: str = output["choices"][0]["delta"].get("content", "") # type: ignore[assignment,index,union-attr] + yield token diff --git a/src/raglite/_search.py b/src/raglite/_search.py index aaac182..f8f7c43 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -16,13 +16,17 @@ from raglite._database import Chunk, ChunkANNIndex, create_database_engine from raglite._embed import embed_strings from raglite._extract import extract_with_llm +from raglite._typing import FloatMatrix, IntVector @lru_cache(maxsize=1) -def _chunk_ann_index(config: RAGLiteConfig) -> tuple[NNDescent, np.ndarray, np.ndarray | None]: +def _chunk_ann_index(config: RAGLiteConfig) -> tuple[NNDescent, IntVector, FloatMatrix | None]: engine = create_database_engine(config.db_url) with Session(engine) as session: chunk_ann_index = session.get(ChunkANNIndex, config.ann_vector_index_id) + if chunk_ann_index is None: + error_message = "First run `update_vector_index()` to create an ANN vector index." + raise ValueError(error_message) index = chunk_ann_index.index chunk_size_cumsum = np.cumsum(np.asarray(chunk_ann_index.chunk_sizes, dtype=np.intp)) query_adapter = chunk_ann_index.query_adapter @@ -30,7 +34,7 @@ def _chunk_ann_index(config: RAGLiteConfig) -> tuple[NNDescent, np.ndarray, np.n def vector_search( - prompt: str | np.ndarray, + prompt: str | FloatMatrix, *, num_results: int = 3, query_adapter: bool = True, @@ -91,14 +95,14 @@ def keyword_search( statement = text( "SELECT chunk.rowid, bm25(chunk_fts) FROM chunk JOIN chunk_fts ON chunk.rowid = chunk_fts.rowid WHERE chunk_fts MATCH :match ORDER BY rank LIMIT :limit;" ) - results = session.exec( + results = session.execute( statement, params={"match": _prompt_to_fts_query(prompt), "limit": num_results} ) # Unpack the results and make FTS5's negative BM25 scores [1] positive. # https://www.sqlite.org/fts5.html#the_bm25_function chunk_rowids, bm25_score = zip(*results, strict=True) - chunk_rowids, bm25_score = list(chunk_rowids), [-s for s in bm25_score] - return chunk_rowids, bm25_score + chunk_rowids, bm25_score = list(chunk_rowids), [-s for s in bm25_score] # type: ignore[assignment] + return chunk_rowids, bm25_score # type: ignore[return-value] def reciprocal_rank_fusion( @@ -107,7 +111,7 @@ def reciprocal_rank_fusion( """Reciprocal Rank Fusion.""" # Compute the RRF score. rowids = {rowid for ranking in rankings for rowid in ranking} - rowid_score = defaultdict(float) + rowid_score: defaultdict[int, float] = defaultdict(float) for ranking in rankings: rowid_index = {rowid: i for i, rowid in enumerate(ranking)} for rowid in rowids: @@ -145,7 +149,7 @@ class QueriesResponse(BaseModel): """An array of queries that help answer the user prompt.""" queries: list[Annotated[str, Field(min_length=1)]] = Field( - ..., description="A single query that helps answer the user prompt.", min_items=1 + ..., description="A single query that helps answer the user prompt." ) system_prompt: ClassVar[str] = """ The user will give you a prompt in search of an answer. @@ -191,7 +195,7 @@ def retrieve_segments( if chunk is not None: chunks.add(chunk) # Extend the chunk with its neighbouring chunks. - if neighbors is not None and len(neighbors) > 0: + if chunk is not None and neighbors is not None and len(neighbors) > 0: for offset in sorted(neighbors, key=abs): where = ( Chunk.document_id == chunk.document_id, @@ -201,11 +205,11 @@ def retrieve_segments( if neighbor is not None: chunks.add(neighbor) # Sort the chunks by document_id and index (needed for groupby). - chunks = sorted(chunks, key=lambda chunk: (chunk.document_id, chunk.index)) + chunks = sorted(chunks, key=lambda chunk: (chunk.document_id, chunk.index)) # type: ignore[assignment] # Group the chunks into contiguous segments. - segments = [] + segments: list[list[Chunk]] = [] for _, group in groupby(chunks, key=lambda chunk: chunk.document_id): - segment = [] + segment: list[Chunk] = [] for chunk in group: if not segment or chunk.index == segment[-1].index + 1: segment.append(chunk) @@ -215,7 +219,7 @@ def retrieve_segments( segments.append(segment) # Convert the segments into strings. segments = [ - segment[0].headings.strip() + "\n\n" + "".join(chunk.body for chunk in segment).strip() + segment[0].headings.strip() + "\n\n" + "".join(chunk.body for chunk in segment).strip() # type: ignore[misc] for segment in segments ] - return segments + return segments # type: ignore[return-value] diff --git a/src/raglite/_split_chunks.py b/src/raglite/_split_chunks.py index 954ac37..54216a1 100644 --- a/src/raglite/_split_chunks.py +++ b/src/raglite/_split_chunks.py @@ -8,14 +8,15 @@ from scipy.sparse import coo_matrix from raglite._embed import embed_strings +from raglite._typing import FloatMatrix def split_chunks( sentences: list[str], max_size: int = 1440, sentence_window_size: int = 3, - embed: Callable[[list[str]], np.ndarray] = embed_strings, -) -> tuple[list[str], list[np.ndarray]]: + embed: Callable[[list[str]], FloatMatrix] = embed_strings, +) -> tuple[list[str], list[FloatMatrix]]: """Split sentences into optimal semantic chunks.""" # Window the sentences. whisker_size = (sentence_window_size - 1) // 2 diff --git a/src/raglite/_split_sentences.py b/src/raglite/_split_sentences.py index c718394..a98bd6a 100644 --- a/src/raglite/_split_sentences.py +++ b/src/raglite/_split_sentences.py @@ -22,7 +22,7 @@ def get_markdown_heading_indexes(doc: str) -> list[tuple[int, int]]: char_idx.append(char_idx[-1] + len(line)) for token in tokens: if token.type == "heading_open": - start_line, end_line = token.map + start_line, end_line = token.map # type: ignore[misc] heading_start = char_idx[start_line] heading_end = char_idx[end_line] headings.append((heading_start, heading_end)) diff --git a/src/raglite/_typing.py b/src/raglite/_typing.py new file mode 100644 index 0000000..c80f0e4 --- /dev/null +++ b/src/raglite/_typing.py @@ -0,0 +1,8 @@ +"""RAGLite typing.""" + +from typing import Any + +import numpy as np + +FloatMatrix = np.ndarray[tuple[int, int], np.dtype[np.floating[Any]]] +IntVector = np.ndarray[tuple[int], np.dtype[np.intp]]