Skip to content

Commit

Permalink
fix: address mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber committed Aug 14, 2024
1 parent 55f4449 commit c05357a
Show file tree
Hide file tree
Showing 14 changed files with 89 additions and 112 deletions.
1 change: 1 addition & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"files.autoSave": "onFocusChange",
"jupyter.kernels.excludePythonEnvironments": ["/usr/local/bin/python"],
"mypy-type-checker.importStrategy": "fromEnvironment",
"mypy-type-checker.preferDaemon": true,
"notebook.codeActionsOnSave": {
"notebook.source.fixAll": "explicit",
"notebook.source.organizeImports": "explicit"
Expand Down
51 changes: 1 addition & 50 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,56 +123,7 @@ src = ["src", "tests"]
target-version = "py310"

[tool.ruff.lint]
ignore-init-module-imports = true
select = [
"A",
"ASYNC",
"B",
"BLE",
"C4",
"C90",
"D",
"DTZ",
"E",
"EM",
"ERA",
"F",
"FBT",
"FLY",
"FURB",
"G",
"I",
"ICN",
"INP",
"INT",
"ISC",
"LOG",
"N",
"NPY",
"PERF",
"PGH",
"PIE",
"PL",
"PT",
"PTH",
"PYI",
"Q",
"RET",
"RSE",
"RUF",
"S",
"SIM",
"SLF",
"SLOT",
"T10",
"T20",
"TCH",
"TID",
"TRY",
"UP",
"W",
"YTT",
]
select = ["A", "ASYNC", "B", "BLE", "C4", "C90", "D", "DTZ", "E", "EM", "ERA", "F", "FBT", "FLY", "FURB", "G", "I", "ICN", "INP", "INT", "ISC", "LOG", "N", "NPY", "PERF", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "Q", "RET", "RSE", "RUF", "S", "SIM", "SLF", "SLOT", "T10", "T20", "TCH", "TID", "TRY", "UP", "W", "YTT"]
ignore = ["D203", "D213", "E501", "RET504", "RUF002", "S101", "S307"]
unfixable = ["ERA001", "F401", "F841", "T201", "T203"]

Expand Down
2 changes: 1 addition & 1 deletion src/raglite/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np
import numpy.typing as npt
from llama_cpp import Llama, LlamaRAMCache
from llama_cpp import Llama, LlamaRAMCache # type: ignore[attr-defined]
from sqlalchemy.engine import URL


Expand Down
34 changes: 20 additions & 14 deletions src/raglite/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,39 @@
from sqlalchemy.types import LargeBinary, TypeDecorator
from sqlmodel import JSON, Column, Field, Relationship, Session, SQLModel, create_engine, text

from raglite._typing import FloatMatrix


def hash_bytes(data: bytes, max_len: int = 16) -> str:
"""Hash bytes to a hexadecimal string."""
return sha256(data, usedforsecurity=False).hexdigest()[:max_len]


class NumpyArray(TypeDecorator):
class NumpyArray(TypeDecorator[np.ndarray[Any, np.dtype[np.floating[Any]]]]):
"""A NumPy array column type for SQLAlchemy."""

impl = LargeBinary

def process_bind_param(self, value: np.ndarray | None, dialect: Dialect) -> bytes | None:
def process_bind_param(
self, value: np.ndarray[Any, np.dtype[np.floating[Any]]] | None, dialect: Dialect
) -> bytes | None:
"""Convert a NumPy array to bytes."""
if value is None:
return None
buffer = io.BytesIO()
np.save(buffer, value, allow_pickle=False, fix_imports=False)
return buffer.getvalue()

def process_result_value(self, value: bytes | None, dialect: Dialect) -> np.ndarray | None:
def process_result_value(
self, value: bytes | None, dialect: Dialect
) -> np.ndarray[Any, np.dtype[np.floating[Any]]] | None:
"""Convert bytes to a NumPy array."""
if value is None:
return None
return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False)
return np.load(io.BytesIO(value), allow_pickle=False, fix_imports=False) # type: ignore[no-any-return]


class PickledObject(TypeDecorator):
class PickledObject(TypeDecorator[object]):
"""A pickled object column type for SQLAlchemy."""

