-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathchat.py
95 lines (73 loc) · 3.32 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from langchain.llms import OpenAI
from langchain import PromptTemplate, LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import prompt
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
import textwrap
import gradio
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.vectorstores import Chroma
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["OPENAI_API_KEY"] = "dummy-key"
# TODO: Callbacks support token-wise streaming
#callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
# Verbose is required to pass to the callback manager
temperature = 0.1 # Use a value between 0 and 2. Lower = factual, higher = creative
n_gpu_layers = 43 # Change this value based on your model and your GPU VRAM pool.
n_batch = 512 # Should be between 1 and n_ctx, consider the amount of VRAM in your GPU.
# Make sure the model path is correct for your system!
llm = OpenAI(
openai_api_base='http://localhost:1234/v1',
openai_api_key='dummy-key'
)
## Follow the default prompt style from the OpenOrca-Platypus2 huggingface model card.
def get_prompt():
return """Use the following Context information to answer the user's question. If you don't know the answer, just say that you don't know, don't try to make up an answer.
### Instruction:
Context: {context}
User Question: {question}
###
Response:
"""
def wrap_text_preserve_newlines(text, width=110):
# Split the input text into lines based on newline characters
lines = text.split('\n')
# Wrap each line individually
wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
# Join the wrapped lines back together using newline characters
wrapped_text = '\n'.join(wrapped_lines)
return wrapped_text
def process_llm_response(llm_response):
if not llm_response:
return "Please enter a question"
print(wrap_text_preserve_newlines(llm_response['result']))
print('\n\nSources:')
for source in llm_response["source_documents"]:
print(source.metadata['source'])
response = llm_response['result']
response = response.split("### Response")[0]
return response
def startChat():
embedding_directory = "./content/chroma_db"
embedding_model=HuggingFaceBgeEmbeddings(model_name='BAAI/bge-base-en', model_kwargs={'device':'cpu'})
embedding_db = Chroma(persist_directory=embedding_directory, embedding_function=embedding_model)
prompt_template = get_prompt()
llama_prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
chain_type_kwargs = {"prompt": llama_prompt}
retriever = embedding_db.as_retriever(search_type="mmr", search_kwargs={'k': 5})
# create the chain to answer questions
qa_chain = RetrievalQA.from_chain_type(llm=llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs=chain_type_kwargs,
return_source_documents=True)
def runChain(query, history):
return process_llm_response(qa_chain(query))
app = gradio.ChatInterface(runChain)
app.queue()
app.launch(share=False, debug=True)