Skip to content

Commit

Permalink
Add set() method and clear() method to VLite class
Browse files Browse the repository at this point in the history
  • Loading branch information
sdan committed Mar 31, 2024
1 parent 233a20c commit 7e44a71
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 65 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Run Unit Tests

on: [push]

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: '3.8'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Add any other dependencies here
- name: Run tests
run: python ./tests/unit.py
1 change: 1 addition & 0 deletions tests/data/text-8192tokens.txt

Large diffs are not rendered by default.

58 changes: 58 additions & 0 deletions tests/tokengen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from random_word import RandomWords
import tiktoken

def num_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int:
encoding = tiktoken.get_encoding(encoding_name)
num_tokens = len(encoding.encode(string))
return num_tokens

def generate_string_of_length(target_tokens: int) -> str:
r = RandomWords()
generated_string = ""
current_tokens = 0

while current_tokens < target_tokens:
word = r.get_random_word()
word_tokens = num_tokens_from_string(word)

if current_tokens + word_tokens <= target_tokens:
generated_string += word + " "
current_tokens += word_tokens + 1 # Add 1 for the space
else:
break

# Remove the trailing space
generated_string = generated_string.strip()

# If the token count is less than the target, append words one by one
while current_tokens < target_tokens:
word = r.get_random_word()
word_tokens = num_tokens_from_string(word)

if current_tokens + word_tokens <= target_tokens:
generated_string += " " + word
current_tokens += word_tokens + 1 # Add 1 for the space
else:
break

# If the token count is greater than the target, remove words one by one
while current_tokens > target_tokens:
words = generated_string.split()
last_word = words.pop()
last_word_tokens = num_tokens_from_string(last_word)
current_tokens -= last_word_tokens + 1 # Subtract 1 for the space
generated_string = " ".join(words)

return generated_string

# Generate a string of 512 tokens
string_512_tokens = generate_string_of_length(512)
print(f"String of 512 tokens:\n{string_512_tokens}")
print(f"Actual token count: {num_tokens_from_string(string_512_tokens)}")

print("\n" + "-" * 50 + "\n")

# Generate a string of 8192 tokens
string_8192_tokens = generate_string_of_length(8192)
print(f"String of 8192 tokens:\n{string_8192_tokens}")
print(f"Actual token count: {num_tokens_from_string(string_8192_tokens)}")
63 changes: 0 additions & 63 deletions tests/unit-test.py

This file was deleted.

103 changes: 103 additions & 0 deletions tests/unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import unittest
import numpy as np
from vlite.main import VLite
import os
from vlite.utils import process_pdf
import cProfile
from pstats import Stats
import matplotlib.pyplot as plt
import time

class TestVLite(unittest.TestCase):
test_times = {}

def setUp(self):
self.vlite = VLite("vlite-unit")

def tearDown(self):
# Remove the file
if os.path.exists('vlite-unit'):
print("[+] Removing vlite")
os.remove('vlite-unit')

def test_add__text(self):
start_time = time.time()
text = "This is a test text."
metadata = {"source": "test"}
self.vlite.add(text, metadata=metadata)
self.assertEqual(self.vlite.count(), 1)
end_time = time.time()
TestVLite.test_times["add_single_text"] = end_time - start_time
print(f"Count of texts in the collection: {self.vlite.count()}")

def test_add_texts(self):
start_time = time.time()
text_512tokens = "underreckoning fleckiness hairstane paradigmatic eligibility sublevate xviii achylia reremice flung outpurl questing gilia unosmotic unsuckled plecopterid excludable phenazine fricando unfledgedness spiritsome incircle desmogenous subclavate redbug semihoral district chrysocolla protocoled servius readings propolises javali dujan stickman attendee hambone obtusipennate tightropes monitorially signaletics diestrums preassigning spriggy yestermorning margaritic tankfuls aseptify linearity hilasmic twinning tokonoma seminormalness cerebrospinant refroid doghouse kochab dacryocystalgia saltbushes newcomer provoker berberid platycoria overpersuaded reoverflow constrainable headless forgivably syzygal purled reese polyglottonic decennary embronze pluripotent equivocally myoblasts thymelaeaceous confervae perverted preanticipate mammalogical desalinizing tackets misappearance subflexuose concludence effluviums runtish gras cuckolded hemostasia coatroom chelidon policizer trichinised frontstall impositions unta outrance scholium fibrochondritis furcates fleaweed housefront helipads hemachate snift appellativeness knobwood superinclination tsures haberdasheries unparliamented reexecution nontangential waddied desolated subdistinctively undiscernibleness swishiest dextral progs koprino bruisingly unloanably bardash uncuckoldedunderreckoning fleckiness hairstane paradigmatic eligibility sublevate xviii achylia reremice flung outpurl questing gilia unosmotic unsuckled plecopterid excludable phenazine fricando unfledgedness spiritsome incircle desmogenous subclavate redbug semihoral district chrysocolla spriggy yestermorning margaritic tankfuls aseptify linearity hilasmic twinning tokonoma seminormalness cerebrospinant refroequivocally myoblasts thymelaeaceous confervae perverted preantiest dextral progs koprino bruisingly unloanably bardash uncuckolded"
metadata = {"source": "test_512tokens"}
self.vlite.add(text_512tokens, metadata=metadata)
with open("data/text-8192tokens.txt", "r") as file:
text_8192tokens = file.read()
metadata = {"source": "test_8192tokens"}
self.vlite.add(text_8192tokens, metadata=metadata)
end_time = time.time()
TestVLite.test_times["add_multiple_texts"] = end_time - start_time
print(f"Count of texts in the collection: {self.vlite.count()}")