impl = LargeBinary
Expand All @@ -55,7 +61,7 @@ def process_result_value(self, value: bytes | None, dialect: Dialect) -> object
"""Convert bytes to a Python object."""
if value is None:
return None
return pickle.loads(value, fix_imports=False) # noqa: S301
return pickle.loads(value, fix_imports=False) # type: ignore[no-any-return] # noqa: S301


class Document(SQLModel, table=True):
Expand Down Expand Up @@ -99,7 +105,7 @@ class Chunk(SQLModel, table=True):
index: int = Field(..., index=True)
headings: str
body: str
multi_vector_embedding: np.ndarray = Field(..., sa_column=Column(NumpyArray))
multi_vector_embedding: FloatMatrix = Field(..., sa_column=Column(NumpyArray))
metadata_: dict[str, Any] = Field(default={}, sa_column=Column("metadata", JSON))

# Add relationship so we can access chunk.document.
Expand All @@ -111,7 +117,7 @@ def from_body(
index: int,
body: str,
headings: str = "",
multi_vector_embedding: np.ndarray | None = None,
multi_vector_embedding: FloatMatrix | None = None,
**kwargs: Any,
) -> "Chunk":
"""Create a chunk from Markdown."""
Expand Down Expand Up @@ -167,7 +173,7 @@ class ChunkANNIndex(SQLModel, table=True):
id: str = Field(..., primary_key=True)
chunk_sizes: list[int] = Field(default=[], sa_column=Column(JSON))
index: NNDescent | None = Field(default=None, sa_column=Column(PickledObject))
query_adapter: np.ndarray | None = Field(default=None, sa_column=Column(NumpyArray))
query_adapter: FloatMatrix | None = Field(default=None, sa_column=Column(NumpyArray))
metadata_: dict[str, Any] = Field(default={}, sa_column=Column("metadata", JSON))

