Skip to content

Commit

Permalink
tons of improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
fuziontech committed Jun 9, 2023
1 parent 1d36753 commit 8c48815
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 18 deletions.
11 changes: 6 additions & 5 deletions ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ async def ai_chat_thread(thread):
[
{"page_content": doc.page_content, "metadata": doc.metadata}
for doc in documents
]
],
indent=2,
)

print(json_docs)

SYSTEM_PROMPT = """
You are the trusty PostHog support AI named Max. You are also PostHog's Mascot!
Please continue the conversation in a way that is helpful to the user and also makes the user feel like they are talking to a human.
Only suggest using PostHog products or services. Do not suggest products or services from other companies.
Only suggest using PostHog and ClickHouse products or services. Do not suggest products or services from other companies.
Please answer the question according to the following context.
Do not create links. Only reference the source from the metadata.source object in the context and prefix it with "https://github.com/PostHog/posthog.com/tree/master".
Do not create links. Only reference the source from the metadata.source object in the context.
If you get a question about pricing please refer to the reasonable and transparent pricing on the pricing page at https://posthog.com/pricing.
If you are unsure of the answer, please say "I'm not sure" and encourage the user to ask PostHog staff.
Try not to mention <@*> in the response.
Expand All @@ -73,8 +76,6 @@ async def ai_chat_thread(thread):
*follow_up_thread,
]

print(prompt)

completion = openai.ChatCompletion.create(model=OPENAI_MODEL, messages=prompt)

completion = completion.choices[0].message.content
Expand Down
18 changes: 10 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from typing import List
from typing import List, Optional


import sentry_sdk
from dotenv import load_dotenv
Expand Down Expand Up @@ -47,10 +48,15 @@ class Message(BaseModel):
role: str
content: str


class Query(BaseModel):
query: str


class GitHubRepo(BaseModel):
repo: Optional[str]


pipeline = MaxPipeline(openai_token=os.getenv("OPENAI_TOKEN"))


Expand All @@ -61,11 +67,8 @@ def create_entries(entries: Entries):


@app.post("/_git")
def create_git_entries(repo_url: str):
print("git")
if not repo_url:
repo_url = "https://github.com/posthog/posthog.com"
pipeline.embed_git_repo(repo_url=repo_url)
def create_git_entries(gh_repo: GitHubRepo):
pipeline.embed_git_repo(gh_repo=gh_repo.repo)
return {"status": "ok"}


Expand All @@ -87,12 +90,11 @@ def receive_spawn():

@app.post("/update")
def update_oncall():
return "nope"
return "nope"


@app.post("/chat")
async def chat(messages: List[Message]):
print(messages)
msgs = [msg.dict() for msg in messages]
response = await ai_chat_thread(msgs)
return response
Expand Down
15 changes: 10 additions & 5 deletions pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def embed_documents(self, documents: List[Document]):
self.document_store.add_documents(documents)

def retrieve_context(self, query: str):
result = self.retriever.get_relevant_documents(query)
return result
return self.retriever.get_relevant_documents(query)

def chat(self, query: str):
chain = RetrievalQAWithSourcesChain.from_chain_type(
Expand All @@ -88,8 +87,9 @@ def chat(self, query: str):
)
return results

def embed_git_repo(self, repo_url):
repo_dir = repo_url.split("/")[-1].replace(".git", "")
def embed_git_repo(self, gh_repo):
repo_url = f"https://github.com/{gh_repo}.git"
repo_dir = gh_repo.split("/")[-1]
path = os.path.join(EXAMPLE_DATA_DIR, repo_dir)
if not os.path.exists(path):
print("Repo not found, cloning...")
Expand All @@ -106,14 +106,19 @@ def embed_git_repo(self, repo_url):
loader = GitLoader(
repo_path=path,
branch=branch,
file_filter=lambda file_path: file_path.endswith(".md"),
file_filter=lambda file_path: file_path.endswith((".md", ".mdx")),
)
data = loader.load()
for page in data:
docs = []
text = self.splitter.split_text(page.page_content)
metadata = page.metadata
print(f"Adding {page.metadata['source']}")
page.metadata[
"source"
] = f"https://github.com/{gh_repo}/blob/master/{page.metadata['source']} "
for token in text:
docs.append(Document(page_content=token, metadata=metadata))
self.document_store.add_documents(docs)
print("Done")
return

0 comments on commit 8c48815

Please sign in to comment.