Skip to content

Commit

Permalink
Fix cohere embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Dec 19, 2023
1 parent 6a9d65a commit 58a4015
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 11 deletions.
13 changes: 7 additions & 6 deletions embedding/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func NewCohere(apiKey string, optFns ...func(o *CohereOptions)) (*Cohere, error)
// It returns the initialized Cohere instance.
func NewCohereFromClient(client CohereClient, optFns ...func(o *CohereOptions)) (*Cohere, error) {
opts := CohereOptions{
Model: "embed-english-v3.0",
Model: "embed-english-v2.0",
MaxRetries: 3,
Truncate: "NONE",
}
Expand All @@ -71,16 +71,17 @@ func (e *Cohere) BatchEmbedText(ctx context.Context, texts []string) ([][]float3
}

res, err := e.embedWithRetry(ctx, &cohere.EmbedRequest{
Model: util.AddrOrNil(e.opts.Model),
Truncate: truncate.Ptr(),
Texts: texts,
Model: util.AddrOrNil(e.opts.Model),
Truncate: truncate.Ptr(),
Texts: texts,
EmbeddingTypes: []string{"float"},
})
if err != nil {
return nil, err
}

embeddings := make([][]float32, len(res.Embeddings))
for i, r := range res.Embeddings {
embeddings := make([][]float32, len(res.EmbeddingsByType.Embeddings.Float))
for i, r := range res.EmbeddingsByType.Embeddings.Float {
embeddings[i] = util.Float64ToFloat32(r)
}

Expand Down
18 changes: 13 additions & 5 deletions embedding/cohere_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ func TestCohere(t *testing.T) {
// Create a new instance of the Cohere model with a mock client.
client := &mockCohereClient{
response: &cohere.EmbedResponse{
Embeddings: [][]float64{
{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0},
EmbeddingsByType: &cohere.EmbedByTypeResponse{
Embeddings: &cohere.EmbedByTypeResponseEmbeddings{
Float: [][]float64{
{1.0, 2.0, 3.0},
{4.0, 5.0, 6.0},
},
},
},
},
}
Expand All @@ -43,8 +47,12 @@ func TestCohere(t *testing.T) {
// Create a new instance of the Cohere model with a mock client.
client := &mockCohereClient{
response: &cohere.EmbedResponse{
Embeddings: [][]float64{
{1.0, 2.0, 3.0},
EmbeddingsByType: &cohere.EmbedByTypeResponse{
Embeddings: &cohere.EmbedByTypeResponseEmbeddings{
Float: [][]float64{
{1.0, 2.0, 3.0},
},
},
},
},
}
Expand Down
24 changes: 24 additions & 0 deletions examples/cohere_embedding/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package main

import (
"context"
"fmt"
"log"
"os"

"github.com/hupe1980/golc/embedding"
)

func main() {
embedder, err := embedding.NewCohere(os.Getenv("COHERE_API_KEY"))
if err != nil {
log.Fatal(err)
}

e, err := embedder.EmbedText(context.Background(), "Hello llama2!")
if err != nil {
log.Fatal(err)
}

fmt.Println(e)
}

0 comments on commit 58a4015

Please sign in to comment.