diff --git a/docs/content/en/docs/llms_and_prompts/models/chatmodels/google_genai.md b/docs/content/en/docs/llms_and_prompts/models/chatmodels/google_genai.md new file mode 100644 index 0000000..305e1fb --- /dev/null +++ b/docs/content/en/docs/llms_and_prompts/models/chatmodels/google_genai.md @@ -0,0 +1,21 @@ +--- +title: Google GenAI +description: All about Google GenAI. +weight: 40 +--- + +```go +ctx := context.Background() + +client, err := generativelanguage.NewGenerativeClient(ctx) +if err != nil { + // Error handling +} + +defer client.Close() + +llm, err := chatmodel.NewGoogleGenAI(client) +if err != nil { + // Error handling +} +``` \ No newline at end of file diff --git a/docs/content/en/docs/llms_and_prompts/models/chatmodels/ollama.md b/docs/content/en/docs/llms_and_prompts/models/chatmodels/ollama.md index 913ebeb..fc5b8bc 100644 --- a/docs/content/en/docs/llms_and_prompts/models/chatmodels/ollama.md +++ b/docs/content/en/docs/llms_and_prompts/models/chatmodels/ollama.md @@ -1,7 +1,7 @@ --- title: Ollama description: All about Ollama. -weight: 40 +weight: 50 --- ```go diff --git a/docs/content/en/docs/llms_and_prompts/models/chatmodels/openai.md b/docs/content/en/docs/llms_and_prompts/models/chatmodels/openai.md index cdd1306..15c70a2 100644 --- a/docs/content/en/docs/llms_and_prompts/models/chatmodels/openai.md +++ b/docs/content/en/docs/llms_and_prompts/models/chatmodels/openai.md @@ -1,7 +1,7 @@ --- title: OpenAI description: All about OpenAI. -weight: 50 +weight: 60 --- ```go diff --git a/docs/content/en/docs/llms_and_prompts/models/chatmodels/palm.md b/docs/content/en/docs/llms_and_prompts/models/chatmodels/palm.md deleted file mode 100644 index ce49d10..0000000 --- a/docs/content/en/docs/llms_and_prompts/models/chatmodels/palm.md +++ /dev/null @@ -1,21 +0,0 @@ ---- -title: Palm -description: All about Palm. -weight: 60 ---- - -```go -ctx := context.Background() - -// see https://pkg.go.dev/cloud.google.com/go/ai@v0.1.1/generativelanguage/apiv1beta2 -c, err := generativelanguage.NewDiscussClient(ctx) -if err != nil { - // Error handling -} -defer c.Close() - -palm, err := chatmodel.NewPalm(c) -if err != nil { - // Error handling -} -``` \ No newline at end of file diff --git a/docs/content/en/docs/llms_and_prompts/models/llms/gemini.md b/docs/content/en/docs/llms_and_prompts/models/llms/gemini.md deleted file mode 100644 index 9161f26..0000000 --- a/docs/content/en/docs/llms_and_prompts/models/llms/gemini.md +++ /dev/null @@ -1,21 +0,0 @@ ---- -title: Gemini -description: All about Gemini. -weight: 40 ---- - -```go -ctx := context.Background() - -client, err := genai.NewClient(ctx, option.WithAPIKey(os.Getenv("GEMINI_API_KEY"))) -if err != nil { - // Error handling -} - -defer client.Close() - -llm, err := llm.NewGemini(client) -if err != nil { - // Error handling -} -``` \ No newline at end of file diff --git a/docs/content/en/docs/llms_and_prompts/models/llms/google_genai.md b/docs/content/en/docs/llms_and_prompts/models/llms/google_genai.md new file mode 100644 index 0000000..46d00f5 --- /dev/null +++ b/docs/content/en/docs/llms_and_prompts/models/llms/google_genai.md @@ -0,0 +1,21 @@ +--- +title: Google GenAI +description: All about Google GenAI. +weight: 40 +--- + +```go +ctx := context.Background() + +client, err := generativelanguage.NewGenerativeClient(ctx) +if err != nil { + // Error handling +} + +defer client.Close() + +llm, err := llm.NewGoogleGenAI(client) +if err != nil { + // Error handling +} +``` \ No newline at end of file diff --git a/docs/content/en/docs/llms_and_prompts/models/llms/palm.md b/docs/content/en/docs/llms_and_prompts/models/llms/palm.md deleted file mode 100644 index 81d1a6a..0000000 --- a/docs/content/en/docs/llms_and_prompts/models/llms/palm.md +++ /dev/null @@ -1,21 +0,0 @@ ---- -title: Palm -description: All about Palm. -weight: 70 ---- - -```go -ctx := context.Background() - -client, err := generativelanguage.NewTextClient(ctx) -if err != nil { - // Error handling -} - -defer client.Close() - -palm, err := llm.NewPalm(client) -if err != nil { - // Error handling -} -``` \ No newline at end of file diff --git a/docs/content/en/docs/llms_and_prompts/models/llms/sagemaker_endpoint.md b/docs/content/en/docs/llms_and_prompts/models/llms/sagemaker_endpoint.md index 23c9bda..37a2749 100644 --- a/docs/content/en/docs/llms_and_prompts/models/llms/sagemaker_endpoint.md +++ b/docs/content/en/docs/llms_and_prompts/models/llms/sagemaker_endpoint.md @@ -1,7 +1,7 @@ --- title: Sagemaker Endpoint description: All about Sagemaker Endpoint. -weight: 80 +weight: 70 --- 1. Create a ContentHandler for Input/Output Transformation diff --git a/embedding/google_genai.go b/embedding/google_genai.go new file mode 100644 index 0000000..46c61fd --- /dev/null +++ b/embedding/google_genai.go @@ -0,0 +1,91 @@ +package embedding + +import ( + "context" + + "cloud.google.com/go/ai/generativelanguage/apiv1/generativelanguagepb" + "github.com/googleapis/gax-go/v2" + "github.com/hupe1980/golc/schema" + "github.com/hupe1980/golc/util" +) + +// Compile time check to ensure GoogleGenAI satisfies the Embedder interface. +var _ schema.Embedder = (*GoogleGenAI)(nil) + +// GoogleGenAIClient is an interface for the GoogleGenAI client. +type GoogleGenAIClient interface { + EmbedContent(context.Context, *generativelanguagepb.EmbedContentRequest, ...gax.CallOption) (*generativelanguagepb.EmbedContentResponse, error) + BatchEmbedContents(context.Context, *generativelanguagepb.BatchEmbedContentsRequest, ...gax.CallOption) (*generativelanguagepb.BatchEmbedContentsResponse, error) +} + +// GoogleGenAIOptions contains options for configuring the GoogleGenAI client. +type GoogleGenAIOptions struct { + ModelName string +} + +// GoogleGenAI is a client for the GoogleGenAI embedding service. +type GoogleGenAI struct { + client GoogleGenAIClient + opts GoogleGenAIOptions +} + +// NewGoogleGenAI creates a new instance of the GoogleGenAI client. +func NewGoogleGenAI(client GoogleGenAIClient, optFns ...func(o *GoogleGenAIOptions)) *GoogleGenAI { + opts := GoogleGenAIOptions{ + ModelName: "models/embedding-001", + } + + for _, fn := range optFns { + fn(&opts) + } + + return &GoogleGenAI{ + client: client, + opts: opts, + } +} + +// EmbedDocuments embeds a list of documents and returns their embeddings. +func (e *GoogleGenAI) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) { + requests := make([]*generativelanguagepb.EmbedContentRequest, len(texts)) + + for i, t := range texts { + requests[i] = &generativelanguagepb.EmbedContentRequest{ + Model: e.opts.ModelName, + Content: &generativelanguagepb.Content{Parts: []*generativelanguagepb.Part{{ + Data: &generativelanguagepb.Part_Text{Text: t}, + }}}, + } + } + + res, err := e.client.BatchEmbedContents(ctx, &generativelanguagepb.BatchEmbedContentsRequest{ + Model: e.opts.ModelName, + Requests: requests, + }) + if err != nil { + return nil, err + } + + embeddings := make([][]float64, len(texts)) + + for i, e := range res.Embeddings { + embeddings[i] = util.Float32ToFloat64(e.Values) + } + + return embeddings, nil +} + +// EmbedQuery embeds a single query and returns its embedding. +func (e *GoogleGenAI) EmbedQuery(ctx context.Context, text string) ([]float64, error) { + res, err := e.client.EmbedContent(ctx, &generativelanguagepb.EmbedContentRequest{ + Model: e.opts.ModelName, + Content: &generativelanguagepb.Content{Parts: []*generativelanguagepb.Part{{ + Data: &generativelanguagepb.Part_Text{Text: text}, + }}}, + }) + if err != nil { + return nil, err + } + + return util.Float32ToFloat64(res.Embedding.Values), nil +} diff --git a/embedding/google_genai_test.go b/embedding/google_genai_test.go new file mode 100644 index 0000000..095efaa --- /dev/null +++ b/embedding/google_genai_test.go @@ -0,0 +1,132 @@ +package embedding + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "cloud.google.com/go/ai/generativelanguage/apiv1/generativelanguagepb" + "github.com/googleapis/gax-go/v2" +) + +func TestGoogleGenAI(t *testing.T) { + // Create a new instance of the GoogleGenAI model with the custom mock client. + client := &mockGoogleGenAIClient{} + + googleGenAIModel := NewGoogleGenAI(client) + + // Test cases + t.Run("Test embedding of documents", func(t *testing.T) { + // Define a list of texts to embed. + texts := []string{"text1", "text2"} + + // Define a mock response for EmbedText. + mockResponse := &generativelanguagepb.BatchEmbedContentsResponse{ + Embeddings: []*generativelanguagepb.ContentEmbedding{ + { + Values: []float32{1.0, 2.0, 3.0}, + }, { + Values: []float32{4.0, 5.0, 6.0}, + }}, + } + + // Set the mock response for EmbedText. + client.respEmbedBatchContents = mockResponse + client.errEmbed = nil + + // Embed the documents. + embeddings, err := googleGenAIModel.EmbedDocuments(context.Background(), texts) + + // Use assertions to check the results. + assert.NoError(t, err) + assert.NotNil(t, embeddings) + assert.Len(t, embeddings, 2) + assert.Len(t, embeddings[0], 3) + assert.Equal(t, float64(1.0), embeddings[0][0]) + assert.Len(t, embeddings[1], 3) + assert.Equal(t, float64(4.0), embeddings[1][0]) + }) + + t.Run("Test embedding error", func(t *testing.T) { + // Define a list of texts to embed. + texts := []string{"text1"} + + // Set the mock error for EmbedText. + client.respEmbedBatchContents = nil + client.errEmbed = errors.New("Test error") + + // Embed the documents. + embeddings, err := googleGenAIModel.EmbedDocuments(context.Background(), texts) + + // Use assertions to check the error and embeddings. + assert.Error(t, err) + assert.Nil(t, embeddings) + }) + + t.Run("Test embedding of a single query", func(t *testing.T) { + // Define a query text. + query := "query text" + + // Define a mock response for EmbedText. + mockResponse := &generativelanguagepb.EmbedContentResponse{ + Embedding: &generativelanguagepb.ContentEmbedding{ + Values: []float32{1.0, 2.0, 3.0}, + }, + } + + // Set the mock response for EmbedText. + client.respEmbedContent = mockResponse + client.errEmbed = nil + + // Embed the query. + embedding, err := googleGenAIModel.EmbedQuery(context.Background(), query) + + // Use assertions to check the results. + assert.NoError(t, err) + assert.NotNil(t, embedding) + assert.Len(t, embedding, 3) + }) + + t.Run("Test embedding error for query", func(t *testing.T) { + // Define a query text. + query := "query text" + + // Set the mock error for EmbedText. + client.respEmbedContent = nil + client.errEmbed = errors.New("Test error") + + // Embed the query. + embedding, err := googleGenAIModel.EmbedQuery(context.Background(), query) + + // Use assertions to check the error and embedding. + assert.Error(t, err) + assert.Nil(t, embedding) + }) +} + +// mockGoogleGenAIClient is a custom mock implementation of the GoogleGenAIClient interface. +type mockGoogleGenAIClient struct { + respEmbedContent *generativelanguagepb.EmbedContentResponse + respEmbedBatchContents *generativelanguagepb.BatchEmbedContentsResponse + errEmbed error +} + +// EmbedContent mocks the EmbedContent method of the GoogleGenAIClient interface. +func (m *mockGoogleGenAIClient) EmbedContent(context.Context, *generativelanguagepb.EmbedContentRequest, ...gax.CallOption) (*generativelanguagepb.EmbedContentResponse, error) { + if m.errEmbed != nil { + return nil, m.errEmbed + } + + return m.respEmbedContent, nil +} + +// BatchEmbedContents mocks the BatchEmbedContents method of the GoogleGenAIClient interface. +func (m *mockGoogleGenAIClient) BatchEmbedContents(context.Context, *generativelanguagepb.BatchEmbedContentsRequest, ...gax.CallOption) (*generativelanguagepb.BatchEmbedContentsResponse, error) { + if m.errEmbed != nil { + return nil, m.errEmbed + } + + return m.respEmbedBatchContents, nil +} diff --git a/embedding/palm.go b/embedding/palm.go deleted file mode 100644 index 96b4870..0000000 --- a/embedding/palm.go +++ /dev/null @@ -1,80 +0,0 @@ -package embedding - -import ( - "context" - - "cloud.google.com/go/ai/generativelanguage/apiv1beta2/generativelanguagepb" - "github.com/googleapis/gax-go/v2" - "github.com/hupe1980/golc/schema" -) - -// Compile time check to ensure Palm satisfies the Embedder interface. -var _ schema.Embedder = (*Palm)(nil) - -// PalmClient is an interface for the Palm client. -type PalmClient interface { - EmbedText(context.Context, *generativelanguagepb.EmbedTextRequest, ...gax.CallOption) (*generativelanguagepb.EmbedTextResponse, error) -} - -// PalmOptions contains options for configuring the Palm client. -type PalmOptions struct { - ModelName string -} - -// Palm is a client for the Palm embedding service. -type Palm struct { - client PalmClient - opts PalmOptions -} - -// NewPalm creates a new instance of the Palm client. -func NewPalm(client PalmClient, optFns ...func(o *PalmOptions)) *Palm { - opts := PalmOptions{ - ModelName: "models/embedding-gecko-001", - } - - for _, fn := range optFns { - fn(&opts) - } - - return &Palm{ - client: client, - opts: opts, - } -} - -// EmbedDocuments embeds a list of documents and returns their embeddings. -func (e *Palm) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) { - embeddings := make([][]float64, len(texts)) - - for i, text := range texts { - v, err := e.EmbedQuery(ctx, text) - if err != nil { - return nil, err - } - - embeddings[i] = v - } - - return embeddings, nil -} - -// EmbedQuery embeds a single query and returns its embedding. -func (e *Palm) EmbedQuery(ctx context.Context, text string) ([]float64, error) { - res, err := e.client.EmbedText(ctx, &generativelanguagepb.EmbedTextRequest{ - Model: e.opts.ModelName, - Text: text, - }) - if err != nil { - return nil, err - } - - values := res.GetEmbedding().GetValue() - - embedding := make([]float64, len(values)) - for i, v := range values { - embedding[i] = float64(v) - } - - return embedding, nil -} diff --git a/embedding/palm_test.go b/embedding/palm_test.go deleted file mode 100644 index daeb838..0000000 --- a/embedding/palm_test.go +++ /dev/null @@ -1,115 +0,0 @@ -package embedding - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - - "cloud.google.com/go/ai/generativelanguage/apiv1beta2/generativelanguagepb" - "github.com/googleapis/gax-go/v2" -) - -func TestPalm(t *testing.T) { - // Create a new instance of the Palm model with the custom mock client. - client := &mockPalmClient{} - palmModel := NewPalm(client) - - // Test cases - t.Run("Test embedding of documents", func(t *testing.T) { - // Define a list of texts to embed. - texts := []string{"text1", "text2"} - - // Define a mock response for EmbedText. - mockResponse := &generativelanguagepb.EmbedTextResponse{ - Embedding: &generativelanguagepb.Embedding{ - Value: []float32{1.0, 2.0, 3.0}, - }, - } - - // Set the mock response for EmbedText. - client.respEmbedText = mockResponse - client.errEmbedText = nil - - // Embed the documents. - embeddings, err := palmModel.EmbedDocuments(context.Background(), texts) - - // Use assertions to check the results. - assert.NoError(t, err) - assert.NotNil(t, embeddings) - assert.Len(t, embeddings, 2) - assert.Len(t, embeddings[0], 3) - }) - - t.Run("Test embedding error", func(t *testing.T) { - // Define a list of texts to embed. - texts := []string{"text1"} - - // Set the mock error for EmbedText. - client.respEmbedText = nil - client.errEmbedText = errors.New("Test error") - - // Embed the documents. - embeddings, err := palmModel.EmbedDocuments(context.Background(), texts) - - // Use assertions to check the error and embeddings. - assert.Error(t, err) - assert.Nil(t, embeddings) - }) - - t.Run("Test embedding of a single query", func(t *testing.T) { - // Define a query text. - query := "query text" - - // Define a mock response for EmbedText. - mockResponse := &generativelanguagepb.EmbedTextResponse{ - Embedding: &generativelanguagepb.Embedding{ - Value: []float32{1.0, 2.0, 3.0}, - }, - } - - // Set the mock response for EmbedText. - client.respEmbedText = mockResponse - client.errEmbedText = nil - - // Embed the query. - embedding, err := palmModel.EmbedQuery(context.Background(), query) - - // Use assertions to check the results. - assert.NoError(t, err) - assert.NotNil(t, embedding) - assert.Len(t, embedding, 3) - }) - - t.Run("Test embedding error for query", func(t *testing.T) { - // Define a query text. - query := "query text" - - // Set the mock error for EmbedText. - client.respEmbedText = nil - client.errEmbedText = errors.New("Test error") - - // Embed the query. - embedding, err := palmModel.EmbedQuery(context.Background(), query) - - // Use assertions to check the error and embedding. - assert.Error(t, err) - assert.Nil(t, embedding) - }) -} - -// mockPalmClient is a custom mock implementation of the PalmClient interface. -type mockPalmClient struct { - respEmbedText *generativelanguagepb.EmbedTextResponse - errEmbedText error -} - -// EmbedText mocks the EmbedText method of the PalmClient interface. -func (m *mockPalmClient) EmbedText(ctx context.Context, req *generativelanguagepb.EmbedTextRequest, opts ...gax.CallOption) (*generativelanguagepb.EmbedTextResponse, error) { - if m.errEmbedText != nil { - return nil, m.errEmbedText - } - - return m.respEmbedText, nil -} diff --git a/go.mod b/go.mod index 8e87e69..a854e4a 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/sashabaranov/go-openai v1.17.9 github.com/stretchr/testify v1.8.4 github.com/weaviate/weaviate v1.22.6 + golang.org/x/net v0.19.0 google.golang.org/grpc v1.59.0 google.golang.org/protobuf v1.31.0 ) @@ -92,7 +93,6 @@ require ( go.opencensus.io v0.24.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/mod v0.14.0 // indirect - golang.org/x/net v0.19.0 // indirect golang.org/x/oauth2 v0.15.0 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/text v0.14.0 // indirect diff --git a/model/chatmodel/google_genai.go b/model/chatmodel/google_genai.go new file mode 100644 index 0000000..1fa5462 --- /dev/null +++ b/model/chatmodel/google_genai.go @@ -0,0 +1,166 @@ +package chatmodel + +import ( + "context" + "fmt" + "strings" + + "cloud.google.com/go/ai/generativelanguage/apiv1/generativelanguagepb" + "github.com/googleapis/gax-go/v2" + "github.com/hupe1980/golc" + "github.com/hupe1980/golc/callback" + "github.com/hupe1980/golc/schema" + "github.com/hupe1980/golc/tokenizer" + "github.com/hupe1980/golc/util" +) + +// Compile time check to ensure GoogleGenAI satisfies the ChatModel interface. +var _ schema.ChatModel = (*GoogleGenAI)(nil) + +// GoogleGenAIClient is an interface for the GoogleGenAI model client. +type GoogleGenAIClient interface { + GenerateContent(context.Context, *generativelanguagepb.GenerateContentRequest, ...gax.CallOption) (*generativelanguagepb.GenerateContentResponse, error) + CountTokens(context.Context, *generativelanguagepb.CountTokensRequest, ...gax.CallOption) (*generativelanguagepb.CountTokensResponse, error) +} + +const ( + roleUser = "user" + roleModel = "model" +) + +type GoogleGenAIOptions struct { + // CallbackOptions specify options for handling callbacks during text generation. + *schema.CallbackOptions `map:"-"` + // Tokenizer represents the tokenizer to be used with the LLM model. + schema.Tokenizer `map:"-"` + // ModelName is the name of the GoogleGenAI model to use. + ModelName string `map:"model_name,omitempty"` + // CandidateCount is the number of candidate generations to consider. + CandidateCount int32 `map:"candidate_count,omitempty"` + // MaxOutputTokens is the maximum number of tokens to generate in the output. + MaxOutputTokens int32 `map:"max_output_tokens,omitempty"` + // Temperature controls the randomness of the generation. Higher values make the output more random. + Temperature float32 `map:"temperature,omitempty"` + // TopP is the nucleus sampling parameter. It controls the cumulative probability of the most likely tokens to sample from. + TopP float32 `map:"top_p,omitempty"` + // TopK is the number of top tokens to consider for sampling. + TopK int32 `map:"top_k,omitempty"` +} + +type GoogleGenAI struct { + schema.Tokenizer + client GoogleGenAIClient + opts GoogleGenAIOptions +} + +func NewGoogleGenAI(client GoogleGenAIClient, optFns ...func(o *GoogleGenAIOptions)) (*GoogleGenAI, error) { + opts := GoogleGenAIOptions{ + CallbackOptions: &schema.CallbackOptions{ + Verbose: golc.Verbose, + }, + ModelName: "models/gemini-pro", + CandidateCount: 1, + MaxOutputTokens: 2048, + TopK: 3, + } + + for _, fn := range optFns { + fn(&opts) + } + + if !strings.HasPrefix(opts.ModelName, "models/") { + opts.ModelName = fmt.Sprintf("models/%s", opts.ModelName) + } + + if opts.Tokenizer == nil { + opts.Tokenizer = tokenizer.NewGoogleGenAI(client, opts.ModelName) + } + + return &GoogleGenAI{ + Tokenizer: opts.Tokenizer, + client: client, + opts: opts, + }, nil +} + +// Generate generates text based on the provided chat messages and options. +func (cm *GoogleGenAI) Generate(ctx context.Context, messages schema.ChatMessages, optFns ...func(o *schema.GenerateOptions)) (*schema.ModelResult, error) { + opts := schema.GenerateOptions{ + CallbackManger: &callback.NoopManager{}, + } + + for _, fn := range optFns { + fn(&opts) + } + + contents := []*generativelanguagepb.Content{} + + for _, message := range messages { + switch message.Type() { + case schema.ChatMessageTypeAI: + contents = append(contents, &generativelanguagepb.Content{Role: roleModel, Parts: []*generativelanguagepb.Part{{ + Data: &generativelanguagepb.Part_Text{Text: message.Content()}, + }}}) + case schema.ChatMessageTypeHuman: + contents = append(contents, &generativelanguagepb.Content{Role: roleUser, Parts: []*generativelanguagepb.Part{{ + Data: &generativelanguagepb.Part_Text{Text: message.Content()}, + }}}) + default: + return nil, fmt.Errorf("unsupported message type: %s", message.Type()) + } + } + + res, err := cm.client.GenerateContent(ctx, &generativelanguagepb.GenerateContentRequest{ + Model: cm.opts.ModelName, + Contents: contents, + GenerationConfig: &generativelanguagepb.GenerationConfig{ + CandidateCount: util.AddrOrNil(cm.opts.CandidateCount), + MaxOutputTokens: util.AddrOrNil(cm.opts.MaxOutputTokens), + Temperature: util.AddrOrNil(cm.opts.Temperature), + TopP: util.AddrOrNil(cm.opts.TopP), + TopK: util.AddrOrNil(cm.opts.TopK), + StopSequences: opts.Stop, + }, + }) + if err != nil { + return nil, err + } + + generations := make([]schema.Generation, len(res.Candidates)) + + for i, c := range res.Candidates { + var b strings.Builder + for _, p := range c.Content.Parts { + fmt.Fprintf(&b, "%v", p) + } + + generations[i] = newChatGeneraton(b.String()) + } + + return &schema.ModelResult{ + Generations: generations, + LLMOutput: map[string]any{ + "BlockReason": res.PromptFeedback.BlockReason.String(), + }, + }, nil +} + +// Type returns the type of the model. +func (cm *GoogleGenAI) Type() string { + return "chatmodel.GoogleGenAI" +} + +// Verbose returns the verbosity setting of the model. +func (cm *GoogleGenAI) Verbose() bool { + return cm.opts.Verbose +} + +// Callbacks returns the registered callbacks of the model. +func (cm *GoogleGenAI) Callbacks() []schema.Callback { + return cm.opts.Callbacks +} + +// InvocationParams returns the parameters used in the model invocation. +func (cm *GoogleGenAI) InvocationParams() map[string]any { + return util.StructToMap(cm.opts) +} diff --git a/model/chatmodel/ollama.go b/model/chatmodel/ollama.go index b358f2c..7b21e3a 100644 --- a/model/chatmodel/ollama.go +++ b/model/chatmodel/ollama.go @@ -112,7 +112,7 @@ func (cm *Ollama) Generate(ctx context.Context, messages schema.ChatMessages, op res, err := cm.client.GenerateChat(ctx, &ollama.ChatRequest{ Model: cm.opts.ModelName, Messages: ollamaMessages, - Stream: util.PTR(false), + Stream: util.AddrOrNil(false), Options: ollama.Options{ Temperature: cm.opts.Temperature, NumPredict: cm.opts.MaxTokens, diff --git a/model/chatmodel/palm.go b/model/chatmodel/palm.go deleted file mode 100644 index 16fda30..0000000 --- a/model/chatmodel/palm.go +++ /dev/null @@ -1,175 +0,0 @@ -package chatmodel - -import ( - "context" - "errors" - "fmt" - - generativelanguagepb "cloud.google.com/go/ai/generativelanguage/apiv1beta2/generativelanguagepb" - "github.com/googleapis/gax-go/v2" - "github.com/hupe1980/golc" - "github.com/hupe1980/golc/callback" - "github.com/hupe1980/golc/schema" - "github.com/hupe1980/golc/tokenizer" - "github.com/hupe1980/golc/util" -) - -// Compile time check to ensure Palm satisfies the ChatModel interface. -var _ schema.ChatModel = (*Palm)(nil) - -// PalmClient is the interface for the PALM client. -type PalmClient interface { - GenerateMessage(ctx context.Context, req *generativelanguagepb.GenerateMessageRequest, opts ...gax.CallOption) (*generativelanguagepb.GenerateMessageResponse, error) -} - -// PalmOptions is the options struct for the PALM chat model. -type PalmOptions struct { - *schema.CallbackOptions `map:"-"` - schema.Tokenizer `map:"-"` - - // ModelName is the name of the Palm chat model to use. - ModelName string `map:"model_name,omitempty"` - - // Temperature is the sampling temperature to use during text generation. - Temperature float32 `map:"temperature,omitempty"` - - // TopP is the total probability mass of tokens to consider at each step. - TopP float32 `map:"top_p,omitempty"` - - // TopK determines how the model selects tokens for output. - TopK int32 `map:"top_k"` - - // CandidateCount specifies the number of candidates to generate during text completion. - CandidateCount int32 `map:"candidate_count"` -} - -// Palm is a struct representing the PALM language model. -type Palm struct { - schema.Tokenizer - client PalmClient - opts PalmOptions -} - -// NewPalm creates a new instance of the PALM language model. -func NewPalm(client PalmClient, optFns ...func(o *PalmOptions)) (*Palm, error) { - opts := PalmOptions{ - CallbackOptions: &schema.CallbackOptions{ - Verbose: golc.Verbose, - }, - ModelName: "models/chat-bison-001", - Temperature: 0.7, - CandidateCount: 1, - } - - for _, fn := range optFns { - fn(&opts) - } - - if opts.Tokenizer == nil { - var tErr error - - opts.Tokenizer, tErr = tokenizer.NewGPT2() - if tErr != nil { - return nil, tErr - } - } - - return &Palm{ - Tokenizer: opts.Tokenizer, - client: client, - opts: opts, - }, nil -} - -// Generate generates text based on the provided chat messages and options. -func (cm *Palm) Generate(ctx context.Context, messages schema.ChatMessages, optFns ...func(o *schema.GenerateOptions)) (*schema.ModelResult, error) { - opts := schema.GenerateOptions{ - CallbackManger: &callback.NoopManager{}, - } - - for _, fn := range optFns { - fn(&opts) - } - - prompt := &generativelanguagepb.MessagePrompt{} - - for i, message := range messages { - switch message.Type() { - case schema.ChatMessageTypeSystem: - if i != 0 { - return nil, errors.New("system message must be first input message") - } - - prompt.Context = message.Content() - case schema.ChatMessageTypeAI: - prompt.Messages = append(prompt.Messages, &generativelanguagepb.Message{ - Author: "ai", - Content: message.Content(), - }) - case schema.ChatMessageTypeHuman: - prompt.Messages = append(prompt.Messages, &generativelanguagepb.Message{ - Author: "human", - Content: message.Content(), - }) - default: - return nil, fmt.Errorf("unsupported message type: %s", message.Type()) - } - } - - res, err := cm.client.GenerateMessage(ctx, &generativelanguagepb.GenerateMessageRequest{ - Prompt: prompt, - Model: cm.opts.ModelName, - Temperature: &cm.opts.Temperature, - TopP: &cm.opts.TopP, - TopK: &cm.opts.TopK, - CandidateCount: &cm.opts.CandidateCount, - }) - if err != nil { - return nil, err - } - - generations := util.Map(res.GetCandidates(), func(m *generativelanguagepb.Message, _ int) schema.Generation { - switch m.GetAuthor() { - case "ai": - return schema.Generation{ - Message: schema.NewAIChatMessage(m.GetContent()), - Text: m.GetContent(), - } - case "human": - return schema.Generation{ - Message: schema.NewHumanChatMessage(m.GetContent()), - Text: m.GetContent(), - } - default: - return schema.Generation{ - Message: schema.NewGenericChatMessage(m.GetContent(), m.GetAuthor()), - Text: m.GetContent(), - } - } - }) - - return &schema.ModelResult{ - Generations: generations, - LLMOutput: map[string]any{}, - }, nil -} - -// Type returns the type of the model. -func (cm *Palm) Type() string { - return "chatmodel.Palm" -} - -// Verbose returns the verbosity setting of the model. -func (cm *Palm) Verbose() bool { - return cm.opts.Verbose -} - -// Callbacks returns the registered callbacks of the model. -func (cm *Palm) Callbacks() []schema.Callback { - return cm.opts.Callbacks -} - -// InvocationParams returns the parameters used in the model invocation. -func (cm *Palm) InvocationParams() map[string]any { - return util.StructToMap(cm.opts) -} diff --git a/model/chatmodel/palm_test.go b/model/chatmodel/palm_test.go deleted file mode 100644 index 5be752f..0000000 --- a/model/chatmodel/palm_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package chatmodel - -import ( - "context" - "testing" - - generativelanguagepb "cloud.google.com/go/ai/generativelanguage/apiv1beta2/generativelanguagepb" - "github.com/googleapis/gax-go/v2" - "github.com/hupe1980/golc/schema" - "github.com/stretchr/testify/assert" -) - -func TestPalm(t *testing.T) { - // Create a mock PalmClient - mockClient := &mockPalmClient{} - - // Create a Palm instance with the mock client - palm, err := NewPalm(mockClient) - assert.NoError(t, err) - - // Run the test case - t.Run("SuccessfulGeneration", func(t *testing.T) { - mockClient.GenerateResponse = &generativelanguagepb.GenerateMessageResponse{ - Candidates: []*generativelanguagepb.Message{{ - Author: "ai", - Content: "World", - }}, - } - - // Invoke the Generate method - result, err := palm.Generate(context.Background(), schema.ChatMessages{ - schema.NewHumanChatMessage("Hello"), - }) - - // Assert the result and error - assert.NoError(t, err) - assert.Equal(t, "World", result.Generations[0].Message.Content()) - }) - - t.Run("Type", func(t *testing.T) { - // Create a Palm instance - llm, err := NewPalm(&mockPalmClient{}) - assert.NoError(t, err) - - // Call the Type method - typ := llm.Type() - - // Assert the result - assert.Equal(t, "chatmodel.Palm", typ) - }) - - t.Run("Verbose", func(t *testing.T) { - // Create a Palm instance - llm, err := NewPalm(&mockPalmClient{}) - assert.NoError(t, err) - - // Call the Verbose method - verbose := llm.Verbose() - - // Assert the result - assert.False(t, verbose) - }) - - t.Run("Callbacks", func(t *testing.T) { - // Create a Palm instance - llm, err := NewPalm(&mockPalmClient{}) - assert.NoError(t, err) - - // Call the Callbacks method - callbacks := llm.Callbacks() - - // Assert the result - assert.Empty(t, callbacks) - }) - - t.Run("InvocationParams", func(t *testing.T) { - // Create a Palm instance - llm, err := NewPalm(&mockPalmClient{}, func(o *PalmOptions) { - o.Temperature = 0.7 - }) - assert.NoError(t, err) - - // Call the InvocationParams method - params := llm.InvocationParams() - - // Assert the result - assert.Equal(t, float32(0.7), params["temperature"]) - }) -} - -// mockPalmClient is a mock implementation of the PalmClient interface for testing. -type mockPalmClient struct { - GenerateResponse *generativelanguagepb.GenerateMessageResponse - GenerateError error -} - -// GenerateMessage is a mock implementation of the GenerateMessage method. -func (m *mockPalmClient) GenerateMessage(ctx context.Context, req *generativelanguagepb.GenerateMessageRequest, opts ...gax.CallOption) (*generativelanguagepb.GenerateMessageResponse, error) { - return m.GenerateResponse, m.GenerateError -} diff --git a/model/llm/google_genai.go b/model/llm/google_genai.go index 2b2566c..7b5a74b 100644 --- a/model/llm/google_genai.go +++ b/model/llm/google_genai.go @@ -5,13 +5,13 @@ import ( "fmt" "strings" + "cloud.google.com/go/ai/generativelanguage/apiv1/generativelanguagepb" + "github.com/googleapis/gax-go/v2" "github.com/hupe1980/golc" "github.com/hupe1980/golc/callback" "github.com/hupe1980/golc/schema" "github.com/hupe1980/golc/tokenizer" "github.com/hupe1980/golc/util" - - "github.com/google/generative-ai-go/genai" ) // Compile time check to ensure GoogleGenAI satisfies the LLM interface. @@ -19,7 +19,8 @@ var _ schema.LLM = (*GoogleGenAI)(nil) // GoogleGenAIClient is an interface for the GoogleGenAI model client. type GoogleGenAIClient interface { - GenerativeModel(name string) *genai.GenerativeModel + GenerateContent(context.Context, *generativelanguagepb.GenerateContentRequest, ...gax.CallOption) (*generativelanguagepb.GenerateContentResponse, error) + CountTokens(context.Context, *generativelanguagepb.CountTokensRequest, ...gax.CallOption) (*generativelanguagepb.CountTokensResponse, error) } // GoogleGenAIOptions contains options for the GoogleGenAI Language Model. @@ -45,8 +46,8 @@ type GoogleGenAIOptions struct { // GoogleGenAI represents the GoogleGenAI Language Model. type GoogleGenAI struct { schema.Tokenizer - model *genai.GenerativeModel - opts GoogleGenAIOptions + client GoogleGenAIClient + opts GoogleGenAIOptions } // NewGoogleGenAI creates a new instance of the GoogleGenAI Language Model. @@ -55,7 +56,7 @@ func NewGoogleGenAI(client GoogleGenAIClient, optFns ...func(o *GoogleGenAIOptio CallbackOptions: &schema.CallbackOptions{ Verbose: golc.Verbose, }, - ModelName: "gemini-pro", + ModelName: "models/gemini-pro", CandidateCount: 1, MaxOutputTokens: 2048, TopK: 3, @@ -65,15 +66,17 @@ func NewGoogleGenAI(client GoogleGenAIClient, optFns ...func(o *GoogleGenAIOptio fn(&opts) } - model := client.GenerativeModel(opts.ModelName) + if !strings.HasPrefix(opts.ModelName, "models/") { + opts.ModelName = fmt.Sprintf("models/%s", opts.ModelName) + } if opts.Tokenizer == nil { - opts.Tokenizer = tokenizer.NewGoogleGenAITokenizer(model) + opts.Tokenizer = tokenizer.NewGoogleGenAI(client, opts.ModelName) } return &GoogleGenAI{ Tokenizer: opts.Tokenizer, - model: model, + client: client, opts: opts, }, nil } @@ -88,16 +91,20 @@ func (l *GoogleGenAI) Generate(ctx context.Context, prompt string, optFns ...fun fn(&opts) } - l.model.GenerationConfig = genai.GenerationConfig{ - CandidateCount: l.opts.CandidateCount, - MaxOutputTokens: l.opts.MaxOutputTokens, - Temperature: l.opts.Temperature, - TopP: l.opts.TopP, - TopK: l.opts.TopK, - StopSequences: opts.Stop, - } - - res, err := l.model.GenerateContent(ctx, genai.Text(prompt)) + res, err := l.client.GenerateContent(ctx, &generativelanguagepb.GenerateContentRequest{ + Model: l.opts.ModelName, + Contents: []*generativelanguagepb.Content{{Parts: []*generativelanguagepb.Part{{ + Data: &generativelanguagepb.Part_Text{Text: prompt}, + }}}}, + GenerationConfig: &generativelanguagepb.GenerationConfig{ + CandidateCount: util.AddrOrNil(l.opts.CandidateCount), + MaxOutputTokens: util.AddrOrNil(l.opts.MaxOutputTokens), + Temperature: util.AddrOrNil(l.opts.Temperature), + TopP: util.AddrOrNil(l.opts.TopP), + TopK: util.AddrOrNil(l.opts.TopK), + StopSequences: opts.Stop, + }, + }) if err != nil { return nil, err } diff --git a/model/llm/ollama.go b/model/llm/ollama.go index f487fcd..e8a7492 100644 --- a/model/llm/ollama.go +++ b/model/llm/ollama.go @@ -96,7 +96,7 @@ func (l *Ollama) Generate(ctx context.Context, prompt string, optFns ...func(o * res, err := l.client.Generate(ctx, &ollama.GenerateRequest{ Model: l.opts.ModelName, Prompt: prompt, - Stream: util.PTR(false), + Stream: util.AddrOrNil(false), Options: ollama.Options{ Temperature: l.opts.Temperature, NumPredict: l.opts.MaxTokens, diff --git a/model/llm/palm.go b/model/llm/palm.go deleted file mode 100644 index 4f4c854..0000000 --- a/model/llm/palm.go +++ /dev/null @@ -1,136 +0,0 @@ -package llm - -import ( - "context" - - generativelanguagepb "cloud.google.com/go/ai/generativelanguage/apiv1beta2/generativelanguagepb" - "github.com/googleapis/gax-go/v2" - "github.com/hupe1980/golc" - "github.com/hupe1980/golc/callback" - "github.com/hupe1980/golc/schema" - "github.com/hupe1980/golc/tokenizer" - "github.com/hupe1980/golc/util" -) - -// PalmClient is the interface for the PALM client. -type PalmClient interface { - GenerateText(ctx context.Context, req *generativelanguagepb.GenerateTextRequest, opts ...gax.CallOption) (*generativelanguagepb.GenerateTextResponse, error) -} - -// PalmOptions is the options struct for the PALM language model. -type PalmOptions struct { - *schema.CallbackOptions `map:"-"` - schema.Tokenizer `map:"-"` - - // ModelName is the name of the Palm language model to use. - ModelName string `map:"model_name,omitempty"` - - // Temperature is the sampling temperature to use during text generation. - Temperature float32 `map:"temperature,omitempty"` - - // TopP is the total probability mass of tokens to consider at each step. - TopP float32 `map:"top_p,omitempty"` - - // TopK determines how the model selects tokens for output. - TopK int32 `map:"top_k"` - - // MaxOutputTokens specifies the maximum number of output tokens for text generation. - MaxOutputTokens int32 `map:"max_output_tokens"` - - // CandidateCount specifies the number of candidates to generate during text completion. - CandidateCount int32 `map:"candidate_count"` -} - -// Palm is a struct representing the PALM language model. -type Palm struct { - client PalmClient - opts PalmOptions -} - -// NewPalm creates a new instance of the PALM language model. -func NewPalm(client PalmClient, optFns ...func(o *PalmOptions)) (*Palm, error) { - opts := PalmOptions{ - CallbackOptions: &schema.CallbackOptions{ - Verbose: golc.Verbose, - }, - ModelName: "models/text-bison-001", - Temperature: 0.7, - CandidateCount: 1, - } - - for _, fn := range optFns { - fn(&opts) - } - - if opts.Tokenizer == nil { - var tErr error - - opts.Tokenizer, tErr = tokenizer.NewGPT2() - if tErr != nil { - return nil, tErr - } - } - - return &Palm{ - client: client, - opts: opts, - }, nil -} - -// Generate generates text based on the provided prompt and options. -func (l *Palm) Generate(ctx context.Context, prompt string, optFns ...func(o *schema.GenerateOptions)) (*schema.ModelResult, error) { - opts := schema.GenerateOptions{ - CallbackManger: &callback.NoopManager{}, - } - - for _, fn := range optFns { - fn(&opts) - } - - res, err := l.client.GenerateText(ctx, &generativelanguagepb.GenerateTextRequest{ - Prompt: &generativelanguagepb.TextPrompt{ - Text: prompt, - }, - Model: l.opts.ModelName, - Temperature: &l.opts.Temperature, - TopP: &l.opts.TopP, - TopK: &l.opts.TopK, - MaxOutputTokens: &l.opts.MaxOutputTokens, - CandidateCount: &l.opts.CandidateCount, - StopSequences: opts.Stop, - }) - if err != nil { - return nil, err - } - - generations := util.Map(res.GetCandidates(), func(p *generativelanguagepb.TextCompletion, _ int) schema.Generation { - return schema.Generation{ - Text: p.GetOutput(), - } - }) - - return &schema.ModelResult{ - Generations: generations, - LLMOutput: map[string]any{}, - }, nil -} - -// Type returns the type of the model. -func (l *Palm) Type() string { - return "llm.Palm" -} - -// Verbose returns the verbosity setting of the model. -func (l *Palm) Verbose() bool { - return l.opts.Verbose -} - -// Callbacks returns the registered callbacks of the model. -func (l *Palm) Callbacks() []schema.Callback { - return l.opts.Callbacks -} - -// InvocationParams returns the parameters used in the model invocation. -func (l *Palm) InvocationParams() map[string]any { - return util.StructToMap(l.opts) -} diff --git a/model/llm/palm_test.go b/model/llm/palm_test.go deleted file mode 100644 index 8b99a99..0000000 --- a/model/llm/palm_test.go +++ /dev/null @@ -1,98 +0,0 @@ -package llm - -import ( - "context" - "testing" - - generativelanguagepb "cloud.google.com/go/ai/generativelanguage/apiv1beta2/generativelanguagepb" - "github.com/googleapis/gax-go/v2" - "github.com/stretchr/testify/assert" -) - -func TestPalm(t *testing.T) { - // Create a mock PalmClient - mockClient := &mockPalmClient{} - - // Create a Palm instance with the mock client - palm, err := NewPalm(mockClient) - assert.NoError(t, err) - - // Run the test case - t.Run("SuccessfulGeneration", func(t *testing.T) { - mockClient.GenerateResponse = &generativelanguagepb.GenerateTextResponse{ - Candidates: []*generativelanguagepb.TextCompletion{{ - Output: "World", - }}, - } - - // Invoke the Generate method - result, err := palm.Generate(context.Background(), "Hello") - - // Assert the result and error - assert.NoError(t, err) - assert.Equal(t, "World", result.Generations[0].Text) - }) - - t.Run("Type", func(t *testing.T) { - // Create a Palm instance - llm, err := NewPalm(&mockPalmClient{}) - assert.NoError(t, err) - - // Call the Type method - typ := llm.Type() - - // Assert the result - assert.Equal(t, "llm.Palm", typ) - }) - - t.Run("Verbose", func(t *testing.T) { - // Create a Palm instance - llm, err := NewPalm(&mockPalmClient{}) - assert.NoError(t, err) - - // Call the Verbose method - verbose := llm.Verbose() - - // Assert the result - assert.False(t, verbose) - }) - - t.Run("Callbacks", func(t *testing.T) { - // Create a Palm instance - llm, err := NewPalm(&mockPalmClient{}) - assert.NoError(t, err) - - // Call the Callbacks method - callbacks := llm.Callbacks() - - // Assert the result - assert.Empty(t, callbacks) - }) - - t.Run("InvocationParams", func(t *testing.T) { - // Create a Palm instance - llm, err := NewPalm(&mockPalmClient{}, func(o *PalmOptions) { - o.Temperature = 0.7 - o.MaxOutputTokens = 4711 - }) - assert.NoError(t, err) - - // Call the InvocationParams method - params := llm.InvocationParams() - - // Assert the result - assert.Equal(t, float32(0.7), params["temperature"]) - assert.Equal(t, int32(4711), params["max_output_tokens"]) - }) -} - -// mockPalmClient is a mock implementation of the PalmClient interface for testing. -type mockPalmClient struct { - GenerateResponse *generativelanguagepb.GenerateTextResponse - GenerateError error -} - -// GenerateText is a mock implementation of the GenerateText method. -func (m *mockPalmClient) GenerateText(ctx context.Context, req *generativelanguagepb.GenerateTextRequest, opts ...gax.CallOption) (*generativelanguagepb.GenerateTextResponse, error) { - return m.GenerateResponse, m.GenerateError -} diff --git a/tokenizer/google_genai.go b/tokenizer/google_genai.go index e34ec89..cd8a51c 100644 --- a/tokenizer/google_genai.go +++ b/tokenizer/google_genai.go @@ -3,30 +3,39 @@ package tokenizer import ( "context" - "github.com/google/generative-ai-go/genai" + "cloud.google.com/go/ai/generativelanguage/apiv1/generativelanguagepb" + "github.com/googleapis/gax-go/v2" "github.com/hupe1980/golc/schema" ) // Compile time check to ensure GoogleGenAI satisfies the Tokenizer interface. var _ schema.Tokenizer = (*GoogleGenAI)(nil) -type GoogleGenAIModel interface { - CountTokens(ctx context.Context, parts ...genai.Part) (*genai.CountTokensResponse, error) +// GoogleGenAIClient is an interface for the GoogleGenAI model client. +type GoogleGenAIClient interface { + CountTokens(context.Context, *generativelanguagepb.CountTokensRequest, ...gax.CallOption) (*generativelanguagepb.CountTokensResponse, error) } type GoogleGenAI struct { - model GoogleGenAIModel + client GoogleGenAIClient + model string } -func NewGoogleGenAITokenizer(model GoogleGenAIModel) *GoogleGenAI { +func NewGoogleGenAI(client GoogleGenAIClient, model string) *GoogleGenAI { return &GoogleGenAI{ - model: model, + client: client, + model: model, } } // GetNumTokens returns the number of tokens in the provided text. func (t *GoogleGenAI) GetNumTokens(ctx context.Context, text string) (uint, error) { - res, err := t.model.CountTokens(ctx, genai.Text(text)) + res, err := t.client.CountTokens(ctx, &generativelanguagepb.CountTokensRequest{ + Model: t.model, + Contents: []*generativelanguagepb.Content{{Parts: []*generativelanguagepb.Part{{ + Data: &generativelanguagepb.Part_Text{Text: text}, + }}}}, + }) if err != nil { return 0, err } diff --git a/tokenizer/google_genai_test.go b/tokenizer/google_genai_test.go new file mode 100644 index 0000000..4d85c93 --- /dev/null +++ b/tokenizer/google_genai_test.go @@ -0,0 +1,69 @@ +package tokenizer + +import ( + "context" + "testing" + + "cloud.google.com/go/ai/generativelanguage/apiv1/generativelanguagepb" + "github.com/googleapis/gax-go/v2" + "github.com/hupe1980/golc/schema" + "github.com/stretchr/testify/require" +) + +func TestGoogleGenAI(t *testing.T) { + // Create a new instance of the GoogleGenAI model with the custom mock client. + client := &mockGoogleGenAIClient{} + + // Create an instance of the GoogleGenAI tokenizer. + GoogleGenAI := NewGoogleGenAI(client, "model") + + // Test GetNumTokens. + t.Run("GetNumTokens", func(t *testing.T) { + // Set the mock response for EmbedText. + client.respCountTokens = &generativelanguagepb.CountTokensResponse{ + TotalTokens: 6, + } + client.errCount = nil + + // Test case with a sample input. + text := "This is a sample text." + numTokens, err := GoogleGenAI.GetNumTokens(context.TODO(), text) + require.NoError(t, err) + require.Equal(t, 6, int(numTokens)) + }) + + // Test GetNumTokensFromMessage. + t.Run("GetNumTokensFromMessage", func(t *testing.T) { + // Set the mock response for EmbedText. + client.respCountTokens = &generativelanguagepb.CountTokensResponse{ + TotalTokens: 27, + } + client.errCount = nil + + // Test case with sample chat messages. + messages := schema.ChatMessages{ + schema.NewSystemChatMessage("Welcome to the chat!"), + schema.NewHumanChatMessage("Hi, how are you?"), + schema.NewSystemChatMessage("I'm doing well, thank you!"), + } + + numTokens, err := GoogleGenAI.GetNumTokensFromMessage(context.TODO(), messages) + require.NoError(t, err) + require.Equal(t, 27, int(numTokens)) + }) +} + +// mockGoogleGenAIClient is a custom mock implementation of the GoogleGenAIClient interface. +type mockGoogleGenAIClient struct { + respCountTokens *generativelanguagepb.CountTokensResponse + errCount error +} + +// CountTokens mocks the CountTokens method of the GoogleGenAIClient interface. +func (m *mockGoogleGenAIClient) CountTokens(context.Context, *generativelanguagepb.CountTokensRequest, ...gax.CallOption) (*generativelanguagepb.CountTokensResponse, error) { + if m.errCount != nil { + return nil, m.errCount + } + + return m.respCountTokens, nil +} diff --git a/util/float.go b/util/float.go new file mode 100644 index 0000000..4ee096c --- /dev/null +++ b/util/float.go @@ -0,0 +1,23 @@ +package util + +// Float64ToFloat32 converts a slice of float64 values to a slice of float32 values. +// It creates a new slice and populates it with the corresponding float32 values. +func Float64ToFloat32(v []float64) []float32 { + v32 := make([]float32, len(v)) + for i, f := range v { + v32[i] = float32(f) + } + + return v32 +} + +// Float32ToFloat64 converts a slice of float32 values to a slice of float64 values. +// It creates a new slice and populates it with the corresponding float64 values. +func Float32ToFloat64(v []float32) []float64 { + v64 := make([]float64, len(v)) + for i, f := range v { + v64[i] = float64(f) + } + + return v64 +} diff --git a/util/float_test.go b/util/float_test.go new file mode 100644 index 0000000..74361cb --- /dev/null +++ b/util/float_test.go @@ -0,0 +1,41 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFloat64ToFloat32(t *testing.T) { + // Test case 1: Convert positive float64 values + input1 := []float64{1.5, 2.8, 3.2} + expected1 := []float32{1.5, 2.8, 3.2} + assert.Equal(t, expected1, Float64ToFloat32(input1)) + + // Test case 2: Convert negative float64 values + input2 := []float64{-1.5, -2.8, -3.2} + expected2 := []float32{-1.5, -2.8, -3.2} + assert.Equal(t, expected2, Float64ToFloat32(input2)) + + // Test case 3: Convert empty slice + input3 := []float64{} + expected3 := []float32{} + assert.Equal(t, expected3, Float64ToFloat32(input3)) +} + +func TestFloat32ToFloat64(t *testing.T) { + // Test case 1: Convert positive float32 values + input1 := []float32{1.5, 2.8, 3.2} + expected1 := []float64{1.5, 2.8, 3.2} + assert.InEpsilonSlice(t, expected1, Float32ToFloat64(input1), 1e-07) + + // Test case 2: Convert negative float32 values + input2 := []float32{-1.5, -2.8, -3.2} + expected2 := []float64{-1.5, -2.8, -3.2} + assert.InEpsilonSlice(t, expected2, Float32ToFloat64(input2), 1e-07) + + // Test case 3: Convert empty slice + input3 := []float32{} + expected3 := []float64{} + assert.Equal(t, expected3, Float32ToFloat64(input3)) +} diff --git a/util/ptr.go b/util/ptr.go new file mode 100644 index 0000000..2c7a7e0 --- /dev/null +++ b/util/ptr.go @@ -0,0 +1,12 @@ +package util + +// AddrOrNil returns nil if x is the zero value for T, +// or &x otherwise. +func AddrOrNil[T comparable](x T) *T { + var z T + if x == z { + return nil + } + + return &x +} diff --git a/util/ptr_test.go b/util/ptr_test.go new file mode 100644 index 0000000..235eae0 --- /dev/null +++ b/util/ptr_test.go @@ -0,0 +1,35 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAddrOrNil(t *testing.T) { + t.Run("ZeroValue", func(t *testing.T) { + var zeroInt int + result := AddrOrNil(zeroInt) + assert.Nil(t, result, "Expected nil for zero value") + }) + + t.Run("NonZeroValue", func(t *testing.T) { + nonZeroInt := 42 + result := AddrOrNil(nonZeroInt) + assert.NotNil(t, result, "Expected non-nil for non-zero value") + assert.Equal(t, nonZeroInt, *result, "Unexpected value for non-zero value") + }) + + t.Run("ZeroString", func(t *testing.T) { + var zeroString string + result := AddrOrNil(zeroString) + assert.Nil(t, result, "Expected nil for zero value") + }) + + t.Run("NonZeroString", func(t *testing.T) { + nonZeroString := "test" + result := AddrOrNil(nonZeroString) + assert.NotNil(t, result, "Expected non-nil for non-zero value") + assert.Equal(t, nonZeroString, *result, "Unexpected value for non-zero value") + }) +} diff --git a/util/util.go b/util/util.go index 48d6893..91b2254 100644 --- a/util/util.go +++ b/util/util.go @@ -1,7 +1,2 @@ // Package util provides utility functions and helpers. package util - -// PTR returns a pointer to the given value. -func PTR[T any](v T) *T { - return &v -} diff --git a/vectorstore/vectorstore.go b/vectorstore/vectorstore.go index 79d3953..e4fa436 100644 --- a/vectorstore/vectorstore.go +++ b/vectorstore/vectorstore.go @@ -10,12 +10,3 @@ import ( func ToRetriever(vectorStore schema.VectorStore, optFns ...func(o *retriever.VectorStoreOptions)) schema.Retriever { return retriever.NewVectorStore(vectorStore, optFns...) } - -func float64ToFloat32(v []float64) []float32 { - v32 := make([]float32, len(v)) - for i, f := range v { - v32[i] = float32(f) - } - - return v32 -} diff --git a/vectorstore/weaviate.go b/vectorstore/weaviate.go index 1dfd006..faa054b 100644 --- a/vectorstore/weaviate.go +++ b/vectorstore/weaviate.go @@ -8,6 +8,7 @@ import ( "github.com/go-openapi/strfmt" "github.com/google/uuid" "github.com/hupe1980/golc/schema" + "github.com/hupe1980/golc/util" "github.com/weaviate/weaviate-go-client/v4/weaviate" "github.com/weaviate/weaviate-go-client/v4/weaviate/graphql" "github.com/weaviate/weaviate/entities/models" @@ -106,7 +107,7 @@ func (vs *Weaviate) AddDocuments(ctx context.Context, docs []schema.Document) er objects = append(objects, &models.Object{ Class: vs.opts.IndexName, ID: strfmt.UUID(uuid.New().String()), - Vector: float64ToFloat32(vectors[i]), + Vector: util.Float64ToFloat32(vectors[i]), Properties: metadata, }) } @@ -125,7 +126,7 @@ func (vs *Weaviate) SimilaritySearch(ctx context.Context, query string) ([]schem return nil, err } - nearVector := vs.client.GraphQL().NearVectorArgBuilder().WithVector(float64ToFloat32(vector)) + nearVector := vs.client.GraphQL().NearVectorArgBuilder().WithVector(util.Float64ToFloat32(vector)) fields := []graphql.Field{ {Name: vs.opts.TextKey},