Skip to content

Commit

Permalink
Refactor google genai
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Dec 15, 2023
1 parent 4f5ae01 commit 6048a5e
Show file tree
Hide file tree
Showing 30 changed files with 662 additions and 815 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
---
title: Google GenAI
description: All about Google GenAI.
weight: 40
---

```go
ctx := context.Background()

client, err := generativelanguage.NewGenerativeClient(ctx)
if err != nil {
// Error handling
}

defer client.Close()

llm, err := chatmodel.NewGoogleGenAI(client)
if err != nil {
// Error handling
}
```
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
title: Ollama
description: All about Ollama.
weight: 40
weight: 50
---

```go
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
title: OpenAI
description: All about OpenAI.
weight: 50
weight: 60
---

```go
Expand Down
21 changes: 0 additions & 21 deletions docs/content/en/docs/llms_and_prompts/models/chatmodels/palm.md

This file was deleted.

21 changes: 0 additions & 21 deletions docs/content/en/docs/llms_and_prompts/models/llms/gemini.md

This file was deleted.

21 changes: 21 additions & 0 deletions docs/content/en/docs/llms_and_prompts/models/llms/google_genai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
---
title: Google GenAI
description: All about Google GenAI.
weight: 40
---

```go
ctx := context.Background()

client, err := generativelanguage.NewGenerativeClient(ctx)
if err != nil {
// Error handling
}

defer client.Close()

llm, err := llm.NewGoogleGenAI(client)
if err != nil {
// Error handling
}
```
21 changes: 0 additions & 21 deletions docs/content/en/docs/llms_and_prompts/models/llms/palm.md

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
title: Sagemaker Endpoint
description: All about Sagemaker Endpoint.
weight: 80
weight: 70
---

1. Create a ContentHandler for Input/Output Transformation
Expand Down
91 changes: 91 additions & 0 deletions embedding/google_genai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package embedding

import (
"context"

"cloud.google.com/go/ai/generativelanguage/apiv1/generativelanguagepb"
"github.com/googleapis/gax-go/v2"
"github.com/hupe1980/golc/schema"
"github.com/hupe1980/golc/util"
)

// Compile time check to ensure GoogleGenAI satisfies the Embedder interface.
var _ schema.Embedder = (*GoogleGenAI)(nil)

// GoogleGenAIClient is an interface for the GoogleGenAI client.
type GoogleGenAIClient interface {
EmbedContent(context.Context, *generativelanguagepb.EmbedContentRequest, ...gax.CallOption) (*generativelanguagepb.EmbedContentResponse, error)
BatchEmbedContents(context.Context, *generativelanguagepb.BatchEmbedContentsRequest, ...gax.CallOption) (*generativelanguagepb.BatchEmbedContentsResponse, error)
}

// GoogleGenAIOptions contains options for configuring the GoogleGenAI client.
type GoogleGenAIOptions struct {
ModelName string
}

// GoogleGenAI is a client for the GoogleGenAI embedding service.
type GoogleGenAI struct {
client GoogleGenAIClient
opts GoogleGenAIOptions
}

// NewGoogleGenAI creates a new instance of the GoogleGenAI client.
func NewGoogleGenAI(client GoogleGenAIClient, optFns ...func(o *GoogleGenAIOptions)) *GoogleGenAI {
opts := GoogleGenAIOptions{
ModelName: "models/embedding-001",
}

for _, fn := range optFns {
fn(&opts)
}

return &GoogleGenAI{
client: client,
opts: opts,
}
}

// EmbedDocuments embeds a list of documents and returns their embeddings.
func (e *GoogleGenAI) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) {
requests := make([]*generativelanguagepb.EmbedContentRequest, len(texts))

for i, t := range texts {
requests[i] = &generativelanguagepb.EmbedContentRequest{
Model: e.opts.ModelName,
Content: &generativelanguagepb.Content{Parts: []*generativelanguagepb.Part{{
Data: &generativelanguagepb.Part_Text{Text: t},
}}},
}
}

res, err := e.client.BatchEmbedContents(ctx, &generativelanguagepb.BatchEmbedContentsRequest{
Model: e.opts.ModelName,
Requests: requests,
})
if err != nil {
return nil, err
}

embeddings := make([][]float64, len(texts))

for i, e := range res.Embeddings {
embeddings[i] = util.Float32ToFloat64(e.Values)
}

return embeddings, nil
}

