Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make stuff fit in 8GB VRAM and don't lock text2text api calls #70

Merged
merged 2 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion appinfo/info.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Install the given apps for Context Chat to work as desired **in the given order*
<bugs>https://github.com/nextcloud/context_chat_backend/issues</bugs>
<repository type="git">https://github.com/nextcloud/context_chat_backend.git</repository>
<dependencies>
<nextcloud min-version="30" max-version="30"/>
<nextcloud min-version="30" max-version="31"/>
</dependencies>
<external-app>
<docker-install>
Expand Down
1 change: 1 addition & 0 deletions config.cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ httpx_verify_ssl: true
model_offload_timeout: 15 # 15 minutes
use_colors: true
uvicorn_workers: 1
embedding_chunk_size: 1000
kyteinsky marked this conversation as resolved.
Show resolved Hide resolved

# model files download configuration
disable_custom_model_download: false
Expand Down
1 change: 1 addition & 0 deletions config.gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ httpx_verify_ssl: true
model_offload_timeout: 15 # 15 minutes
use_colors: true
uvicorn_workers: 1
embedding_chunk_size: 1000

# model files download configuration
disable_custom_model_download: false
Expand Down
6 changes: 3 additions & 3 deletions context_chat_backend/chain/ingest/doc_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
)


def get_splitter_for(mimetype: str = 'text/plain') -> TextSplitter:
def get_splitter_for(chunk_size: int, mimetype: str = 'text/plain') -> TextSplitter:
kwargs = {
'chunk_size': 2000,
'chunk_overlap': 200,
'chunk_size': chunk_size,
'chunk_overlap': int(chunk_size / 10),
'add_start_index': True,
'strip_whitespace': True,
'is_separator_regex': True,
Expand Down
8 changes: 5 additions & 3 deletions context_chat_backend/chain/ingest/injest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi.datastructures import UploadFile
from langchain.schema import Document

from ...config_parser import TConfig
from ...utils import not_none, to_int
from ...vectordb import BaseVectorDB
from .doc_loader import decode_source
Expand Down Expand Up @@ -111,7 +112,7 @@ def _bucket_by_type(documents: list[Document]) -> dict[str, list[Document]]:
return bucketed_documents


def _process_sources(vectordb: BaseVectorDB, sources: list[UploadFile]) -> bool:
def _process_sources(vectordb: BaseVectorDB, config: TConfig, sources: list[UploadFile]) -> bool:
filtered_sources = _filter_sources(sources[0].headers['userId'], vectordb, sources)

if len(filtered_sources) == 0:
Expand All @@ -132,7 +133,7 @@ def _process_sources(vectordb: BaseVectorDB, sources: list[UploadFile]) -> bool:
type_bucketed_docs = _bucket_by_type(documents)

for _type, _docs in type_bucketed_docs.items():
text_splitter = get_splitter_for(_type)
text_splitter = get_splitter_for(config['embedding_chunk_size'], _type)
split_docs = text_splitter.split_documents(_docs)
split_documents.extend(split_docs)

Expand All @@ -158,6 +159,7 @@ def _process_sources(vectordb: BaseVectorDB, sources: list[UploadFile]) -> bool:

def embed_sources(
vectordb: BaseVectorDB,
config: TConfig,
sources: list[UploadFile],
) -> bool:
# either not a file or a file that is allowed
Expand All @@ -172,4 +174,4 @@ def embed_sources(
'\n'.join([f'{source.filename} ({source.headers.get("title", "")})' for source in sources_filtered]),
flush=True,
)
return _process_sources(vectordb, sources_filtered)
return _process_sources(vectordb, config, sources_filtered)
2 changes: 2 additions & 0 deletions context_chat_backend/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class TConfig(TypedDict):
model_offload_timeout: int
use_colors: bool
uvicorn_workers: int
embedding_chunk_size: int

# model files download configuration
disable_custom_model_download: bool
Expand Down Expand Up @@ -74,6 +75,7 @@ def get_config(file_path: str) -> TConfig:
'model_offload_timeout': config.get('model_offload_timeout', 15),
'use_colors': config.get('use_colors', True),
'uvicorn_workers': config.get('uvicorn_workers', 1),
'embedding_chunk_size': config.get('embedding_chunk_size', 1000),

'disable_custom_model_download': config.get('disable_custom_model_download', False),
'model_download_uri': config.get('model_download_uri', 'https://download.nextcloud.com/server/apps/context_chat_backend'),
Expand Down
69 changes: 38 additions & 31 deletions context_chat_backend/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _(sources: list[UploadFile]):
return JSONResponse('Invaild/missing headers', 400)

db: BaseVectorDB = vectordb_loader.load()
result = embed_sources(db, sources)
result = embed_sources(db, app.extra['CONFIG'], sources)
if not result:
return JSONResponse('Error: All sources were not loaded, check logs for more info', 500)

Expand Down Expand Up @@ -305,40 +305,47 @@ def at_least_one_context(cls, value: int):
return value


def execute_query(query: Query) -> LLMOutput:
# todo: migrate to Depends during db schema change
llm: LLM = llm_loader.load()

template = app.extra.get('LLM_TEMPLATE')
no_ctx_template = app.extra['LLM_NO_CTX_TEMPLATE']
# todo: array
end_separator = app.extra.get('LLM_END_SEPARATOR', '')

if query.useContext:
db: BaseVectorDB = vectordb_loader.load()
return process_context_query(
user_id=query.userId,
vectordb=db,
llm=llm,
app_config=app_config,
query=query.query,
ctx_limit=query.ctxLimit,
template=template,
end_separator=end_separator,
scope_type=query.scopeType,
scope_list=query.scopeList,
)

return process_query(
llm=llm,
app_config=app_config,
query=query.query,
no_ctx_template=no_ctx_template,
end_separator=end_separator,
)


@app.post('/query')
@enabled_guard(app)
def _(query: Query) -> LLMOutput:
global llm_lock
print('query:', query, flush=True)

if app_config['llm'][0] == 'nc_texttotext':
return execute_query(query)

with llm_lock:
# todo: migrate to Depends during db schema change
llm: LLM = llm_loader.load()

template = app.extra.get('LLM_TEMPLATE')
no_ctx_template = app.extra['LLM_NO_CTX_TEMPLATE']
# todo: array
end_separator = app.extra.get('LLM_END_SEPARATOR', '')

if query.useContext:
db: BaseVectorDB = vectordb_loader.load()
return process_context_query(
user_id=query.userId,
vectordb=db,
llm=llm,
app_config=app_config,
query=query.query,
ctx_limit=query.ctxLimit,
template=template,
end_separator=end_separator,
scope_type=query.scopeType,
scope_list=query.scopeList,
)

return process_query(
llm=llm,
app_config=app_config,
query=query.query,
no_ctx_template=no_ctx_template,
end_separator=end_separator,
)
return execute_query(query)