def test_add_pdf(self):
start_time = time.time()
process_pdf('data/gpt-4.pdf')
end_time = time.time()
TestVLite.test_times["add_pdf"] = end_time - start_time
# time to add 71067 tokens from the GPT-4 paper
print(f"Time to add 71067 tokens: {TestVLite.test_times['add_pdf']} seconds")

def test_retrieve(self):
queries = [
"What is the architecture of GPT-4?",
"How does GPT-4 handle contextual understanding?",
"What are the key improvements in GPT-4 over GPT-3?",
"How many parameters does GPT-4 have?",
"What are the limitations of GPT-4?",
"What datasets were used to train GPT-4?",
"How does GPT-4 handle longer context?",
"What is the computational requirement for training GPT-4?",
"What techniques were used to train GPT-4?",
"What is the impact of GPT-4 on natural language processing?",
"What are the use cases demonstrated in the GPT-4 paper?",
"What are the evaluation metrics used in GPT-4's paper?",
"What kind of ethical considerations are discussed in the GPT-4 paper?",
"How does the GPT-4 handle tokenization?",
"What are the novel contributions of the GPT-4 model?"
]
process_pdf('data/gpt-4.pdf')
start_time = time.time()
for query in self.queries:
_, top_sims, _ = self.vlite.retrieve(query)
print(f"Top similarities for query '{query}': {top_sims}")
end_time = time.time()
TestVLite.test_times["retrieve"] = end_time - start_time

def test_delete(self):
self.vlite.add("This is a test text.")
start_time = time.time()
self.vlite.delete(0)
end_time = time.time()
TestVLite.test_times["delete"] = end_time - start_time
print(f"Count of texts in the collection: {self.vlite.count()}")

def test_update(self):
self.vlite.add("This is a test text.")
start_time = time.time()
self.vlite.update(0, "This is an updated text.")
end_time = time.time()
TestVLite.test_times["update"] = end_time - start_time
print(f"Count of texts in the collection: {self.vlite.count()}")

@classmethod
def tearDownClass(cls):
print("\nTest times:")
for test_name, test_time in cls.test_times.items():
print(f"{test_name}: {test_time:.4f} seconds")

if __name__ == '__main__':
unittest.main(verbosity=2)
49 changes: 47 additions & 2 deletions vlite/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,31 @@ def get(self, ids=None, where=None):
return [(self.texts[idx], self.metadata[idx]) for idx in ids if idx in self.metadata]
else:
return [(self.texts[idx], self.metadata[idx]) for idx in ids if idx in self.metadata and all(self.metadata[idx].get(k) == v for k, v in where.items())]



def set(self, id, text=None, metadata=None, vector=None):
"""
Updates the attributes of an item in the collection by ID.
Args:
id (str): ID of the item to update.
text (str, optional): Updated text content of the item.
metadata (dict, optional): Updated metadata of the item.
vector (numpy.ndarray, optional): Updated embedding vector of the item.
"""
print(f"Setting attributes for item with ID: {id}")
idx = next((i for i, x in enumerate(self.metadata) if x.get('index') == id), None)
if idx is not None:
if text is not None:
self.texts[idx] = text
if metadata is not None:
self.metadata[idx].update(metadata)
if vector is not None:
self.vectors[idx] = vector
self.save()
else:
print(f"Item with ID {id} not found.")

def count(self):
"""
Returns the number of items in the collection.
Expand All @@ -171,4 +195,25 @@ def save(self):
print(f"Saving collection to {self.collection}")
with open(self.collection, 'wb') as f:
np.savez(f, texts=self.texts, metadata=self.metadata, vectors=self.vectors)
print("Collection saved successfully.")
print("Collection saved successfully.")

def clear(self):
"""
Clears the entire collection, removing all items and resetting the attributes.
"""
print("Clearing the collection...")
self.texts = []
self.metadata = {}
self.vectors = np.empty((0, self.model.dimension))
self.save()
print("Collection cleared.")

def info(self):
"""
Prints information about the collection, including the number of items, collection file path,
and the embedding model used.
"""
print("Collection Information:")
print(f" Items: {self.count()}")
print(f" Collection file: {self.collection}")
print(f" Embedding model: {self.model}")

0 comments on commit 7e44a71

Please sign in to comment.