# Enable support for JSON, PickledObject, and NumpyArray columns.
Expand Down Expand Up @@ -199,7 +205,7 @@ def from_chunks(
contexts: list[Chunk],
ground_truth: str,
**kwargs: Any,
) -> "Chunk":
) -> "Eval":
"""Create a chunk from Markdown."""
document_id = contexts[0].document_id
chunk_ids = [context.id for context in contexts]
Expand Down Expand Up @@ -237,26 +243,26 @@ def create_database_engine(db_url: str | URL = "sqlite:///raglite.sqlite") -> En
# We use the chunk table as an external content table [1] to avoid duplicating the data.
# [1] https://www.sqlite.org/fts5.html#external_content_tables
with Session(engine) as session:
session.exec(
session.execute(
text("""
CREATE VIRTUAL TABLE IF NOT EXISTS chunk_fts USING fts5(body, content='chunk', content_rowid='rowid');
""")
)
session.exec(
session.execute(
text("""
CREATE TRIGGER IF NOT EXISTS chunk_fts_auto_insert AFTER INSERT ON chunk BEGIN
INSERT INTO chunk_fts(rowid, body) VALUES (new.rowid, new.body);
END;
""")
)
session.exec(
session.execute(
text("""
CREATE TRIGGER IF NOT EXISTS chunk_fts_auto_delete AFTER DELETE ON chunk BEGIN
INSERT INTO chunk_fts(chunk_fts, rowid, body) VALUES('delete', old.rowid, old.body);
END;
""")
)
session.exec(
session.execute(
text("""
CREATE TRIGGER IF NOT EXISTS chunk_fts_auto_update AFTER UPDATE ON chunk BEGIN
INSERT INTO chunk_fts(chunk_fts, rowid, body) VALUES('delete', old.rowid, old.body);
Expand Down
9 changes: 5 additions & 4 deletions src/raglite/_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@
from tqdm.auto import trange

from raglite._config import RAGLiteConfig
from raglite._typing import FloatMatrix


@lru_cache(maxsize=128)
def _embed_string_batch(
string_batch: tuple[str], *, config: RAGLiteConfig | None = None
) -> np.ndarray:
string_batch: tuple[str, ...], *, config: RAGLiteConfig | None = None
) -> FloatMatrix:
# Embed a batch of strings.
config = config or RAGLiteConfig()
if len(string_batch) == 0:
embeddings = np.zeros((0, config.embedder.n_embd()))
else:
embeddings = np.asarray(config.embedder.embed(string_batch))
embeddings = np.asarray(config.embedder.embed(string_batch)) # type: ignore[arg-type]
# Normalise embeddings to unit norm.
if config.embedder_normalize:
embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True)
Expand All @@ -26,7 +27,7 @@ def _embed_string_batch(
return embeddings


def embed_strings(strings: list[str], *, config: RAGLiteConfig | None = None) -> np.ndarray:
def embed_strings(strings: list[str], *, config: RAGLiteConfig | None = None) -> FloatMatrix:
"""Embed a list of strings as a NumPy array of row vectors."""
assert isinstance(strings, list), "Input must be a list of strings"
config = config or RAGLiteConfig()
Expand Down
8 changes: 4 additions & 4 deletions src/raglite/_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class MyNameResponse(BaseModel):
config = config or RAGLiteConfig()
# Update the system prompt with the JSON schema of the return type to help the LLM.
system_prompt = (
return_type.system_prompt.strip() + "\n",
return_type.system_prompt.strip() + "\n", # type: ignore[attr-defined]
"Format your response according to this JSON schema:\n",
return_type.model_json_schema(),
)
Expand All @@ -47,21 +47,21 @@ class MyNameResponse(BaseModel):
for _ in range(config.llm_max_tries):
response = config.llm.create_chat_completion(
messages=[
{"role": "system", "content": system_prompt},
{"role": "system", "content": system_prompt}, # type: ignore[list-item,misc]
{"role": "user", "content": user_prompt},
],
response_format={"type": "json_object", "schema": return_type.model_json_schema()},
temperature=config.llm_temperature,
**kwargs,
)
try:
instance = return_type.model_validate_json(response["choices"][0]["message"]["content"])
instance = return_type.model_validate_json(response["choices"][0]["message"]["content"]) # type: ignore[arg-type,index]
except (KeyError, ValueError, ValidationError):
# Malformed response, not a JSON string, or not a valid instance of the return type.
continue
else:
break
else:
error_message = f"Failed to extract {return_type} from input {user_prompt}"
error_message = f"Failed to extract {return_type} from input {user_prompt}."
raise ValueError(error_message)
return instance
12 changes: 6 additions & 6 deletions src/raglite/_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
from raglite._markdown import document_to_markdown
from raglite._split_chunks import split_chunks
from raglite._split_sentences import split_sentences
from raglite._typing import FloatMatrix


def _create_chunk_records(
document_id: str,
chunks: list[str],
multi_vector_embeddings: list[np.ndarray],
*,
config: RAGLiteConfig | None = None,
multi_vector_embeddings: list[FloatMatrix],
config: RAGLiteConfig,
) -> list[Chunk]:
"""Process chunks into headings, body and improved multi-vector embeddings."""
# Create the chunk records.
Expand Down Expand Up @@ -86,7 +86,7 @@ def insert_document(doc_path: Path, *, config: RAGLiteConfig | None = None) -> N
session.commit()
# Create the chunk records.
chunk_records = _create_chunk_records(
document_record.id, chunks, multi_vector_embeddings, config=config
document_record.id, chunks, multi_vector_embeddings, config
)
# Store the chunk records.
for chunk_record in tqdm(
Expand Down Expand Up @@ -129,8 +129,8 @@ def update_vector_index(config: RAGLiteConfig | None = None) -> None:
)
chunk_ann_index.index.prepare()
else:
chunk_ann_index.index.update(X_unindexed)
chunk_ann_index.index.prepare()
chunk_ann_index.index.update(X_unindexed) # type: ignore[union-attr]
chunk_ann_index.index.prepare() # type: ignore[union-attr]
chunk_ann_index.chunk_sizes.extend(
[chunk.multi_vector_embedding.shape[0] for chunk in unindexed_chunks]
)
Expand Down
6 changes: 3 additions & 3 deletions src/raglite/_markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def add_heading_level_metadata(pages: list[dict[str, Any]]) -> list[dict[str, An
for line in block["lines"]
for span in line["spans"]
]
font_sizes = np.asarray(font_sizes)
font_sizes = np.asarray(font_sizes) # type: ignore[assignment]
font_sizes = np.round(font_sizes * 2) / 2
unique_font_sizes, counts = np.unique(font_sizes, return_counts=True)
# Determine the paragraph font size as the mode font size.
Expand Down Expand Up @@ -60,7 +60,7 @@ def add_heading_level_metadata(pages: list[dict[str, Any]]) -> list[dict[str, An
elif span_font_size == mode_font_size:
idx = 6
else:
idx = np.argmin(np.abs(heading_font_sizes - span_font_size))
idx = np.argmin(np.abs(heading_font_sizes - span_font_size)) # type: ignore[assignment]
span["md"]["heading_level"] = idx + 1
heading_level[idx] += len(span["text"])
line["md"]["heading_level"] = np.argmax(heading_level) + 1
Expand Down Expand Up @@ -153,7 +153,7 @@ def convert_to_markdown(pages: list[dict[str, Any]]) -> list[str]: # noqa: C901
def merge_split_headings(pages: list[str]) -> list[str]:
"""Merge headings that are split across lines."""

def _merge_split_headings(match: re.Match) -> str:
def _merge_split_headings(match: re.Match[str]) -> str:
atx_headings = [line.strip("# ").strip() for line in match.group().splitlines()]
return f"{match.group(1)} {' '.join(atx_headings)}\n\n"

Expand Down
28 changes: 16 additions & 12 deletions src/raglite/_query_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class AnswerResponse(BaseModel):
session.commit()


def update_query_adapter(
def update_query_adapter( # noqa: C901
*,
max_triplets: int = 4096,
max_triplets_per_eval: int = 64,
Expand Down Expand Up @@ -207,10 +207,8 @@ def update_query_adapter(
select(Eval).order_by(Eval.id).limit(max(8, max_triplets // max_triplets_per_eval))
).all()
if len(evals) * max_triplets_per_eval < config.embedder.n_embd():
error_message = (
"Run `insert_evals` to generate sufficient Evals before updating the query adapter"
)
raise ValueError
error_message = "First run `insert_evals()` to generate sufficient Evals."
raise ValueError(error_message)
# Loop over the evals to generate (q, p, n) triplets.
Q = np.zeros((0, config.embedder.n_embd())) # We want double precision here. # noqa: N806
P = np.zeros_like(Q) # noqa: N806
Expand All @@ -231,17 +229,23 @@ def update_query_adapter(
# Extract (q, p, n) triplets by comparing the retrieved chunks with the eval.
num_triplets = 0
for i, retrieved_chunk in enumerate(retrieved_chunks):
# Raise an error if the retrieved chunk is None.
if retrieved_chunk is None:
error_message = (
f"The chunk with rowid {chunk_rowids[i]} is missing from the database."
)
raise ValueError(error_message)
# Select irrelevant chunks.
if retrieved_chunk.id not in eval_.chunk_ids:
# Look up all positive chunks that are ranked lower than this negative one.
p = [
p_mve = [
np.mean(chunk.multi_vector_embedding, axis=0, keepdims=True)
for chunk in retrieved_chunks[i + 1 :]
if chunk.id in eval_.chunk_ids
if chunk is not None and chunk.id in eval_.chunk_ids
]
if not p:
if not p_mve:
continue
p = np.vstack(p)
p = np.vstack(p_mve)
n = np.repeat(
np.mean(retrieved_chunk.multi_vector_embedding, axis=0, keepdims=True),
p.shape[0],
Expand Down Expand Up @@ -298,15 +302,15 @@ def answer_evals(
with Session(engine) as session:
evals = session.exec(select(Eval)).all()
# Answer evals with RAG.
answers = []
answers: list[str] = []
for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True):
response = rag(eval_.question, search=search, config=config)
answer = "".join(response)
answers.append(answer)
# Evaluate the answers.
test_set = {
test_set: dict[str, list[str | list[str]]] = {
"question": [eval_.question for eval_ in evals],
"answer": answers,
"answer": answers, # type: ignore[dict-item]
"contexts": [eval_.contexts for eval_ in evals],
"ground_truth": [eval_.ground_truth for eval_ in evals],
}
Expand Down
Loading

0 comments on commit c05357a

Please sign in to comment.