From e17de264e934e2b89a872eb67845aba7da652778 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Fri, 14 Jun 2024 12:11:46 +0530 Subject: [PATCH] feat: Qdrant vector storage --- .env.example | 19 +++ poetry.lock | 178 +++++++++++++++++++++- pyproject.toml | 1 + warc_gpt/commands/ingest.py | 43 ++---- warc_gpt/commands/visualize.py | 20 ++- warc_gpt/utils/vector_storage/__init__.py | 5 + warc_gpt/utils/vector_storage/base.py | 46 ++++++ warc_gpt/utils/vector_storage/chroma.py | 49 ++++++ warc_gpt/utils/vector_storage/qdrant.py | 144 +++++++++++++++++ warc_gpt/views/api/search.py | 50 ++---- wsgi.py | 1 + 11 files changed, 480 insertions(+), 76 deletions(-) create mode 100644 warc_gpt/utils/vector_storage/__init__.py create mode 100644 warc_gpt/utils/vector_storage/base.py create mode 100644 warc_gpt/utils/vector_storage/chroma.py create mode 100644 warc_gpt/utils/vector_storage/qdrant.py diff --git a/.env.example b/.env.example index 0c8418e..371de39 100644 --- a/.env.example +++ b/.env.example @@ -71,6 +71,7 @@ VECTOR_SEARCH_PATH="./chromadb" #------------------------------------------------------------------------------- # Vector Store / Search settings #------------------------------------------------------------------------------- +VECTOR_SEARCH_DATABASE="chroma" # Available options are "chroma", "qdrant" VECTOR_SEARCH_COLLECTION_NAME="collection" VECTOR_SEARCH_SENTENCE_TRANSFORMER_MODEL="intfloat/e5-large-v2" # Can be any Sentence-Transformers compatible model available on Hugging Face @@ -83,6 +84,24 @@ VECTOR_SEARCH_QUERY_PREFIX="query: " # Can be used to add prefix to text embeddi VECTOR_SEARCH_TEXT_SPLITTER_CHUNK_OVERLAP=25 # Determines, for a given chunk of text, how many tokens must overlap with adjacent chunks. VECTOR_SEARCH_SEARCH_N_RESULTS=4 # How many entries should the vector search return? +#------------------------------------------------------------------------------- +# Qdrant settings +#------------------------------------------------------------------------------- +# NOTE: +# - This set of variables allows to configure Qdrant vector storage. +# - Applicable when "VECTOR_SEARCH_DATABASE" is set to "qdrant". +# QDRANT_LOCATION="http://localhost:6333" +# QDRANT_PORT=6333 +# QDRANT_GRPC_PORT=6334 +# QDRANT_PREFER_GRPC=False +# QDRANT_HTTPS=False +# QDRANT_URL= +# QDRANT_API_KEY= +# QDRANT_PREFIX= +# QDRANT_TIMEOUT= +# QDRANT_HOST= +# QDRANT_PATH= + #------------------------------------------------------------------------------- # Basic Rate Limiting #------------------------------------------------------------------------------- diff --git a/poetry.lock b/poetry.lock index d5a9193..0713c23 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1169,6 +1169,74 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.63.0)"] +[[package]] +name = "grpcio-tools" +version = "1.62.2" +description = "Protobuf code generator for gRPC" +optional = false +python-versions = ">=3.7" +files = [ + {file = "grpcio-tools-1.62.2.tar.gz", hash = "sha256:5fd5e1582b678e6b941ee5f5809340be5e0724691df5299aae8226640f94e18f"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:1679b4903aed2dc5bd8cb22a452225b05dc8470a076f14fd703581efc0740cdb"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:9d41e0e47dd075c075bb8f103422968a65dd0d8dc8613288f573ae91eb1053ba"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:987e774f74296842bbffd55ea8826370f70c499e5b5f71a8cf3103838b6ee9c3"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40cd4eeea4b25bcb6903b82930d579027d034ba944393c4751cdefd9c49e6989"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6746bc823958499a3cf8963cc1de00072962fb5e629f26d658882d3f4c35095"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2ed775e844566ce9ce089be9a81a8b928623b8ee5820f5e4d58c1a9d33dfc5ae"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bdc5dd3f57b5368d5d661d5d3703bcaa38bceca59d25955dff66244dbc987271"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-win32.whl", hash = "sha256:3a8d6f07e64c0c7756f4e0c4781d9d5a2b9cc9cbd28f7032a6fb8d4f847d0445"}, + {file = "grpcio_tools-1.62.2-cp310-cp310-win_amd64.whl", hash = "sha256:e33b59fb3efdddeb97ded988a871710033e8638534c826567738d3edce528752"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:472505d030135d73afe4143b0873efe0dcb385bd6d847553b4f3afe07679af00"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:ec674b4440ef4311ac1245a709e87b36aca493ddc6850eebe0b278d1f2b6e7d1"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:184b4174d4bd82089d706e8223e46c42390a6ebac191073b9772abc77308f9fa"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c195d74fe98541178ece7a50dad2197d43991e0f77372b9a88da438be2486f12"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a34d97c62e61bfe9e6cff0410fe144ac8cca2fc979ad0be46b7edf026339d161"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cbb8453ae83a1db2452b7fe0f4b78e4a8dd32be0f2b2b73591ae620d4d784d3d"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4f989e5cebead3ae92c6abf6bf7b19949e1563a776aea896ac5933f143f0c45d"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-win32.whl", hash = "sha256:c48fabe40b9170f4e3d7dd2c252e4f1ff395dc24e49ac15fc724b1b6f11724da"}, + {file = "grpcio_tools-1.62.2-cp311-cp311-win_amd64.whl", hash = "sha256:8c616d0ad872e3780693fce6a3ac8ef00fc0963e6d7815ce9dcfae68ba0fc287"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-linux_armv7l.whl", hash = "sha256:10cc3321704ecd17c93cf68c99c35467a8a97ffaaed53207e9b2da6ae0308ee1"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:9be84ff6d47fd61462be7523b49d7ba01adf67ce4e1447eae37721ab32464dd8"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:d82f681c9a9d933a9d8068e8e382977768e7779ddb8870fa0cf918d8250d1532"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:04c607029ae3660fb1624ed273811ffe09d57d84287d37e63b5b802a35897329"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72b61332f1b439c14cbd3815174a8f1d35067a02047c32decd406b3a09bb9890"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8214820990d01b52845f9fbcb92d2b7384a0c321b303e3ac614c219dc7d1d3af"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:462e0ab8dd7c7b70bfd6e3195eebc177549ede5cf3189814850c76f9a340d7ce"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-win32.whl", hash = "sha256:fa107460c842e4c1a6266150881694fefd4f33baa544ea9489601810c2210ef8"}, + {file = "grpcio_tools-1.62.2-cp312-cp312-win_amd64.whl", hash = "sha256:759c60f24c33a181bbbc1232a6752f9b49fbb1583312a4917e2b389fea0fb0f2"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-linux_armv7l.whl", hash = "sha256:45db5da2bcfa88f2b86b57ef35daaae85c60bd6754a051d35d9449c959925b57"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:ab84bae88597133f6ea7a2bdc57b2fda98a266fe8d8d4763652cbefd20e73ad7"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:7a49bccae1c7d154b78e991885c3111c9ad8c8fa98e91233de425718f47c6139"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7e439476b29d6dac363b321781a113794397afceeb97dad85349db5f1cb5e9a"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ea369c4d1567d1acdf69c8ea74144f4ccad9e545df7f9a4fc64c94fa7684ba3"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4f955702dc4b530696375251319d05223b729ed24e8673c2129f7a75d2caefbb"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:3708a747aa4b6b505727282ca887041174e146ae030ebcadaf4c1d346858df62"}, + {file = "grpcio_tools-1.62.2-cp37-cp37m-win_amd64.whl", hash = "sha256:2ce149ea55eadb486a7fb75a20f63ef3ac065ee6a0240ed25f3549ce7954c653"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:58cbb24b3fa6ae35aa9c210fcea3a51aa5fef0cd25618eb4fd94f746d5a9b703"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:6413581e14a80e0b4532577766cf0586de4dd33766a31b3eb5374a746771c07d"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:47117c8a7e861382470d0e22d336e5a91fdc5f851d1db44fa784b9acea190d87"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f1ba79a253df9e553d20319c615fa2b429684580fa042dba618d7f6649ac7e4"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:04a394cf5e51ba9be412eb9f6c482b6270bd81016e033e8eb7d21b8cc28fe8b5"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3c53b221378b035ae2f1881cbc3aca42a6075a8e90e1a342c2f205eb1d1aa6a1"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c384c838b34d1b67068e51b5bbe49caa6aa3633acd158f1ab16b5da8d226bc53"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-win32.whl", hash = "sha256:19ea69e41c3565932aa28a202d1875ec56786aea46a2eab54a3b28e8a27f9517"}, + {file = "grpcio_tools-1.62.2-cp38-cp38-win_amd64.whl", hash = "sha256:1d768a5c07279a4c461ebf52d0cec1c6ca85c6291c71ec2703fe3c3e7e28e8c4"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:5b07b5874187e170edfbd7aa2ca3a54ebf3b2952487653e8c0b0d83601c33035"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:d58389fe8be206ddfb4fa703db1e24c956856fcb9a81da62b13577b3a8f7fda7"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:7d8b4e00c3d7237b92260fc18a561cd81f1da82e8be100db1b7d816250defc66"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fe08d2038f2b7c53259b5c49e0ad08c8e0ce2b548d8185993e7ef67e8592cca"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:19216e1fb26dbe23d12a810517e1b3fbb8d4f98b1a3fbebeec9d93a79f092de4"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b8574469ecc4ff41d6bb95f44e0297cdb0d95bade388552a9a444db9cd7485cd"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4f6f32d39283ea834a493fccf0ebe9cfddee7577bdcc27736ad4be1732a36399"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-win32.whl", hash = "sha256:76eb459bdf3fb666e01883270beee18f3f11ed44488486b61cd210b4e0e17cc1"}, + {file = "grpcio_tools-1.62.2-cp39-cp39-win_amd64.whl", hash = "sha256:217c2ee6a7ce519a55958b8622e21804f6fdb774db08c322f4c9536c35fdce7c"}, +] + +[package.dependencies] +grpcio = ">=1.62.2" +protobuf = ">=4.21.6,<5.0dev" +setuptools = "*" + [[package]] name = "h11" version = "0.14.0" @@ -1180,6 +1248,32 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "h2" +version = "4.1.0" +description = "HTTP/2 State-Machine based protocol implementation" +optional = false +python-versions = ">=3.6.1" +files = [ + {file = "h2-4.1.0-py3-none-any.whl", hash = "sha256:03a46bcf682256c95b5fd9e9a99c1323584c3eec6440d379b9903d709476bc6d"}, + {file = "h2-4.1.0.tar.gz", hash = "sha256:a83aca08fbe7aacb79fec788c9c0bac936343560ed9ec18b82a13a12c28d2abb"}, +] + +[package.dependencies] +hpack = ">=4.0,<5" +hyperframe = ">=6.0,<7" + +[[package]] +name = "hpack" +version = "4.0.0" +description = "Pure-Python HPACK header compression" +optional = false +python-versions = ">=3.6.1" +files = [ + {file = "hpack-4.0.0-py3-none-any.whl", hash = "sha256:84a076fad3dc9a9f8063ccb8041ef100867b1878b25ef0ee63847a5d53818a6c"}, + {file = "hpack-4.0.0.tar.gz", hash = "sha256:fc41de0c63e687ebffde81187a948221294896f6bdc0ae2312708df339430095"}, +] + [[package]] name = "httpcore" version = "1.0.5" @@ -1263,6 +1357,7 @@ files = [ [package.dependencies] anyio = "*" certifi = "*" +h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""} httpcore = "==1.*" idna = "*" sniffio = "*" @@ -1321,6 +1416,17 @@ files = [ [package.dependencies] pyreadline3 = {version = "*", markers = "sys_platform == \"win32\" and python_version >= \"3.8\""} +[[package]] +name = "hyperframe" +version = "6.0.1" +description = "HTTP/2 framing layer for Python" +optional = false +python-versions = ">=3.6.1" +files = [ + {file = "hyperframe-6.0.1-py3-none-any.whl", hash = "sha256:0ec6bafd80d8ad2195c4f03aacba3a8265e57bc4cff261e802bf39970ed02a15"}, + {file = "hyperframe-6.0.1.tar.gz", hash = "sha256:ae510046231dc8e9ecb1a6586f63d2347bf4c8905914aa84ba585ae85f28a914"}, +] + [[package]] name = "idna" version = "3.7" @@ -2693,6 +2799,25 @@ files = [ packaging = "*" tenacity = ">=6.2.0" +[[package]] +name = "portalocker" +version = "2.8.2" +description = "Wraps the portalocker recipe for easy usage" +optional = false +python-versions = ">=3.8" +files = [ + {file = "portalocker-2.8.2-py3-none-any.whl", hash = "sha256:cfb86acc09b9aa7c3b43594e19be1345b9d16af3feb08bf92f23d4dce513a28e"}, + {file = "portalocker-2.8.2.tar.gz", hash = "sha256:2b035aa7828e46c58e9b31390ee1f169b98e1066ab10b9a6a861fe7e25ee4f33"}, +] + +[package.dependencies] +pywin32 = {version = ">=226", markers = "platform_system == \"Windows\""} + +[package.extras] +docs = ["sphinx (>=1.7.1)"] +redis = ["redis"] +tests = ["pytest (>=5.4.1)", "pytest-cov (>=2.8.1)", "pytest-mypy (>=0.8.0)", "pytest-timeout (>=2.1.0)", "redis", "sphinx (>=6.0.0)", "types-redis"] + [[package]] name = "posthog" version = "3.5.0" @@ -3038,6 +3163,29 @@ files = [ {file = "pytz-2024.1.tar.gz", hash = "sha256:2a29735ea9c18baf14b448846bde5a48030ed267578472d8955cd0e7443a9812"}, ] +[[package]] +name = "pywin32" +version = "306" +description = "Python for Window Extensions" +optional = false +python-versions = "*" +files = [ + {file = "pywin32-306-cp310-cp310-win32.whl", hash = "sha256:06d3420a5155ba65f0b72f2699b5bacf3109f36acbe8923765c22938a69dfc8d"}, + {file = "pywin32-306-cp310-cp310-win_amd64.whl", hash = "sha256:84f4471dbca1887ea3803d8848a1616429ac94a4a8d05f4bc9c5dcfd42ca99c8"}, + {file = "pywin32-306-cp311-cp311-win32.whl", hash = "sha256:e65028133d15b64d2ed8f06dd9fbc268352478d4f9289e69c190ecd6818b6407"}, + {file = "pywin32-306-cp311-cp311-win_amd64.whl", hash = "sha256:a7639f51c184c0272e93f244eb24dafca9b1855707d94c192d4a0b4c01e1100e"}, + {file = "pywin32-306-cp311-cp311-win_arm64.whl", hash = "sha256:70dba0c913d19f942a2db25217d9a1b726c278f483a919f1abfed79c9cf64d3a"}, + {file = "pywin32-306-cp312-cp312-win32.whl", hash = "sha256:383229d515657f4e3ed1343da8be101000562bf514591ff383ae940cad65458b"}, + {file = "pywin32-306-cp312-cp312-win_amd64.whl", hash = "sha256:37257794c1ad39ee9be652da0462dc2e394c8159dfd913a8a4e8eb6fd346da0e"}, + {file = "pywin32-306-cp312-cp312-win_arm64.whl", hash = "sha256:5821ec52f6d321aa59e2db7e0a35b997de60c201943557d108af9d4ae1ec7040"}, + {file = "pywin32-306-cp37-cp37m-win32.whl", hash = "sha256:1c73ea9a0d2283d889001998059f5eaaba3b6238f767c9cf2833b13e6a685f65"}, + {file = "pywin32-306-cp37-cp37m-win_amd64.whl", hash = "sha256:72c5f621542d7bdd4fdb716227be0dd3f8565c11b280be6315b06ace35487d36"}, + {file = "pywin32-306-cp38-cp38-win32.whl", hash = "sha256:e4c092e2589b5cf0d365849e73e02c391c1349958c5ac3e9d5ccb9a28e017b3a"}, + {file = "pywin32-306-cp38-cp38-win_amd64.whl", hash = "sha256:e8ac1ae3601bee6ca9f7cb4b5363bf1c0badb935ef243c4733ff9a393b1690c0"}, + {file = "pywin32-306-cp39-cp39-win32.whl", hash = "sha256:e25fd5b485b55ac9c057f67d94bc203f3f6595078d1fb3b458c9c28b7153a802"}, + {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, +] + [[package]] name = "pyyaml" version = "6.0.1" @@ -3098,6 +3246,32 @@ files = [ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] +[[package]] +name = "qdrant-client" +version = "1.9.1" +description = "Client library for the Qdrant vector search engine" +optional = false +python-versions = ">=3.8" +files = [ + {file = "qdrant_client-1.9.1-py3-none-any.whl", hash = "sha256:b9b7e0e5c1a51410d8bb5106a869a51e12f92ab45a99030f27aba790553bd2c8"}, + {file = "qdrant_client-1.9.1.tar.gz", hash = "sha256:186b9c31d95aefe8f2db84b7746402d7365bd63b305550e530e31bde2002ce79"}, +] + +[package.dependencies] +grpcio = ">=1.41.0" +grpcio-tools = ">=1.41.0" +httpx = {version = ">=0.20.0", extras = ["http2"]} +numpy = [ + {version = ">=1.21", markers = "python_version >= \"3.8\" and python_version < \"3.12\""}, + {version = ">=1.26", markers = "python_version >= \"3.12\""}, +] +portalocker = ">=2.7.0,<3.0.0" +pydantic = ">=1.10.8" +urllib3 = ">=1.26.14,<3" + +[package.extras] +fastembed = ["fastembed (==0.2.6)"] + [[package]] name = "regex" version = "2024.5.10" @@ -4615,4 +4789,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "639d8fcd761b8b12f4ad8b99a0d2160104ec4dc5edf896ba455cf9f649990fb5" +content-hash = "4212c804077790a60810a32ce45052141628e28c4659b93d04e250d897cb4ede" diff --git a/pyproject.toml b/pyproject.toml index d73c43b..f855d2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ scikit-learn = "^1.5.0" sentence-transformers = "^3.0.1" torch = "^2.3.1" warcio = "^1.7.4" +qdrant-client = "^1.9.1" [tool.poetry.group.dev.dependencies] black = "^24.3.0" diff --git a/warc_gpt/commands/ingest.py b/warc_gpt/commands/ingest.py index d992582..8268944 100644 --- a/warc_gpt/commands/ingest.py +++ b/warc_gpt/commands/ingest.py @@ -9,7 +9,6 @@ from shutil import rmtree import click -import chromadb from bs4 import BeautifulSoup from bs4 import Comment as HTMLComment from pypdf import PdfReader @@ -21,6 +20,7 @@ from statistics import mean from warc_gpt import WARC_RECORD_DATA +from warc_gpt.utils.vector_storage import VectorStorage @current_app.cli.command("ingest") @@ -44,7 +44,7 @@ def ingest(batch_size) -> None: warc_files = [] embedding_model = None chroma_client = None - chroma_collection = None + vector_storage = None total_records = 0 total_embeddings = 0 @@ -85,15 +85,7 @@ def ingest(batch_size) -> None: ) # Note: The text splitter adjusts its cut-off based on the models' max_seq_length # Init vector store - chroma_client = chromadb.PersistentClient( - path=environ["VECTOR_SEARCH_PATH"], - settings=chromadb.Settings(anonymized_telemetry=False), - ) - - chroma_collection = chroma_client.create_collection( - name=environ["VECTOR_SEARCH_COLLECTION_NAME"], - metadata={"hnsw:space": environ["VECTOR_SEARCH_DISTANCE_FUNCTION"]}, - ) + vector_storage = VectorStorage.make_storage(environ["VECTOR_SEARCH_DATABASE"]) # # For each WARC: @@ -202,25 +194,20 @@ def ingest(batch_size) -> None: text_chunks = [chunk_prefix + chunk for chunk in text_chunks] # Generate embeddings and metadata for each chunk - ( - documents, - ids, - metadatas, - embeddings, - multi_chunk_mode, - encoding_timings - ) = chunk_objects( - record_data, - text_chunks, - embedding_model, - multi_chunk_mode, - encoding_timings, - batch_size + (documents, ids, metadatas, embeddings, multi_chunk_mode, encoding_timings) = ( + chunk_objects( + record_data, + text_chunks, + embedding_model, + multi_chunk_mode, + encoding_timings, + batch_size, + ) ) total_embeddings += len(embeddings) # Store embeddings and metadata - chroma_collection.add( + vector_storage.ingest( documents=documents, embeddings=embeddings, metadatas=metadatas, @@ -236,7 +223,7 @@ def chunk_objects( embedding_model: SentenceTransformer, multi_chunk_mode: bool, encoding_timings: list[float], - batch_size: int + batch_size: int, ): """ Return one document, metadata, id, and embedding object per chunk; also return @@ -254,7 +241,7 @@ def chunk_objects( ids = [f"{record_data['warc_record_id']}-{i+1}" for i in chunk_range] metadatas = [ - dict(record_data, **{"warc_record_text": text_chunks[i][len(chunk_prefix):]}) + dict(record_data, **{"warc_record_text": text_chunks[i][len(chunk_prefix) :]}) for i in chunk_range ] diff --git a/warc_gpt/commands/visualize.py b/warc_gpt/commands/visualize.py index 2acdbb3..f19609e 100644 --- a/warc_gpt/commands/visualize.py +++ b/warc_gpt/commands/visualize.py @@ -1,16 +1,17 @@ """ `commands.visualize` module: Controller for the `visualize` CLI command. """ + import os from textwrap import wrap import click -import chromadb import pandas as pd from sklearn.manifold import TSNE import plotly.express as px from sentence_transformers import SentenceTransformer from flask import current_app +from warc_gpt.utils.vector_storage import VectorStorage @current_app.cli.command("visualize") @@ -26,7 +27,7 @@ default=30.0, type=float, help="TSNE default setting; reduce for small input sets.", - show_default=True + show_default=True, ) def visualize(questions: str, perplexity: float) -> None: """ @@ -51,15 +52,10 @@ def visualize(questions: str, perplexity: float) -> None: ) # Init vector store - chroma_client = chromadb.PersistentClient( - path=environ["VECTOR_SEARCH_PATH"], - settings=chromadb.Settings(anonymized_telemetry=False), - ) - - chroma_collection = chroma_client.get_collection(name=environ["VECTOR_SEARCH_COLLECTION_NAME"]) + vector_storage = VectorStorage.make_storage(environ["VECTOR_SEARCH_DATABASE"]) # Pull everything out of the vector store - all_vectors = chroma_collection.get(include=["metadatas", "documents", "embeddings"]) + all_vectors = vector_storage.get_all() # # If a question was provided, generate embeddings for it so it can be placed on the plot @@ -83,8 +79,10 @@ def visualize(questions: str, perplexity: float) -> None: try: scatter_plot_data = TSNE(perplexity=perplexity).fit_transform(scatter_plot_data) except ValueError as e: - if f'{e}' == "perplexity must be less than n_samples": - click.echo("You may not have enough input data; add some or reduce perplexity to less than n_samples.") # noqa + if f"{e}" == "perplexity must be less than n_samples": + click.echo( + "You may not have enough input data; add some or reduce perplexity to less than n_samples." + ) # noqa return 1 else: raise diff --git a/warc_gpt/utils/vector_storage/__init__.py b/warc_gpt/utils/vector_storage/__init__.py new file mode 100644 index 0000000..b29911e --- /dev/null +++ b/warc_gpt/utils/vector_storage/__init__.py @@ -0,0 +1,5 @@ +from warc_gpt.utils.vector_storage.base import VectorStorage +from warc_gpt.utils.vector_storage.chroma import Chroma +from warc_gpt.utils.vector_storage.qdrant import Qdrant + +__all__ = ["VectorStorage", "Chroma", "Qdrant"] diff --git a/warc_gpt/utils/vector_storage/base.py b/warc_gpt/utils/vector_storage/base.py new file mode 100644 index 0000000..d9e55d6 --- /dev/null +++ b/warc_gpt/utils/vector_storage/base.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Mapping, Sequence, TypedDict, Union + +Embedding = Sequence[float] +Metadata = Mapping[str, Union[str, int, float, bool]] + + +class StorageResponse(TypedDict): + ids: List[str] + embeddings: List[Embedding] + documents: List[str] + metadatas: List[Metadata] + + +class VectorStorage(ABC): + """ + Abstract class to define a common interface for vector databases. + """ + + @abstractmethod + def ingest( + self, + documents: List[str], + embeddings: List[Sequence[float]], + metadatas: List[Dict], + ids: List[str], + ) -> None: + """Stores the data to a vector storage.""" + raise NotImplementedError() + + @abstractmethod + def get_all(self) -> StorageResponse: + """Returns all the data stored in vector storage.""" + raise NotImplementedError() + + @abstractmethod + def search(self, query_embedding: Embedding, limit: int) -> StorageResponse: + """Returns semantically similar data from a vector storage with limit.""" + raise NotImplementedError() + + @classmethod + def make_storage(cls, name, *args, **kwargs): + for sc in cls.__subclasses__(): + if sc.__name__.lower() == name.lower(): + return sc(*args, **kwargs) + raise ValueError("Vector storage implemementation not found.") diff --git a/warc_gpt/utils/vector_storage/chroma.py b/warc_gpt/utils/vector_storage/chroma.py new file mode 100644 index 0000000..d5386fe --- /dev/null +++ b/warc_gpt/utils/vector_storage/chroma.py @@ -0,0 +1,49 @@ +from os import environ +from typing import Dict, List, Sequence + +import chromadb +from warc_gpt.utils.vector_storage.base import VectorStorage, StorageResponse + + +class Chroma(VectorStorage): + def __init__(self): + chroma_client = chromadb.PersistentClient( + path=environ["VECTOR_SEARCH_PATH"], + settings=chromadb.Settings(anonymized_telemetry=False), + ) + + self.chroma_collection = chroma_client.get_or_create_collection( + name=environ["VECTOR_SEARCH_COLLECTION_NAME"], + metadata={"hnsw:space": environ["VECTOR_SEARCH_DISTANCE_FUNCTION"]}, + ) + + def ingest( + self, + documents: List[str], + embeddings: List[Sequence[float]], + metadatas: List[Dict], + ids: List[str], + ) -> None: + return self.chroma_collection.add( + documents=documents, + embeddings=embeddings, + metadatas=metadatas, + ids=ids, + ) + + def get_all(self) -> StorageResponse: + return self.chroma_collection.get(include=["metadatas", "documents", "embeddings"]) + + def search(self, query_embedding: Sequence[float], limit: int) -> StorageResponse: + results = self.chroma_collection.query( + query_embeddings=query_embedding, + n_results=limit, + include=["documents", "metadatas", "embeddings"], + ) + + return { + "documents": results["documents"][0], + "embeddings": results["embeddings"][0], + "metadatas": results["metadatas"][0], + "ids": results["ids"][0], + } diff --git a/warc_gpt/utils/vector_storage/qdrant.py b/warc_gpt/utils/vector_storage/qdrant.py new file mode 100644 index 0000000..c998d47 --- /dev/null +++ b/warc_gpt/utils/vector_storage/qdrant.py @@ -0,0 +1,144 @@ +from copy import deepcopy +from os import environ, getenv +from typing import Dict, List, Sequence +import uuid + +from qdrant_client import QdrantClient, models +from warc_gpt.utils.vector_storage.base import VectorStorage, StorageResponse + +ID_KEY = "_id" +DOCUMENT_KEY = "_document" +SCROLL_SIZE = 64 + + +class Qdrant(VectorStorage): + def __init__(self): + self.collection_name = environ["VECTOR_SEARCH_COLLECTION_NAME"] + self.client = QdrantClient( + location=getenv("QDRANT_LOCATION"), + port=int(getenv("QDRANT_PORT", 6333)), + grpc_port=int(getenv("QDRANT_GRPC_PORT", 6334)), + prefer_grpc=bool(getenv("QDRANT_PREFER_GRPC", False)), + https=bool(getenv("QDRANT_HTTPS", False)), + api_key=getenv("QDRANT_API_KEY"), + prefix=getenv("QDRANT_PREFIX"), + timeout=int(getenv("QDRANT_TIMEOUT", 0)) or None, + host=getenv("QDRANT_HOST"), + path=getenv("QDRANT_PATH"), + ) + + def ingest( + self, + documents: List[str], + embeddings: List[Sequence[float]], + metadatas: List[Dict], + ids: List[str], + ) -> None: + self._ensure_collection(size=len(embeddings[0])) + payloads = [ + {**metadata, ID_KEY: id, DOCUMENT_KEY: document} + for document, metadata, id in zip(documents, metadatas, ids) + ] + + # Qdrant onlly allows UUIDs and unsigned integers as point IDs + # https://qdrant.tech/documentation/concepts/points/#point-ids + ids = [uuid.uuid4().hex for _ in ids] + + self.client.upsert( + self.collection_name, + points=models.Batch(ids=ids, vectors=embeddings, payloads=payloads), + ) + + def get_all(self) -> StorageResponse: + points = self._scroll_points() + + ids, embeddings, documents, metadatas = [], [], [], [] + + for point in points: + payload = deepcopy(point.payload) + ids.append(payload.pop(ID_KEY)) + documents.append(payload.pop(DOCUMENT_KEY)) + metadatas.append(payload) + embeddings.append(point.vector) + + return { + "documents": documents, + "embeddings": embeddings, + "ids": ids, + "metadatas": metadatas, + } + + def search(self, query_embedding: Sequence[float], limit: int) -> StorageResponse: + points = self.client.search( + self.collection_name, + query_vector=query_embedding, + with_payload=True, + with_vectors=True, + limit=True, + ) + ids, embeddings, documents, metadatas = [], [], [], [] + + for point in points: + payload = deepcopy(point.payload) + ids.append(payload.pop(ID_KEY)) + documents.append(payload.pop(DOCUMENT_KEY)) + metadatas.append(payload) + embeddings.append(point.vector) + + return { + "documents": documents, + "embeddings": embeddings, + "ids": ids, + "metadatas": metadatas, + } + + def _scroll_points(self) -> List[models.Record]: + """ + Scroll through and return all points in a collection + """ + + from qdrant_client import grpc + + records = [] + next_offset = None + stop_scrolling = False + while not stop_scrolling: + response, next_offset = self.client.scroll( + self.collection_name, + limit=SCROLL_SIZE, + offset=next_offset, + with_payload=True, + with_vectors=True, + ) + + stop_scrolling = next_offset is None or ( + isinstance(next_offset, grpc.PointId) + and next_offset.num == 0 + and next_offset.uuid == "" + ) + + records.extend(response) + + return records + + def _ensure_collection(self, size: int): + if not self.client.collection_exists(self.collection_name): + distance = self._convert_metric(getenv("VECTOR_SEARCH_DISTANCE_FUNCTION", "cosine")) + self.client.create_collection( + collection_name=self.collection_name, + vectors_config=models.VectorParams(size=size, distance=distance), + ) + + def _convert_metric(self, metric: str): + from qdrant_client.models import Distance + + mapping = { + "cosine": Distance.COSINE, + "l2": Distance.EUCLID, + "ip": Distance.DOT, + } + + if metric not in mapping: + raise ValueError(f"Unsupported Qdrant similarity metric: {metric}") + + return mapping[metric] diff --git a/warc_gpt/views/api/search.py b/warc_gpt/views/api/search.py index a7ce9f8..9417ac4 100644 --- a/warc_gpt/views/api/search.py +++ b/warc_gpt/views/api/search.py @@ -1,19 +1,19 @@ import os +from os import environ import traceback -import chromadb from sentence_transformers import SentenceTransformer from flask import current_app, jsonify, request from warc_gpt import WARC_RECORD_DATA from warc_gpt.utils import get_limiter +from warc_gpt.utils.vector_storage import VectorStorage API_SEARCH_RATE_LIMIT = os.environ["API_SEARCH_RATE_LIMIT"] vector_store_cache = { - "chroma_client": None, + "vector_storage": None, "embedding_model": None, - "chroma_collection": None, } """ Module-level "caching" for vector store connection. """ @@ -29,10 +29,8 @@ def post_search(): Returns a JSON object of WARC_RECORD_DATA entries. """ - environ = os.environ - chroma_client = None - chroma_collection = None + vector_storage = None embedding_model = None input = request.get_json() @@ -76,34 +74,16 @@ def post_search(): # Chroma client try: - if vector_store_cache.get("chroma_client", None) is None: - chroma_client = chromadb.PersistentClient( - path=environ["VECTOR_SEARCH_PATH"], - settings=chromadb.Settings(anonymized_telemetry=False), - ) - vector_store_cache["chroma_client"] = chroma_client - else: - chroma_client = vector_store_cache["chroma_client"] - - assert chroma_client - except Exception: - current_app.logger.debug(traceback.format_exc()) - return jsonify({"error": "Could not load ChromaDB client."}), 500 - - # Chroma collection - try: - if vector_store_cache.get("chroma_collection", None) is None: - chroma_collection = chroma_client.get_collection( - name=environ["VECTOR_SEARCH_COLLECTION_NAME"], - ) - vector_store_cache["chroma_collection"] = chroma_collection + if vector_store_cache.get("vector_storage", None) is None: + vector_storage = VectorStorage.make_storage(environ["VECTOR_SEARCH_DATABASE"]) + vector_store_cache["vector_storage"] = vector_storage else: - chroma_collection = vector_store_cache["chroma_collection"] + vector_storage = vector_store_cache["vector_storage"] - assert chroma_collection + assert vector_storage except Exception: current_app.logger.debug(traceback.format_exc()) - return jsonify({"error": "Could not load ChromaDB collection."}), 500 + return jsonify({"error": "Could not load vector storage."}), 500 # # Retrieve context chunks @@ -114,19 +94,19 @@ def post_search(): normalize_embeddings=normalize_embeddings, ).tolist() - vector_search_results = chroma_collection.query( - query_embeddings=message_embedding, - n_results=int(environ["VECTOR_SEARCH_SEARCH_N_RESULTS"]), + vector_search_results = vector_storage.search( + query_embedding=message_embedding, + limit=int(environ["VECTOR_SEARCH_SEARCH_N_RESULTS"]), ) + except Exception: current_app.logger.debug(traceback.format_exc()) return jsonify({"error": "Could not retrieve context from vector store."}), 500 - # # Filter and return metadata # if vector_search_results: - for vector in vector_search_results["metadatas"][0]: + for vector in vector_search_results["metadatas"]: metadata = {} for key in WARC_RECORD_DATA.keys(): diff --git a/wsgi.py b/wsgi.py index 7958444..32bc18f 100644 --- a/wsgi.py +++ b/wsgi.py @@ -1,4 +1,5 @@ """ WSGI hook """ + from warc_gpt import create_app if __name__ == "__main__":