-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
193 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
package vectorstore | ||
|
||
import ( | ||
"context" | ||
"sort" | ||
|
||
"github.com/hupe1980/golc/internal/util" | ||
"github.com/hupe1980/golc/metric" | ||
"github.com/hupe1980/golc/schema" | ||
) | ||
|
||
// Compile time check to ensure InMemory satisfies the VectorStore interface. | ||
var _ schema.VectorStore = (*InMemory)(nil) | ||
|
||
// InMemoryItem represents an item stored in memory with its content, vector, and metadata. | ||
type InMemoryItem struct { | ||
Content string `json:"content"` | ||
Vector []float32 `json:"vector"` | ||
Metadata map[string]any `json:"metadata"` | ||
} | ||
|
||
// InMemoryOptions represents options for the in-memory vector store. | ||
type InMemoryOptions struct { | ||
TopK int | ||
} | ||
|
||
// InMemory represents an in-memory vector store. | ||
// Note: This implementation is intended for testing and demonstration purposes, not for production use. | ||
type InMemory struct { | ||
embedder schema.Embedder | ||
data []InMemoryItem | ||
opts InMemoryOptions | ||
} | ||
|
||
// NewInMemory creates a new instance of the in-memory vector store. | ||
func NewInMemory(embedder schema.Embedder, optFns ...func(*InMemoryOptions)) *InMemory { | ||
opts := InMemoryOptions{ | ||
TopK: 3, | ||
} | ||
|
||
for _, fn := range optFns { | ||
fn(&opts) | ||
} | ||
|
||
return &InMemory{ | ||
data: make([]InMemoryItem, 0), | ||
embedder: embedder, | ||
opts: opts, | ||
} | ||
} | ||
|
||
// AddDocuments adds a batch of documents to the InMemory vector store. | ||
func (vs *InMemory) AddDocuments(ctx context.Context, docs []schema.Document) error { | ||
texts := make([]string, len(docs)) | ||
for i, doc := range docs { | ||
texts[i] = doc.PageContent | ||
} | ||
|
||
vectors, err := vs.embedder.BatchEmbedText(ctx, texts) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
for i, doc := range docs { | ||
vs.data = append(vs.data, InMemoryItem{ | ||
Content: doc.PageContent, | ||
Vector: vectors[i], | ||
Metadata: doc.Metadata, | ||
}) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// AddItem adds a single item to the InMemory vector store. | ||
func (vs *InMemory) AddItem(item InMemoryItem) { | ||
vs.data = append(vs.data, item) | ||
} | ||
|
||
// Data returns the underlying data stored in the InMemory vector store. | ||
func (vs *InMemory) Data() []InMemoryItem { | ||
return vs.data | ||
} | ||
|
||
// SimilaritySearch performs a similarity search with the given query in the InMemory vector store. | ||
func (vs *InMemory) SimilaritySearch(ctx context.Context, query string) ([]schema.Document, error) { | ||
queryVector, err := vs.embedder.EmbedText(ctx, query) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
type searchResult struct { | ||
Item InMemoryItem | ||
Similarity float32 | ||
} | ||
|
||
results := make([]searchResult, len(vs.data)) | ||
|
||
for i, item := range vs.data { | ||
similarity := metric.CosineSimilarity(queryVector, item.Vector) | ||
results[i] = searchResult{Item: item, Similarity: similarity} | ||
} | ||
|
||
// Sort results by similarity in descending order | ||
sort.Slice(results, func(i, j int) bool { | ||
return results[i].Similarity > results[j].Similarity | ||
}) | ||
|
||
docLen := util.Min(len(results), vs.opts.TopK) | ||
|
||
// Extract documents from sorted results | ||
documents := make([]schema.Document, docLen) | ||
for i := 0; i < docLen; i++ { | ||
documents[i] = schema.Document{ | ||
PageContent: results[i].Item.Content, | ||
Metadata: results[i].Item.Metadata, | ||
} | ||
} | ||
|
||
return documents, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package vectorstore | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
|
||
"github.com/hupe1980/golc/schema" | ||
) | ||
|
||
func TestInMemory(t *testing.T) { | ||
// Setup | ||
embedder := &mockEmbedder{} | ||
vs := NewInMemory(embedder) | ||
|
||
// Test AddDocuments method | ||
t.Run("AddDocuments", func(t *testing.T) { | ||
// Given | ||
documents := []schema.Document{ | ||
{PageContent: "document1"}, | ||
{PageContent: "document2"}, | ||
{PageContent: "document3"}, | ||
} | ||
|
||
// When | ||
err := vs.AddDocuments(context.Background(), documents) | ||
|
||
// Then | ||
assert.NoError(t, err) | ||
assert.Len(t, vs.Data(), 3) | ||
}) | ||
|
||
// Test SimilaritySearch method | ||
t.Run("SimilaritySearch", func(t *testing.T) { | ||
// Given | ||
query := "query" | ||
expectedDocuments := []schema.Document{ | ||
{PageContent: "document1"}, | ||
{PageContent: "document2"}, | ||
{PageContent: "document3"}, | ||
} | ||
|
||
// When | ||
documents, err := vs.SimilaritySearch(context.Background(), query) | ||
|
||
// Then | ||
assert.NoError(t, err) | ||
assert.Len(t, documents, 3) | ||
|
||
for i, doc := range documents { | ||
assert.Equal(t, expectedDocuments[i].PageContent, doc.PageContent) | ||
} | ||
}) | ||
} | ||
|
||
// mockEmbedder implements the schema.Embedder interface for testing purposes. | ||
type mockEmbedder struct{} | ||
|
||
func (m *mockEmbedder) BatchEmbedText(ctx context.Context, texts []string) ([][]float32, error) { | ||
// Mock implementation for batch embedding text | ||
return [][]float32{ | ||
{1.0, 2.0, 3.0}, | ||
{2.0, 3.0, 4.0}, | ||
{3.0, 4.0, 5.0}, | ||
}, nil | ||
} | ||
|
||
func (m *mockEmbedder) EmbedText(ctx context.Context, text string) ([]float32, error) { | ||
// Mock implementation for embedding text | ||
return []float32{1.0, 2.0, 3.0}, nil | ||
} |