// EmbedQuery embeds a single query and returns its embedding.
func (e *GoogleGenAI) EmbedQuery(ctx context.Context, text string) ([]float64, error) {
res, err := e.client.EmbedContent(ctx, &generativelanguagepb.EmbedContentRequest{
Model: e.opts.ModelName,
Content: &generativelanguagepb.Content{Parts: []*generativelanguagepb.Part{{
Data: &generativelanguagepb.Part_Text{Text: text},
}}},
})
if err != nil {
return nil, err
}

return util.Float32ToFloat64(res.Embedding.Values), nil
}
132 changes: 132 additions & 0 deletions embedding/google_genai_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package embedding

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/assert"

"cloud.google.com/go/ai/generativelanguage/apiv1/generativelanguagepb"
"github.com/googleapis/gax-go/v2"
)

func TestGoogleGenAI(t *testing.T) {
// Create a new instance of the GoogleGenAI model with the custom mock client.
client := &mockGoogleGenAIClient{}

googleGenAIModel := NewGoogleGenAI(client)

// Test cases
t.Run("Test embedding of documents", func(t *testing.T) {
// Define a list of texts to embed.
texts := []string{"text1", "text2"}

// Define a mock response for EmbedText.
mockResponse := &generativelanguagepb.BatchEmbedContentsResponse{
Embeddings: []*generativelanguagepb.ContentEmbedding{
{
Values: []float32{1.0, 2.0, 3.0},
}, {
Values: []float32{4.0, 5.0, 6.0},
}},
}

// Set the mock response for EmbedText.
client.respEmbedBatchContents = mockResponse
client.errEmbed = nil

// Embed the documents.
embeddings, err := googleGenAIModel.EmbedDocuments(context.Background(), texts)

// Use assertions to check the results.
assert.NoError(t, err)
assert.NotNil(t, embeddings)
assert.Len(t, embeddings, 2)
assert.Len(t, embeddings[0], 3)
assert.Equal(t, float64(1.0), embeddings[0][0])
assert.Len(t, embeddings[1], 3)
assert.Equal(t, float64(4.0), embeddings[1][0])
})

t.Run("Test embedding error", func(t *testing.T) {
// Define a list of texts to embed.
texts := []string{"text1"}

// Set the mock error for EmbedText.
client.respEmbedBatchContents = nil
client.errEmbed = errors.New("Test error")

// Embed the documents.
embeddings, err := googleGenAIModel.EmbedDocuments(context.Background(), texts)

// Use assertions to check the error and embeddings.
assert.Error(t, err)
assert.Nil(t, embeddings)
})

t.Run("Test embedding of a single query", func(t *testing.T) {
// Define a query text.
query := "query text"

// Define a mock response for EmbedText.
mockResponse := &generativelanguagepb.EmbedContentResponse{
Embedding: &generativelanguagepb.ContentEmbedding{
Values: []float32{1.0, 2.0, 3.0},
},
}

// Set the mock response for EmbedText.
client.respEmbedContent = mockResponse
client.errEmbed = nil

// Embed the query.
embedding, err := googleGenAIModel.EmbedQuery(context.Background(), query)

// Use assertions to check the results.
assert.NoError(t, err)
assert.NotNil(t, embedding)
assert.Len(t, embedding, 3)
})

t.Run("Test embedding error for query", func(t *testing.T) {
// Define a query text.
query := "query text"

// Set the mock error for EmbedText.
client.respEmbedContent = nil
client.errEmbed = errors.New("Test error")

// Embed the query.
embedding, err := googleGenAIModel.EmbedQuery(context.Background(), query)

// Use assertions to check the error and embedding.
assert.Error(t, err)
assert.Nil(t, embedding)
})
}

// mockGoogleGenAIClient is a custom mock implementation of the GoogleGenAIClient interface.
type mockGoogleGenAIClient struct {
respEmbedContent *generativelanguagepb.EmbedContentResponse
respEmbedBatchContents *generativelanguagepb.BatchEmbedContentsResponse
errEmbed error
}

// EmbedContent mocks the EmbedContent method of the GoogleGenAIClient interface.
func (m *mockGoogleGenAIClient) EmbedContent(context.Context, *generativelanguagepb.EmbedContentRequest, ...gax.CallOption) (*generativelanguagepb.EmbedContentResponse, error) {
if m.errEmbed != nil {
return nil, m.errEmbed
}

return m.respEmbedContent, nil
}

// BatchEmbedContents mocks the BatchEmbedContents method of the GoogleGenAIClient interface.
func (m *mockGoogleGenAIClient) BatchEmbedContents(context.Context, *generativelanguagepb.BatchEmbedContentsRequest, ...gax.CallOption) (*generativelanguagepb.BatchEmbedContentsResponse, error) {
if m.errEmbed != nil {
return nil, m.errEmbed
}

return m.respEmbedBatchContents, nil
}
Loading

0 comments on commit 6048a5e

Please sign in to comment.