diff --git a/appinfo/info.xml b/appinfo/info.xml
index 3eb5a30..6fb1534 100644
--- a/appinfo/info.xml
+++ b/appinfo/info.xml
@@ -25,7 +25,7 @@ Install the given apps for Context Chat to work as desired **in the given order*
https://github.com/nextcloud/context_chat_backend/issues
https://github.com/nextcloud/context_chat_backend.git
-
+
diff --git a/config.cpu.yaml b/config.cpu.yaml
index 4bb73c4..0ff0d7c 100644
--- a/config.cpu.yaml
+++ b/config.cpu.yaml
@@ -4,6 +4,7 @@ httpx_verify_ssl: true
model_offload_timeout: 15 # 15 minutes
use_colors: true
uvicorn_workers: 1
+embedding_chunk_size: 1000
# model files download configuration
disable_custom_model_download: false
diff --git a/config.gpu.yaml b/config.gpu.yaml
index e8cff47..9910f61 100644
--- a/config.gpu.yaml
+++ b/config.gpu.yaml
@@ -4,6 +4,7 @@ httpx_verify_ssl: true
model_offload_timeout: 15 # 15 minutes
use_colors: true
uvicorn_workers: 1
+embedding_chunk_size: 1000
# model files download configuration
disable_custom_model_download: false
diff --git a/context_chat_backend/chain/ingest/doc_splitter.py b/context_chat_backend/chain/ingest/doc_splitter.py
index 6d95d04..78354fe 100644
--- a/context_chat_backend/chain/ingest/doc_splitter.py
+++ b/context_chat_backend/chain/ingest/doc_splitter.py
@@ -5,10 +5,10 @@
)
-def get_splitter_for(mimetype: str = 'text/plain') -> TextSplitter:
+def get_splitter_for(chunk_size: int, mimetype: str = 'text/plain') -> TextSplitter:
kwargs = {
- 'chunk_size': 2000,
- 'chunk_overlap': 200,
+ 'chunk_size': chunk_size,
+ 'chunk_overlap': int(chunk_size / 10),
'add_start_index': True,
'strip_whitespace': True,
'is_separator_regex': True,
diff --git a/context_chat_backend/chain/ingest/injest.py b/context_chat_backend/chain/ingest/injest.py
index c8668e5..f3cc607 100644
--- a/context_chat_backend/chain/ingest/injest.py
+++ b/context_chat_backend/chain/ingest/injest.py
@@ -4,6 +4,7 @@
from fastapi.datastructures import UploadFile
from langchain.schema import Document
+from ...config_parser import TConfig
from ...utils import not_none, to_int
from ...vectordb import BaseVectorDB
from .doc_loader import decode_source
@@ -111,7 +112,7 @@ def _bucket_by_type(documents: list[Document]) -> dict[str, list[Document]]:
return bucketed_documents
-def _process_sources(vectordb: BaseVectorDB, sources: list[UploadFile]) -> bool:
+def _process_sources(vectordb: BaseVectorDB, config: TConfig, sources: list[UploadFile]) -> bool:
filtered_sources = _filter_sources(sources[0].headers['userId'], vectordb, sources)
if len(filtered_sources) == 0:
@@ -132,7 +133,7 @@ def _process_sources(vectordb: BaseVectorDB, sources: list[UploadFile]) -> bool:
type_bucketed_docs = _bucket_by_type(documents)
for _type, _docs in type_bucketed_docs.items():
- text_splitter = get_splitter_for(_type)
+ text_splitter = get_splitter_for(config['embedding_chunk_size'], _type)
split_docs = text_splitter.split_documents(_docs)
split_documents.extend(split_docs)
@@ -158,6 +159,7 @@ def _process_sources(vectordb: BaseVectorDB, sources: list[UploadFile]) -> bool:
def embed_sources(
vectordb: BaseVectorDB,
+ config: TConfig,
sources: list[UploadFile],
) -> bool:
# either not a file or a file that is allowed
@@ -172,4 +174,4 @@ def embed_sources(
'\n'.join([f'{source.filename} ({source.headers.get("title", "")})' for source in sources_filtered]),
flush=True,
)
- return _process_sources(vectordb, sources_filtered)
+ return _process_sources(vectordb, config, sources_filtered)
diff --git a/context_chat_backend/config_parser.py b/context_chat_backend/config_parser.py
index 30c4de2..51091d6 100644
--- a/context_chat_backend/config_parser.py
+++ b/context_chat_backend/config_parser.py
@@ -13,6 +13,7 @@ class TConfig(TypedDict):
model_offload_timeout: int
use_colors: bool
uvicorn_workers: int
+ embedding_chunk_size: int
# model files download configuration
disable_custom_model_download: bool
@@ -74,6 +75,7 @@ def get_config(file_path: str) -> TConfig:
'model_offload_timeout': config.get('model_offload_timeout', 15),
'use_colors': config.get('use_colors', True),
'uvicorn_workers': config.get('uvicorn_workers', 1),
+ 'embedding_chunk_size': config.get('embedding_chunk_size', 1000),
'disable_custom_model_download': config.get('disable_custom_model_download', False),
'model_download_uri': config.get('model_download_uri', 'https://download.nextcloud.com/server/apps/context_chat_backend'),
diff --git a/context_chat_backend/controller.py b/context_chat_backend/controller.py
index 0b67bda..e93a843 100644
--- a/context_chat_backend/controller.py
+++ b/context_chat_backend/controller.py
@@ -273,7 +273,7 @@ def _(sources: list[UploadFile]):
return JSONResponse('Invaild/missing headers', 400)
db: BaseVectorDB = vectordb_loader.load()
- result = embed_sources(db, sources)
+ result = embed_sources(db, app.extra['CONFIG'], sources)
if not result:
return JSONResponse('Error: All sources were not loaded, check logs for more info', 500)
@@ -305,40 +305,47 @@ def at_least_one_context(cls, value: int):
return value
+def execute_query(query: Query) -> LLMOutput:
+ # todo: migrate to Depends during db schema change
+ llm: LLM = llm_loader.load()
+
+ template = app.extra.get('LLM_TEMPLATE')
+ no_ctx_template = app.extra['LLM_NO_CTX_TEMPLATE']
+ # todo: array
+ end_separator = app.extra.get('LLM_END_SEPARATOR', '')
+
+ if query.useContext:
+ db: BaseVectorDB = vectordb_loader.load()
+ return process_context_query(
+ user_id=query.userId,
+ vectordb=db,
+ llm=llm,
+ app_config=app_config,
+ query=query.query,
+ ctx_limit=query.ctxLimit,
+ template=template,
+ end_separator=end_separator,
+ scope_type=query.scopeType,
+ scope_list=query.scopeList,
+ )
+
+ return process_query(
+ llm=llm,
+ app_config=app_config,
+ query=query.query,
+ no_ctx_template=no_ctx_template,
+ end_separator=end_separator,
+ )
+
+
@app.post('/query')
@enabled_guard(app)
def _(query: Query) -> LLMOutput:
global llm_lock
print('query:', query, flush=True)
+ if app_config['llm'][0] == 'nc_texttotext':
+ return execute_query(query)
+
with llm_lock:
- # todo: migrate to Depends during db schema change
- llm: LLM = llm_loader.load()
-
- template = app.extra.get('LLM_TEMPLATE')
- no_ctx_template = app.extra['LLM_NO_CTX_TEMPLATE']
- # todo: array
- end_separator = app.extra.get('LLM_END_SEPARATOR', '')
-
- if query.useContext:
- db: BaseVectorDB = vectordb_loader.load()
- return process_context_query(
- user_id=query.userId,
- vectordb=db,
- llm=llm,
- app_config=app_config,
- query=query.query,
- ctx_limit=query.ctxLimit,
- template=template,
- end_separator=end_separator,
- scope_type=query.scopeType,
- scope_list=query.scopeList,
- )
-
- return process_query(
- llm=llm,
- app_config=app_config,
- query=query.query,
- no_ctx_template=no_ctx_template,
- end_separator=end_separator,
- )
+ return execute_query(query)