Skip to content

Commit

Permalink
Add in memory vectorstore
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Mar 9, 2024
1 parent f84e83d commit b0f1fd4
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 0 deletions.
121 changes: 121 additions & 0 deletions vectorstore/in_memory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package vectorstore

import (
"context"
"sort"

"github.com/hupe1980/golc/internal/util"
"github.com/hupe1980/golc/metric"
"github.com/hupe1980/golc/schema"
)

// Compile time check to ensure InMemory satisfies the VectorStore interface.
var _ schema.VectorStore = (*InMemory)(nil)

// InMemoryItem represents an item stored in memory with its content, vector, and metadata.
type InMemoryItem struct {
Content string `json:"content"`
Vector []float32 `json:"vector"`
Metadata map[string]any `json:"metadata"`
}

// InMemoryOptions represents options for the in-memory vector store.
type InMemoryOptions struct {
TopK int
}

// InMemory represents an in-memory vector store.
// Note: This implementation is intended for testing and demonstration purposes, not for production use.
type InMemory struct {
embedder schema.Embedder
data []InMemoryItem
opts InMemoryOptions
}

// NewInMemory creates a new instance of the in-memory vector store.
func NewInMemory(embedder schema.Embedder, optFns ...func(*InMemoryOptions)) *InMemory {
opts := InMemoryOptions{
TopK: 3,
}

for _, fn := range optFns {
fn(&opts)
}

return &InMemory{
data: make([]InMemoryItem, 0),
embedder: embedder,
opts: opts,
}
}

// AddDocuments adds a batch of documents to the InMemory vector store.
func (vs *InMemory) AddDocuments(ctx context.Context, docs []schema.Document) error {
texts := make([]string, len(docs))
for i, doc := range docs {
texts[i] = doc.PageContent
}

vectors, err := vs.embedder.BatchEmbedText(ctx, texts)
if err != nil {
return err
}

for i, doc := range docs {
vs.data = append(vs.data, InMemoryItem{
Content: doc.PageContent,
Vector: vectors[i],
Metadata: doc.Metadata,
})
}

return nil
}

// AddItem adds a single item to the InMemory vector store.
func (vs *InMemory) AddItem(item InMemoryItem) {
vs.data = append(vs.data, item)
}

// Data returns the underlying data stored in the InMemory vector store.
func (vs *InMemory) Data() []InMemoryItem {
return vs.data
}

// SimilaritySearch performs a similarity search with the given query in the InMemory vector store.
func (vs *InMemory) SimilaritySearch(ctx context.Context, query string) ([]schema.Document, error) {
queryVector, err := vs.embedder.EmbedText(ctx, query)
if err != nil {
return nil, err
}

type searchResult struct {
Item InMemoryItem
Similarity float32
}

results := make([]searchResult, len(vs.data))

for i, item := range vs.data {
similarity := metric.CosineSimilarity(queryVector, item.Vector)
results[i] = searchResult{Item: item, Similarity: similarity}
}

// Sort results by similarity in descending order
sort.Slice(results, func(i, j int) bool {
return results[i].Similarity > results[j].Similarity
})

docLen := util.Min(len(results), vs.opts.TopK)

// Extract documents from sorted results
documents := make([]schema.Document, docLen)
for i := 0; i < docLen; i++ {
documents[i] = schema.Document{
PageContent: results[i].Item.Content,
Metadata: results[i].Item.Metadata,
}
}

return documents, nil
}
72 changes: 72 additions & 0 deletions vectorstore/in_memory_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package vectorstore

import (
"context"
"testing"

"github.com/stretchr/testify/assert"

"github.com/hupe1980/golc/schema"
)

func TestInMemory(t *testing.T) {
// Setup
embedder := &mockEmbedder{}
vs := NewInMemory(embedder)

// Test AddDocuments method
t.Run("AddDocuments", func(t *testing.T) {
// Given
documents := []schema.Document{
{PageContent: "document1"},
{PageContent: "document2"},
{PageContent: "document3"},
}

// When
err := vs.AddDocuments(context.Background(), documents)

// Then
assert.NoError(t, err)
assert.Len(t, vs.Data(), 3)
})

// Test SimilaritySearch method
t.Run("SimilaritySearch", func(t *testing.T) {
// Given
query := "query"
expectedDocuments := []schema.Document{
{PageContent: "document1"},
{PageContent: "document2"},
{PageContent: "document3"},
}

// When
documents, err := vs.SimilaritySearch(context.Background(), query)

// Then
assert.NoError(t, err)
assert.Len(t, documents, 3)

for i, doc := range documents {
assert.Equal(t, expectedDocuments[i].PageContent, doc.PageContent)
}
})
}

// mockEmbedder implements the schema.Embedder interface for testing purposes.
type mockEmbedder struct{}

func (m *mockEmbedder) BatchEmbedText(ctx context.Context, texts []string) ([][]float32, error) {
// Mock implementation for batch embedding text
return [][]float32{
{1.0, 2.0, 3.0},
{2.0, 3.0, 4.0},
{3.0, 4.0, 5.0},
}, nil
}

func (m *mockEmbedder) EmbedText(ctx context.Context, text string) ([]float32, error) {
// Mock implementation for embedding text
return []float32{1.0, 2.0, 3.0}, nil
}

0 comments on commit b0f1fd4

Please sign in to comment.