Skip to content

Commit

Permalink
fix: defer cleanup of test
Browse files Browse the repository at this point in the history
  • Loading branch information
cabljac committed Sep 26, 2024
1 parent 2b263d2 commit 521a059
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 27 deletions.
24 changes: 18 additions & 6 deletions go/plugins/firebase/retriever.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package firebase

import (
"context"
"fmt"
"log"

"cloud.google.com/go/firestore"
"github.com/firebase/genkit/go/ai"
Expand Down Expand Up @@ -35,11 +48,11 @@ func DefineFirestoreRetriever(cfg RetrieverOptions) (ai.Retriever, error) {
if cfg.VectorType != Vector64 {
return nil, fmt.Errorf("DefineFirestoreRetriever: only Vector64 is supported")
}
if cfg.Client == nil {
return nil, fmt.Errorf("DefineFirestoreRetriever: Firestore client is not provided")
}

Retrieve := func(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) {
if cfg.Client == nil {
return nil, fmt.Errorf("DefineFirestoreRetriever: Firestore client is not provided")
}

if req.Document == nil {
return nil, fmt.Errorf("DefineFirestoreRetriever: Request document is nil")
Expand Down Expand Up @@ -89,8 +102,7 @@ func DefineFirestoreRetriever(cfg RetrieverOptions) (ai.Retriever, error) {
// Ensure content field exists and is of type string
content, ok := data[cfg.ContentField].(string)
if !ok {
// TODO: use genkit logger
log.Printf("Content field %s missing or not a string in document %s", cfg.ContentField, result.Ref.ID)
fmt.Printf("Content field %s missing or not a string in document %s", cfg.ContentField, result.Ref.ID)
continue
}

Expand Down
90 changes: 69 additions & 21 deletions go/plugins/firebase/retriever_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package firebase

import (
"context"
"flag"
"log"
"testing"

"cloud.google.com/go/firestore"
firebasev4 "firebase.google.com/go/v4"
"github.com/firebase/genkit/go/ai"
"google.golang.org/api/iterator"
)

var (
Expand All @@ -27,15 +41,32 @@ func (e *MockEmbedder) Name() string {
func (e *MockEmbedder) Embed(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) {
var embeddings []*ai.DocumentEmbedding
for _, doc := range req.Documents {
// For simplicity, all embeddings are [1.0, 0.0, 0.0]
embedding := []float32{1.0, 0.0, 0.0}
var embedding []float32
switch doc.Content[0].Text {
case "This is document one":
// Embedding for document one is the closest to the query
embedding = []float32{0.9, 0.1, 0.0}
case "This is document two":
// Embedding for document two is less close to the query
embedding = []float32{0.7, 0.2, 0.1}
case "This is document three":
// Embedding for document three is even further from the query
embedding = []float32{0.4, 0.3, 0.3}
case "This is input query":
// Embedding for the input query
embedding = []float32{0.9, 0.1, 0.0}
default:
// Default embedding for any other documents
embedding = []float32{0.0, 0.0, 0.0}
}

embeddings = append(embeddings, &ai.DocumentEmbedding{Embedding: embedding})
_ = doc
}
return &ai.EmbedResponse{Embeddings: embeddings}, nil
}

// To run this test you must have a Firestore database initialized in a GCP project, with a vector indexed collection (of dimension 3).
// Warning: This test will delete all documents in the collection in cleanup.

func TestFirestoreRetriever(t *testing.T) {
ctx := context.Background()
Expand All @@ -55,7 +86,7 @@ func TestFirestoreRetriever(t *testing.T) {
defer client.Close()

// Clean up the collection before the test
deleteCollection(ctx, client, *testCollection, t)
defer deleteCollection(ctx, client, *testCollection, t)

// Initialize the embedder
embedder := &MockEmbedder{}
Expand All @@ -71,6 +102,12 @@ func TestFirestoreRetriever(t *testing.T) {
{"doc3", "This is document three", map[string]interface{}{"metadata": "meta3"}},
}

// Expected document text content in order of relevance for the query
expectedTexts := []string{
"This is document one",
"This is document two",
}

for _, doc := range testDocs {
// Create an ai.Document
aiDoc := ai.DocumentFromText(doc.Text, doc.Data)
Expand Down Expand Up @@ -131,7 +168,7 @@ func TestFirestoreRetriever(t *testing.T) {
}

// Create a retriever request with the input document
queryText := "This is a query similar to document one"
queryText := "This is input query"
inputDocument := ai.DocumentFromText(queryText, nil)

req := &ai.RetrieverRequest{
Expand All @@ -149,28 +186,39 @@ func TestFirestoreRetriever(t *testing.T) {
t.Fatalf("No documents retrieved")
}

// Verify the content of all retrieved documents against the expected list
for i, doc := range resp.Documents {
t.Logf("Retrieved Document %d: %s", i+1, doc.Content[0].Text)
}
if i >= len(expectedTexts) {
t.Errorf("More documents retrieved than expected. Retrieved: %d, Expected: %d", len(resp.Documents), len(expectedTexts))
break
}

// Optionally, check if the top retrieved document is the expected one
expectedFirstDoc := "This is document one"
if resp.Documents[0].Content[0].Text != expectedFirstDoc {
t.Errorf("Expected first retrieved document to be '%s', but got '%s'", expectedFirstDoc, resp.Documents[0].Content[0].Text)
if doc.Content[0].Text != expectedTexts[i] {
t.Errorf("Mismatch in document %d content. Expected: '%s', Got: '%s'", i+1, expectedTexts[i], doc.Content[0].Text)
} else {
t.Logf("Retrieved Document %d matches expected content: '%s'", i+1, expectedTexts[i])
}
}
}

// Helper function to delete all documents in a collection
func deleteCollection(ctx context.Context, client *firestore.Client, collectionName string, t *testing.T) {
docs, err := client.Collection(collectionName).Documents(ctx).GetAll()
if err != nil {
t.Logf("Failed to get documents for deletion: %v\n", err)
return
}
for _, doc := range docs {
_, err := doc.Ref.Delete(ctx)
// Get all documents in the collection
iter := client.Collection(collectionName).Documents(ctx)
for {
doc, err := iter.Next()
if err == iterator.Done {
break // No more documents
}
if err != nil {
t.Fatalf("Failed to iterate documents for deletion: %v", err)
}

// Delete each document
_, err = doc.Ref.Delete(ctx)
if err != nil {
log.Printf("Failed to delete document %s: %v\n", doc.Ref.ID, err)
t.Errorf("Failed to delete document %s: %v", doc.Ref.ID, err)
} else {
t.Logf("Deleted document: %s", doc.Ref.ID)
}
}
}

0 comments on commit 521a059

Please sign in to comment.