From 7ec95826dd2b6e65db93ec253172214b7dfcd141 Mon Sep 17 00:00:00 2001 From: Manolo Santos Date: Sun, 15 Dec 2024 12:43:14 +0100 Subject: [PATCH] fix: support pgvector v0.7.0+ (#63) --- poetry.lock | 2 +- pyproject.toml | 2 ++ src/raglite/_database.py | 28 ++++++++++++++++++++++------ 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/poetry.lock b/poetry.lock index 7353b43..4429f28 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6813,4 +6813,4 @@ ragas = ["ragas"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "b3a14066711fe4caec356d0aa18514495d44ac253371d2560fc0c5aea890aaef" +content-hash = "239db3c85866a30b063fa6dfe538bbbe92ba659419e47446d5529fbd5bb3831a" diff --git a/pyproject.toml b/pyproject.toml index 10e3f57..fe253bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,8 @@ ragas = { version = ">=0.1.12", optional = true } typer = ">=0.12.5" # Frontend: chainlit = { version = ">=1.2.0", optional = true } +# Utilities: +packaging = ">=23.0" [tool.poetry.extras] # https://python-poetry.org/docs/pyproject/#extras chainlit = ["chainlit"] diff --git a/src/raglite/_database.py b/src/raglite/_database.py index 573a3cc..8d6c179 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -6,11 +6,13 @@ from functools import lru_cache from hashlib import sha256 from pathlib import Path -from typing import Any +from typing import Any, cast from xml.sax.saxutils import escape import numpy as np from markdown_it import MarkdownIt +from packaging import version +from packaging.version import Version from pydantic import ConfigDict from sqlalchemy.engine import Engine, make_url from sqlmodel import JSON, Column, Field, Relationship, Session, SQLModel, create_engine, text @@ -310,6 +312,18 @@ def from_chunks( ) +def _pgvector_version(session: Session) -> Version: + try: + result = session.execute( + text("SELECT extversion FROM pg_extension WHERE extname = 'vector'") + ) + pgvector_version = version.parse(cast(str, result.scalar_one())) + except Exception as e: + error_message = "Unable to parse pgvector version, is pgvector installed?" + raise ValueError(error_message) from e + return pgvector_version + + @lru_cache(maxsize=1) def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: """Create a database engine and initialize it.""" @@ -358,17 +372,19 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine: CREATE INDEX IF NOT EXISTS keyword_search_chunk_index ON chunk USING GIN (to_tsvector('simple', body)); """) ) - session.execute( - text(f""" + create_vector_index_sql = f""" CREATE INDEX IF NOT EXISTS vector_search_chunk_index ON chunk_embedding USING hnsw ( (embedding::halfvec({embedding_dim})) halfvec_{metrics[config.vector_search_index_metric]}_ops ); SET hnsw.ef_search = {20 * 4 * 8}; - SET hnsw.iterative_scan = {'relaxed_order' if config.reranker else 'strict_order'}; - """) - ) + """ + # Enable iterative scan for pgvector v0.8.0 and up. + pgvector_version = _pgvector_version(session) + if pgvector_version and pgvector_version >= version.parse("0.8.0"): + create_vector_index_sql += f"\nSET hnsw.iterative_scan = {'relaxed_order' if config.reranker else 'strict_order'};" + session.execute(text(create_vector_index_sql)) session.commit() elif db_backend == "sqlite": # Create a virtual table for keyword search on the chunk table.