From 58a401501e5115d190ee701c59b4e2f58d029f57 Mon Sep 17 00:00:00 2001 From: hupe1980 Date: Tue, 19 Dec 2023 13:48:01 +0100 Subject: [PATCH] Fix cohere embedding --- embedding/cohere.go | 13 +++++++------ embedding/cohere_test.go | 18 +++++++++++++----- examples/cohere_embedding/main.go | 24 ++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 11 deletions(-) create mode 100644 examples/cohere_embedding/main.go diff --git a/embedding/cohere.go b/embedding/cohere.go index b098965..9f909e6 100644 --- a/embedding/cohere.go +++ b/embedding/cohere.go @@ -48,7 +48,7 @@ func NewCohere(apiKey string, optFns ...func(o *CohereOptions)) (*Cohere, error) // It returns the initialized Cohere instance. func NewCohereFromClient(client CohereClient, optFns ...func(o *CohereOptions)) (*Cohere, error) { opts := CohereOptions{ - Model: "embed-english-v3.0", + Model: "embed-english-v2.0", MaxRetries: 3, Truncate: "NONE", } @@ -71,16 +71,17 @@ func (e *Cohere) BatchEmbedText(ctx context.Context, texts []string) ([][]float3 } res, err := e.embedWithRetry(ctx, &cohere.EmbedRequest{ - Model: util.AddrOrNil(e.opts.Model), - Truncate: truncate.Ptr(), - Texts: texts, + Model: util.AddrOrNil(e.opts.Model), + Truncate: truncate.Ptr(), + Texts: texts, + EmbeddingTypes: []string{"float"}, }) if err != nil { return nil, err } - embeddings := make([][]float32, len(res.Embeddings)) - for i, r := range res.Embeddings { + embeddings := make([][]float32, len(res.EmbeddingsByType.Embeddings.Float)) + for i, r := range res.EmbeddingsByType.Embeddings.Float { embeddings[i] = util.Float64ToFloat32(r) } diff --git a/embedding/cohere_test.go b/embedding/cohere_test.go index d7055d5..9abdfcc 100644 --- a/embedding/cohere_test.go +++ b/embedding/cohere_test.go @@ -15,9 +15,13 @@ func TestCohere(t *testing.T) { // Create a new instance of the Cohere model with a mock client. client := &mockCohereClient{ response: &cohere.EmbedResponse{ - Embeddings: [][]float64{ - {1.0, 2.0, 3.0}, - {4.0, 5.0, 6.0}, + EmbeddingsByType: &cohere.EmbedByTypeResponse{ + Embeddings: &cohere.EmbedByTypeResponseEmbeddings{ + Float: [][]float64{ + {1.0, 2.0, 3.0}, + {4.0, 5.0, 6.0}, + }, + }, }, }, } @@ -43,8 +47,12 @@ func TestCohere(t *testing.T) { // Create a new instance of the Cohere model with a mock client. client := &mockCohereClient{ response: &cohere.EmbedResponse{ - Embeddings: [][]float64{ - {1.0, 2.0, 3.0}, + EmbeddingsByType: &cohere.EmbedByTypeResponse{ + Embeddings: &cohere.EmbedByTypeResponseEmbeddings{ + Float: [][]float64{ + {1.0, 2.0, 3.0}, + }, + }, }, }, } diff --git a/examples/cohere_embedding/main.go b/examples/cohere_embedding/main.go new file mode 100644 index 0000000..d4a8dae --- /dev/null +++ b/examples/cohere_embedding/main.go @@ -0,0 +1,24 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + + "github.com/hupe1980/golc/embedding" +) + +func main() { + embedder, err := embedding.NewCohere(os.Getenv("COHERE_API_KEY")) + if err != nil { + log.Fatal(err) + } + + e, err := embedder.EmbedText(context.Background(), "Hello llama2!") + if err != nil { + log.Fatal(err) + } + + fmt.Println(e) +}