diff --git a/go/plugins/firebase/retriever.go b/go/plugins/firebase/retriever.go index b9d343d56..31c0cc1b1 100644 --- a/go/plugins/firebase/retriever.go +++ b/go/plugins/firebase/retriever.go @@ -1,142 +1,112 @@ -// Copyright 2024 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package firebase import ( "context" "fmt" + "log" "cloud.google.com/go/firestore" "github.com/firebase/genkit/go/ai" ) +type VectorType int + +const ( + Vector64 VectorType = iota +) + const provider = "firebase" -// RetrieverOptions defines the configuration for the retriever. type RetrieverOptions struct { Name string Label string Client *firestore.Client - Embedder ai.Embedder - EmbedderOptions ai.EmbedOption Collection string + Embedder ai.Embedder VectorField string - MetadataFields []string // Optional: if empty, metadata will not be retrieved + MetadataFields []string ContentField string + Limit int DistanceMeasure firestore.DistanceMeasure -} - -type RetrieverRequestOptions struct { - Limit int `json:"limit,omitempty"` // maximum number of values to retrieve - DistanceMeasure firestore.DistanceMeasure + VectorType VectorType } func DefineFirestoreRetriever(cfg RetrieverOptions) (ai.Retriever, error) { - - coll := cfg.Client.Collection(cfg.Collection) - if coll == nil { - return nil, fmt.Errorf("DefineFirestoreRetriever: collection path %q is invalid", cfg.Collection) + if cfg.VectorType != Vector64 { + return nil, fmt.Errorf("DefineFirestoreRetriever: only Vector64 is supported") } Retrieve := func(ctx context.Context, req *ai.RetrieverRequest) (*ai.RetrieverResponse, error) { - if req == nil { - return nil, fmt.Errorf("retriever request is nil") - } - - options := RetrieverRequestOptions{Limit: 10, DistanceMeasure: cfg.DistanceMeasure} - - if req.Options != nil { - // Ensure that the options are of the correct type - parsedOptions, ok := req.Options.(*RetrieverRequestOptions) - if !ok { - return nil, fmt.Errorf("firebase.Retrieve options have type %T, want %T", req.Options, &RetrieverRequestOptions{}) - } - options = *parsedOptions + if cfg.Client == nil { + return nil, fmt.Errorf("DefineFirestoreRetriever: Firestore client is not provided") } - if cfg.Embedder == nil { - return nil, fmt.Errorf("embedder is nil in config") + if req.Document == nil { + return nil, fmt.Errorf("DefineFirestoreRetriever: Request document is nil") } - // Use the embedder to convert the document we want to retrieve into a vector. - ereq := &ai.EmbedRequest{ - Documents: []*ai.Document{req.Document}, - } - - eres, err := cfg.Embedder.Embed(ctx, ereq) + // Generate query embedding using the Embedder + embedRequest := &ai.EmbedRequest{Documents: []*ai.Document{req.Document}} + embedResponse, err := cfg.Embedder.Embed(ctx, embedRequest) if err != nil { - return nil, fmt.Errorf("%s index embedding failed: %v", provider, err) + return nil, fmt.Errorf("DefineFirestoreRetriever: Embedding failed: %v", err) } - if eres == nil || len(eres.Embeddings) == 0 { - return nil, fmt.Errorf("embedding result is nil or empty") + if len(embedResponse.Embeddings) == 0 { + return nil, fmt.Errorf("DefineFirestoreRetriever: No embeddings returned") } - embedding := eres.Embeddings[0].Embedding - - distanceMeasure := cfg.DistanceMeasure - - if options.DistanceMeasure != 0 { - distanceMeasure = options.DistanceMeasure + queryEmbedding := embedResponse.Embeddings[0].Embedding + if len(queryEmbedding) == 0 { + return nil, fmt.Errorf("DefineFirestoreRetriever: Generated embedding is empty") } - fmt.Printf("Retrieving nearest documents to embedding %v\n", embedding) - - query := coll.FindNearest(cfg.VectorField, embedding, options.Limit, distanceMeasure, nil) - // Execute the query - iter := query.Documents(ctx) - gotDocs, err := iter.GetAll() - + // Convert to []float64 + queryEmbedding64 := make([]float64, len(queryEmbedding)) + for i, val := range queryEmbedding { + queryEmbedding64[i] = float64(val) + } + // Perform the FindNearest query + vectorQuery := cfg.Client.Collection(cfg.Collection).FindNearest( + cfg.VectorField, + firestore.Vector64(queryEmbedding64), + cfg.Limit, + cfg.DistanceMeasure, + nil, + ) + iter := vectorQuery.Documents(ctx) + + results, err := iter.GetAll() if err != nil { - return nil, fmt.Errorf("getting documents: %v", err) + return nil, fmt.Errorf("DefineFirestoreRetriever: FindNearest query failed: %v", err) } - genkitDocs := make([]*ai.Document, len(gotDocs)) - for i, doc := range gotDocs { - // Call doc.Data() once and cache the result - data := doc.Data() + // Prepare the documents to return in the response + var documents []*ai.Document + for _, result := range results { + data := result.Data() - // Extract the content field + // Ensure content field exists and is of type string content, ok := data[cfg.ContentField].(string) if !ok { - fmt.Printf("content field is missing or not a string in document %v", doc.Ref.ID) + // TODO: use genkit logger + log.Printf("Content field %s missing or not a string in document %s", cfg.ContentField, result.Ref.ID) continue } - out := make(map[string]any) - out["content"] = content - - metadata := make(map[string]any) - // Use the cached `data` to retrieve the metadata fields - if len(cfg.MetadataFields) > 0 { - for _, field := range cfg.MetadataFields { - metadata[field] = data[field] - } - } else { - for k, v := range data { - if k != cfg.VectorField && k != cfg.ContentField { - metadata[k] = v - } + // Extract metadata fields + metadata := make(map[string]interface{}) + for _, field := range cfg.MetadataFields { + if value, ok := data[field]; ok { + metadata[field] = value } } - out["metadata"] = metadata - genkitDocs[i] = ai.DocumentFromText(content, metadata) + doc := ai.DocumentFromText(content, metadata) + documents = append(documents, doc) } - return &ai.RetrieverResponse{Documents: genkitDocs}, nil + return &ai.RetrieverResponse{Documents: documents}, nil } return ai.DefineRetriever(provider, cfg.Name, Retrieve), nil diff --git a/go/plugins/firebase/retriever_test.go b/go/plugins/firebase/retriever_test.go index f3213647a..7ca8684e1 100644 --- a/go/plugins/firebase/retriever_test.go +++ b/go/plugins/firebase/retriever_test.go @@ -3,177 +3,174 @@ package firebase import ( "context" "flag" - "fmt" - "os" + "log" "testing" "cloud.google.com/go/firestore" + firebasev4 "firebase.google.com/go/v4" "github.com/firebase/genkit/go/ai" - vertexai "github.com/firebase/genkit/go/plugins/vertexai" ) var ( testProjectID = flag.String("test-project-id", "", "GCP Project ID to use for tests") - testCollection = flag.String("test-collection", "", "Firestore collection to use for tests") - testVectorField = flag.String("test-vector-field", "", "Firestore vector field to use for tests") - testLocation = flag.String("test-location", "us-central1", "Firestore location to use for tests") + testCollection = flag.String("test-collection", "testR2", "Firestore collection to use for tests") + testVectorField = flag.String("test-vector-field", "embedding", "Field name for vector embeddings") ) -func TestFirestoreRetriever(t *testing.T) { - // Check if the required flags are set, otherwise skip the test - if *testProjectID == "" { - t.Skip("skipping test because -test-project-id flag not used") - } - if *testCollection == "" { - t.Skip("skipping test because -test-collection flag not used") - } - if *testVectorField == "" { - t.Skip("skipping test because -test-vector-field flag not used") +// MockEmbedder implements the Embedder interface for testing purposes +type MockEmbedder struct{} + +func (e *MockEmbedder) Name() string { + return "MockEmbedder" +} + +func (e *MockEmbedder) Embed(ctx context.Context, req *ai.EmbedRequest) (*ai.EmbedResponse, error) { + var embeddings []*ai.DocumentEmbedding + for _, doc := range req.Documents { + // For simplicity, all embeddings are [1.0, 0.0, 0.0] + embedding := []float32{1.0, 0.0, 0.0} + embeddings = append(embeddings, &ai.DocumentEmbedding{Embedding: embedding}) + _ = doc } + return &ai.EmbedResponse{Embeddings: embeddings}, nil +} - // Set environment variables for Firebase emulators - os.Setenv("FIRESTORE_EMULATOR_HOST", "127.0.0.1:8080") - os.Setenv("FIREBASE_AUTH_EMULATOR_HOST", "127.0.0.1:9099") - os.Setenv("FIREBASE_STORAGE_EMULATOR_HOST", "127.0.0.1:9199") - os.Setenv("FIREBASE_DATABASE_EMULATOR_HOST", "127.0.0.1:9000") +// To run this test you must have a Firestore database initialized in a GCP project, with a vector indexed collection (of dimension 3). - // Use context for initializing Firebase app and Vertex AI embedder +func TestFirestoreRetriever(t *testing.T) { ctx := context.Background() - // Initialize Vertex AI configuration - vertexAiConfig := vertexai.Config{ - ProjectID: *testProjectID, - Location: *testLocation, + // Initialize Firebase app + conf := &firebasev4.Config{ProjectID: *testProjectID} + app, err := firebasev4.NewApp(ctx, conf) + if err != nil { + t.Fatalf("Failed to create Firebase app: %v", err) } - err := vertexai.Init(ctx, &vertexAiConfig) + + // Initialize Firestore client + client, err := app.Firestore(ctx) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to create Firestore client: %v", err) } + defer client.Close() - // Get the embedder - testEmbedder := vertexai.Embedder("textembedding-gecko@003") + // Clean up the collection before the test + deleteCollection(ctx, client, *testCollection, t) - if testEmbedder == nil { - t.Fatal("embedder is nil") - } + // Initialize the embedder + embedder := &MockEmbedder{} - // Initialize Firebase plugin configuration - pluginConfig := FirebasePluginConfig{ - ProjectID: *testProjectID, + // Insert test documents with embeddings generated by the embedder + testDocs := []struct { + ID string + Text string + Data map[string]interface{} + }{ + {"doc1", "This is document one", map[string]interface{}{"metadata": "meta1"}}, + {"doc2", "This is document two", map[string]interface{}{"metadata": "meta2"}}, + {"doc3", "This is document three", map[string]interface{}{"metadata": "meta3"}}, } - // Initialize Firebase - if err := Init(ctx, &pluginConfig); err != nil { - t.Fatal(err) - } - defer unInit() + for _, doc := range testDocs { + // Create an ai.Document + aiDoc := ai.DocumentFromText(doc.Text, doc.Data) - // Create Firestore client - client, err := firestore.NewClient(ctx, *testProjectID) - if err != nil { - t.Fatal(err) - } - defer client.Close() + // Generate embedding + embedRequest := &ai.EmbedRequest{Documents: []*ai.Document{aiDoc}} + embedResponse, err := embedder.Embed(ctx, embedRequest) + if err != nil { + t.Fatalf("Failed to generate embedding for document %s: %v", doc.ID, err) + } + + if len(embedResponse.Embeddings) == 0 { + t.Fatalf("No embeddings returned for document %s", doc.ID) + } - // Set up test data in Firestore - if err := setupTestCollection(ctx, client, *testCollection, *testVectorField, testEmbedder); err != nil { - t.Fatalf("failed to set up test collection: %v", err) + embedding := embedResponse.Embeddings[0].Embedding + if len(embedding) == 0 { + t.Fatalf("Generated embedding is empty for document %s", doc.ID) + } + + // Convert to []float64 + embedding64 := make([]float64, len(embedding)) + for i, val := range embedding { + embedding64[i] = float64(val) + } + + // Store in Firestore + _, err = client.Collection(*testCollection).Doc(doc.ID).Set(ctx, map[string]interface{}{ + "text": doc.Text, + "metadata": doc.Data["metadata"], + *testVectorField: firestore.Vector64(embedding64), + }) + if err != nil { + t.Fatalf("Failed to insert document %s: %v", doc.ID, err) + } + t.Logf("Inserted document: %s with embedding: %v", doc.ID, embedding64) } - // Define retriever configuration - retrieverConfig := RetrieverOptions{ + // Define retriever options + retrieverOptions := RetrieverOptions{ Name: "test-retriever", Label: "Test Retriever", Client: client, - Embedder: testEmbedder, Collection: *testCollection, + Embedder: embedder, VectorField: *testVectorField, + MetadataFields: []string{"metadata"}, ContentField: "text", + Limit: 2, DistanceMeasure: firestore.DistanceMeasureEuclidean, + VectorType: Vector64, } - // Define the Firestore retriever - retriever, err := DefineFirestoreRetriever(retrieverConfig) + // Define the retriever + retriever, err := DefineFirestoreRetriever(retrieverOptions) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to define retriever: %v", err) } - // Create a test document - testDocument := ai.DocumentFromText("Test document", map[string]any{"metadata": "test"}) + // Create a retriever request with the input document + queryText := "This is a query similar to document one" + inputDocument := ai.DocumentFromText(queryText, nil) - // Create a retriever request req := &ai.RetrieverRequest{ - Document: testDocument, - Options: &RetrieverRequestOptions{Limit: 2, DistanceMeasure: firestore.DistanceMeasureEuclidean}, + Document: inputDocument, } - // Retrieve documents using the retriever + // Perform the retrieval resp, err := retriever.Retrieve(ctx, req) if err != nil { - t.Fatal(err) + t.Fatalf("Retriever failed: %v", err) } - // Log and validate the response - if resp == nil { - t.Fatal("expected non-nil response, got nil") + // Check the retrieved documents + if len(resp.Documents) == 0 { + t.Fatalf("No documents retrieved") } - t.Logf("Retrieved %d documents", len(resp.Documents)) - if len(resp.Documents) != 2 { - t.Errorf("expected 2 documents, got %d", len(resp.Documents)) + + for i, doc := range resp.Documents { + t.Logf("Retrieved Document %d: %s", i+1, doc.Content[0].Text) } - for _, doc := range resp.Documents { - if doc == nil { - t.Error("retrieved document is nil") - } + // Optionally, check if the top retrieved document is the expected one + expectedFirstDoc := "This is document one" + if resp.Documents[0].Content[0].Text != expectedFirstDoc { + t.Errorf("Expected first retrieved document to be '%s', but got '%s'", expectedFirstDoc, resp.Documents[0].Content[0].Text) } - t.Logf("Doc with content \n\n %s \n\n retrieved \n\n", resp.Documents[0].Content[0].Text) - t.Logf("Doc with content \n\n %s \n\n retrieved \n\n", resp.Documents[1].Content[0].Text) } -// setupTestCollection initializes a Firestore collection with sample documents. -func setupTestCollection(ctx context.Context, client *firestore.Client, collection string, vectorField string, embedder ai.Embedder) error { - // Delete existing documents in the collection - iter := client.Collection(collection).Documents(ctx) - docs, err := iter.GetAll() +// Helper function to delete all documents in a collection +func deleteCollection(ctx context.Context, client *firestore.Client, collectionName string, t *testing.T) { + docs, err := client.Collection(collectionName).Documents(ctx).GetAll() if err != nil { - return fmt.Errorf("failed to list documents for deletion: %v", err) + t.Logf("Failed to get documents for deletion: %v\n", err) + return } for _, doc := range docs { - if _, err := doc.Ref.Delete(ctx); err != nil { - return fmt.Errorf("failed to delete document %s: %v", doc.Ref.ID, err) - } - } - - // Add 10 sample documents with embeddings and text content - for i := 0; i < 10; i++ { - docID := fmt.Sprintf("doc-%d", i) - text := fmt.Sprintf("This is test document number %d", i) - - doc := ai.DocumentFromText("Test document", map[string]any{"metadata": "test"}) // Create a document from text - - docs := []*ai.Document{doc} - - // Generate embedding for the text - embedReq := &ai.EmbedRequest{ - Documents: docs, - } - embedResp, err := embedder.Embed(ctx, embedReq) - if err != nil { - return fmt.Errorf("failed to generate embedding for document %s: %v", docID, err) - } - - data := map[string]interface{}{ - "text": text, - vectorField: embedResp.Embeddings[0].Embedding, - "metadata": map[string]interface{}{ - "index": i, - }, - } - _, err = client.Collection(collection).Doc(docID).Set(ctx, data) + _, err := doc.Ref.Delete(ctx) if err != nil { - return fmt.Errorf("failed to create document %s: %v", docID, err) + log.Printf("Failed to delete document %s: %v\n", doc.Ref.ID, err) } } - return nil }