Skip to content

Commit

Permalink
feat: added basic context generationg and retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
sauravpanda committed Aug 17, 2024
1 parent ad403d9 commit 9584fd2
Show file tree
Hide file tree
Showing 11 changed files with 438 additions and 212 deletions.
25 changes: 19 additions & 6 deletions db_setup/init.sql
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
-- Enable vector extension
CREATE EXTENSION IF NOT EXISTS vector;

-- Table to store repository information
CREATE TABLE repositories (
repo_id SERIAL PRIMARY KEY,
repo_name TEXT NOT NULL,
repo_owner TEXT NOT NULL,
repo_url TEXT NOT NULL,
repo_description TEXT
repo_description TEXT,
CONSTRAINT unique_repo UNIQUE (repo_name, repo_owner)
);

-- Table to store file information
Expand All @@ -17,7 +17,8 @@ CREATE TABLE files (
file_path TEXT NOT NULL,
file_name TEXT NOT NULL,
file_ext TEXT NOT NULL,
programming_language TEXT
programming_language TEXT,
CONSTRAINT unique_repo_file UNIQUE (repo_id, file_path)
);

-- Table to store function abstractions
Expand All @@ -37,10 +38,10 @@ CREATE TABLE function_abstractions (
CREATE TABLE function_embeddings (
embedding_id SERIAL PRIMARY KEY,
function_id INTEGER NOT NULL REFERENCES function_abstractions(function_id),
vector VECTOR(1536) NOT NULL
vector VECTOR(1536) NOT NULL,
CONSTRAINT unique_function_embedding UNIQUE (function_id)
);


CREATE TABLE syntax_nodes (
node_id SERIAL PRIMARY KEY,
file_id INTEGER NOT NULL REFERENCES files(file_id),
Expand All @@ -65,4 +66,16 @@ CREATE TABLE node_properties (
node_id INTEGER NOT NULL REFERENCES syntax_nodes(node_id),
property_name TEXT NOT NULL,
property_value TEXT NOT NULL
);
);

-- Create an index on the file_path column for faster lookups
CREATE INDEX idx_file_path ON files(file_path);

-- Create an index on the function_name column for faster lookups
CREATE INDEX idx_function_name ON function_abstractions(function_name);

-- Create an index on the node_type column for faster lookups
CREATE INDEX idx_node_type ON syntax_nodes(node_type);

-- Create an index on the vector column for faster similarity searches
CREATE INDEX idx_function_embeddings_vector ON function_embeddings USING ivfflat (vector vector_l2_ops);
13 changes: 12 additions & 1 deletion docker-compose-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ services:
restart: always
secrets:
- github_app_pem
networks:
- app-network

redis:
image: "redis:alpine"
Expand All @@ -20,8 +22,11 @@ services:
- REDIS_PASSWORD=${REDIS_PASSWORD}
ports:
- "6379:6379"
networks:
- app-network

postgres:
container_name: postgres
build:
context: .
dockerfile: Dockerfile-postgres
Expand All @@ -31,7 +36,13 @@ services:
- ./db_setup/init.sql:/docker-entrypoint-initdb.d/init.sql
ports:
- "5432:5432"
networks:
- app-network

secrets:
github_app_pem:
file: ./GITHUB_APP_NIGHTLY.pem
file: ./GITHUB_APP_NIGHTLY.pem

networks:
app-network:
driver: bridge
4 changes: 2 additions & 2 deletions examples/ragify_codebase/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
analyzer = RepositoryAnalyzer()

# Set up the repository (do this when you first analyze a repo or when you want to update it)
analyzer.setup_repository("./github_app/")
# analyzer.setup_repository("./github_app/")

# Perform queries (you can do this as many times as you want without calling setup_repository again)
results = analyzer.query("Find functions that handle authentication")
for result in results:
print(f"File: {result['file_path']}")
print(f"Abstraction: {result['abstraction']}")
print(f"Code:\n{result['code']}")
print(f"result:\n{result}")
print(f"Relevance Score: {result['relevance_score']}")
print("---")

Expand Down
8 changes: 2 additions & 6 deletions install_tree_sitter_languages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,8 @@ install_language() {
# Update submodules
git submodule update --init

# Compile the parser
cc -fPIC -c -I./src src/parser.c
cc -shared *.o -o "../$lang.so"

# Clean up object files
rm *.o
# Build the parser using tree-sitter CLI
tree-sitter generate

# Navigate back to the original directory
cd ../..
Expand Down
2 changes: 1 addition & 1 deletion kaizen/llms/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,6 @@ def get_text_embedding(self, text):
# if model["model_name"] == "embedding":
# break
response = self.provider.embedding(
model="embedding", input=[text], dimensions=1536
model="embedding", input=[text], dimensions=1536, encoding_format="float"
)
return response["data"], response["usage"]
94 changes: 5 additions & 89 deletions kaizen/retriever/code_chunker.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,11 @@
import os
import subprocess
from tree_sitter import Language, Parser
from typing import Dict, List, Any
from typing import Dict, Any
from kaizen.retriever.tree_sitter_utils import parse_code, ParserFactory

ParsedBody = Dict[str, Dict[str, Any]]

# Define the languages and their GitHub repositories
LANGUAGES = {
"python": "https://github.com/tree-sitter/tree-sitter-python",
"javascript": "https://github.com/tree-sitter/tree-sitter-javascript",
"typescript": "https://github.com/tree-sitter/tree-sitter-typescript",
"rust": "https://github.com/tree-sitter/tree-sitter-rust",
}

# Directory to store the language libraries
LANGUAGE_DIR = os.path.join(os.path.dirname(__file__), "tree_sitter_languages")


def ensure_language_installed(language: str) -> None:
if not os.path.exists(LANGUAGE_DIR):
os.makedirs(LANGUAGE_DIR)

lang_file = os.path.join(LANGUAGE_DIR, f"{language}.so")
if not os.path.exists(lang_file):
repo_url = LANGUAGES[language]
repo_dir = os.path.join(LANGUAGE_DIR, f"tree-sitter-{language}")

if not os.path.exists(repo_dir):
subprocess.run(["git", "clone", repo_url, repo_dir], check=True)

subprocess.run(
["bash", "-c", f"cd {repo_dir} && git submodule update --init"], check=True
)
Language.build_library(lang_file, [repo_dir])


def get_parser(language: str) -> Parser:
ensure_language_installed(language)
parser = Parser()
lang_file = os.path.join(LANGUAGE_DIR, f"{language}.so")
lang = Language(lang_file, language)
parser.set_language(lang)
return parser


def traverse_tree(node, code_bytes: bytes) -> Dict[str, Any]:
if node.type in [
"function_definition",
"function_declaration",
"arrow_function",
"method_definition",
]:
return {
"type": "function",
"name": (
node.child_by_field_name("name").text.decode("utf8")
if node.child_by_field_name("name")
else "anonymous"
),
"code": code_bytes[node.start_byte : node.end_byte].decode("utf8"),
}
elif node.type in ["class_definition", "class_declaration"]:
return {
"type": "class",
"name": node.child_by_field_name("name").text.decode("utf8"),
"code": code_bytes[node.start_byte : node.end_byte].decode("utf8"),
}
elif node.type in ["jsx_element", "jsx_self_closing_element"]:
return {
"type": "component",
"name": (
node.child_by_field_name("opening_element")
.child_by_field_name("name")
.text.decode("utf8")
if node.type == "jsx_element"
else node.child_by_field_name("name").text.decode("utf8")
),
"code": code_bytes[node.start_byte : node.end_byte].decode("utf8"),
}
elif node.type == "impl_item":
return {
"type": "impl",
"name": node.child_by_field_name("type").text.decode("utf8"),
"code": code_bytes[node.start_byte : node.end_byte].decode("utf8"),
}
else:
return None


def chunk_code(code: str, language: str) -> ParsedBody:
parser = get_parser(language)
parser = ParserFactory.get_parser(language)
tree = parser.parse(code.encode("utf8"))

body: ParsedBody = {
Expand All @@ -99,10 +15,10 @@ def chunk_code(code: str, language: str) -> ParsedBody:
"components": {},
"other_blocks": [],
}
code_bytes = code.encode("utf8")
# code_bytes = code.encode("utf8")

def process_node(node):
result = traverse_tree(node, code_bytes)
result = parse_code(code, language)
if result:
if result["type"] == "function":
if is_react_hook(result["name"]):
Expand Down
52 changes: 52 additions & 0 deletions kaizen/retriever/custom_vector_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from llama_index.vector_stores.postgres import PGVectorStore
from typing import List
import numpy as np
from psycopg2.extras import Json


class CustomPGVectorStore(PGVectorStore):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Store the table name in a new attribute
self.table_name = kwargs.get('table_name', 'embeddings')

def custom_query(self, query_embedding: List[float], repo_id: int, similarity_top_k: int) -> List[dict]:
# Normalize the query embedding
query_embedding_np = np.array(query_embedding)
query_embedding_normalized = query_embedding_np / np.linalg.norm(query_embedding_np)

# SQL query with repo_id filter and cosine similarity
query = f"""
SELECT
e.node_id,
e.text,
e.metadata,
1 - (e.embedding <=> %s::vector) as similarity
FROM
{self.table_name} e
JOIN
function_abstractions fa ON e.node_id = fa.function_id::text
JOIN
files f ON fa.file_id = f.file_id
WHERE
f.repo_id = %s
ORDER BY
similarity DESC
LIMIT
%s
"""

with self.get_client() as client:
with client.cursor() as cur:
cur.execute(query, (query_embedding_normalized.tolist(), repo_id, similarity_top_k))
results = cur.fetchall()

return [
{
"id": row[0],
"text": row[1],
"metadata": row[2] if isinstance(row[2], dict) else Json(row[2]),
"similarity": row[3]
}
for row in results
]
Loading

0 comments on commit 9584fd2

Please sign in to comment.