diff --git a/appinfo/info.xml b/appinfo/info.xml index 3eb5a30..6fb1534 100644 --- a/appinfo/info.xml +++ b/appinfo/info.xml @@ -25,7 +25,7 @@ Install the given apps for Context Chat to work as desired **in the given order* https://github.com/nextcloud/context_chat_backend/issues https://github.com/nextcloud/context_chat_backend.git - + diff --git a/config.cpu.yaml b/config.cpu.yaml index 4bb73c4..0ff0d7c 100644 --- a/config.cpu.yaml +++ b/config.cpu.yaml @@ -4,6 +4,7 @@ httpx_verify_ssl: true model_offload_timeout: 15 # 15 minutes use_colors: true uvicorn_workers: 1 +embedding_chunk_size: 1000 # model files download configuration disable_custom_model_download: false diff --git a/config.gpu.yaml b/config.gpu.yaml index e8cff47..9910f61 100644 --- a/config.gpu.yaml +++ b/config.gpu.yaml @@ -4,6 +4,7 @@ httpx_verify_ssl: true model_offload_timeout: 15 # 15 minutes use_colors: true uvicorn_workers: 1 +embedding_chunk_size: 1000 # model files download configuration disable_custom_model_download: false diff --git a/context_chat_backend/chain/ingest/doc_splitter.py b/context_chat_backend/chain/ingest/doc_splitter.py index 6d95d04..78354fe 100644 --- a/context_chat_backend/chain/ingest/doc_splitter.py +++ b/context_chat_backend/chain/ingest/doc_splitter.py @@ -5,10 +5,10 @@ ) -def get_splitter_for(mimetype: str = 'text/plain') -> TextSplitter: +def get_splitter_for(chunk_size: int, mimetype: str = 'text/plain') -> TextSplitter: kwargs = { - 'chunk_size': 2000, - 'chunk_overlap': 200, + 'chunk_size': chunk_size, + 'chunk_overlap': int(chunk_size / 10), 'add_start_index': True, 'strip_whitespace': True, 'is_separator_regex': True, diff --git a/context_chat_backend/chain/ingest/injest.py b/context_chat_backend/chain/ingest/injest.py index c8668e5..f3cc607 100644 --- a/context_chat_backend/chain/ingest/injest.py +++ b/context_chat_backend/chain/ingest/injest.py @@ -4,6 +4,7 @@ from fastapi.datastructures import UploadFile from langchain.schema import Document +from ...config_parser import TConfig from ...utils import not_none, to_int from ...vectordb import BaseVectorDB from .doc_loader import decode_source @@ -111,7 +112,7 @@ def _bucket_by_type(documents: list[Document]) -> dict[str, list[Document]]: return bucketed_documents -def _process_sources(vectordb: BaseVectorDB, sources: list[UploadFile]) -> bool: +def _process_sources(vectordb: BaseVectorDB, config: TConfig, sources: list[UploadFile]) -> bool: filtered_sources = _filter_sources(sources[0].headers['userId'], vectordb, sources) if len(filtered_sources) == 0: @@ -132,7 +133,7 @@ def _process_sources(vectordb: BaseVectorDB, sources: list[UploadFile]) -> bool: type_bucketed_docs = _bucket_by_type(documents) for _type, _docs in type_bucketed_docs.items(): - text_splitter = get_splitter_for(_type) + text_splitter = get_splitter_for(config['embedding_chunk_size'], _type) split_docs = text_splitter.split_documents(_docs) split_documents.extend(split_docs) @@ -158,6 +159,7 @@ def _process_sources(vectordb: BaseVectorDB, sources: list[UploadFile]) -> bool: def embed_sources( vectordb: BaseVectorDB, + config: TConfig, sources: list[UploadFile], ) -> bool: # either not a file or a file that is allowed @@ -172,4 +174,4 @@ def embed_sources( '\n'.join([f'{source.filename} ({source.headers.get("title", "")})' for source in sources_filtered]), flush=True, ) - return _process_sources(vectordb, sources_filtered) + return _process_sources(vectordb, config, sources_filtered) diff --git a/context_chat_backend/config_parser.py b/context_chat_backend/config_parser.py index 30c4de2..51091d6 100644 --- a/context_chat_backend/config_parser.py +++ b/context_chat_backend/config_parser.py @@ -13,6 +13,7 @@ class TConfig(TypedDict): model_offload_timeout: int use_colors: bool uvicorn_workers: int + embedding_chunk_size: int # model files download configuration disable_custom_model_download: bool @@ -74,6 +75,7 @@ def get_config(file_path: str) -> TConfig: 'model_offload_timeout': config.get('model_offload_timeout', 15), 'use_colors': config.get('use_colors', True), 'uvicorn_workers': config.get('uvicorn_workers', 1), + 'embedding_chunk_size': config.get('embedding_chunk_size', 1000), 'disable_custom_model_download': config.get('disable_custom_model_download', False), 'model_download_uri': config.get('model_download_uri', 'https://download.nextcloud.com/server/apps/context_chat_backend'), diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py index 0b67bda..e93a843 100644 --- a/context_chat_backend/controller.py +++ b/context_chat_backend/controller.py @@ -273,7 +273,7 @@ def _(sources: list[UploadFile]): return JSONResponse('Invaild/missing headers', 400) db: BaseVectorDB = vectordb_loader.load() - result = embed_sources(db, sources) + result = embed_sources(db, app.extra['CONFIG'], sources) if not result: return JSONResponse('Error: All sources were not loaded, check logs for more info', 500) @@ -305,40 +305,47 @@ def at_least_one_context(cls, value: int): return value +def execute_query(query: Query) -> LLMOutput: + # todo: migrate to Depends during db schema change + llm: LLM = llm_loader.load() + + template = app.extra.get('LLM_TEMPLATE') + no_ctx_template = app.extra['LLM_NO_CTX_TEMPLATE'] + # todo: array + end_separator = app.extra.get('LLM_END_SEPARATOR', '') + + if query.useContext: + db: BaseVectorDB = vectordb_loader.load() + return process_context_query( + user_id=query.userId, + vectordb=db, + llm=llm, + app_config=app_config, + query=query.query, + ctx_limit=query.ctxLimit, + template=template, + end_separator=end_separator, + scope_type=query.scopeType, + scope_list=query.scopeList, + ) + + return process_query( + llm=llm, + app_config=app_config, + query=query.query, + no_ctx_template=no_ctx_template, + end_separator=end_separator, + ) + + @app.post('/query') @enabled_guard(app) def _(query: Query) -> LLMOutput: global llm_lock print('query:', query, flush=True) + if app_config['llm'][0] == 'nc_texttotext': + return execute_query(query) + with llm_lock: - # todo: migrate to Depends during db schema change - llm: LLM = llm_loader.load() - - template = app.extra.get('LLM_TEMPLATE') - no_ctx_template = app.extra['LLM_NO_CTX_TEMPLATE'] - # todo: array - end_separator = app.extra.get('LLM_END_SEPARATOR', '') - - if query.useContext: - db: BaseVectorDB = vectordb_loader.load() - return process_context_query( - user_id=query.userId, - vectordb=db, - llm=llm, - app_config=app_config, - query=query.query, - ctx_limit=query.ctxLimit, - template=template, - end_separator=end_separator, - scope_type=query.scopeType, - scope_list=query.scopeList, - ) - - return process_query( - llm=llm, - app_config=app_config, - query=query.query, - no_ctx_template=no_ctx_template, - end_separator=end_separator, - ) + return execute_query(query)