Skip to content

Commit

Permalink
Misc
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jun 16, 2023
1 parent 37354e7 commit 82dd290
Show file tree
Hide file tree
Showing 20 changed files with 1,205 additions and 91 deletions.
8 changes: 4 additions & 4 deletions chatmodel/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ var _ golc.LLM = (*OpenAI)(nil)

type OpenAIOptions struct {
// Model name to use.
Model string
ModelName string
// Sampling temperature to use.
Temperatur float32
// The maximum number of tokens to generate in the completion.
Expand All @@ -42,7 +42,7 @@ type OpenAI struct {

func NewOpenAI(apiKey string) (*OpenAI, error) {
opts := OpenAIOptions{
Model: "gpt-3.5-turbo",
ModelName: "gpt-3.5-turbo",
Temperatur: 1,
TopP: 1,
PresencePenalty: 0,
Expand Down Expand Up @@ -75,7 +75,7 @@ func (o *OpenAI) generate(ctx context.Context, messages []golc.ChatMessage) (*go
}

res, err := o.client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
Model: o.opts.Model,
Model: o.opts.ModelName,
Messages: openAIMessages,
})
if err != nil {
Expand Down Expand Up @@ -122,7 +122,7 @@ func (o *OpenAI) GetNumTokensFromMessage(messages []golc.ChatMessage) (int, erro
}

func (o *OpenAI) getEncodingForModel() (string, *tiktoken.Tiktoken, error) {
model := o.opts.Model
model := o.opts.ModelName
if model == "gpt-3.5-turbo" {
model = "gpt-3.5-turbo-0301"
} else if model == "gpt-4" {
Expand Down
1 change: 1 addition & 0 deletions embedding/cohere.go
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
package embedding
12 changes: 6 additions & 6 deletions embedding/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ func NewFake(size int) *Fake {
return &Fake{Size: size}
}

func (f *Fake) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) {
func (e *Fake) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) {
embeddings := make([][]float64, len(texts))
for i := range texts {
embeddings[i] = f.getEmbedding()
embeddings[i] = e.getEmbedding()
}

return embeddings, nil
}

func (f *Fake) EmbedQuery(ctx context.Context, text string) ([]float64, error) {
return f.getEmbedding(), nil
func (e *Fake) EmbedQuery(ctx context.Context, text string) ([]float64, error) {
return e.getEmbedding(), nil
}

func (f *Fake) getEmbedding() []float64 {
embedding := make([]float64, f.Size)
func (e *Fake) getEmbedding() []float64 {
embedding := make([]float64, e.Size)
for i := range embedding {
embedding[i] = rand.NormFloat64()
}
Expand Down
213 changes: 213 additions & 0 deletions embedding/openai.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
package embedding

import (
"context"
"fmt"
"math"
"strings"

"github.com/hupe1980/golc/util"
"github.com/pkoukk/tiktoken-go"
"github.com/sashabaranov/go-openai"
)

var nameToOpenAIModel = map[string]openai.EmbeddingModel{
"text-similarity-ada-001": openai.AdaSimilarity,
"text-similarity-babbage-001": openai.BabbageSimilarity,
"text-similarity-curie-001": openai.CurieSimilarity,
"text-similarity-davinci-001": openai.DavinciSimilarity,
"text-search-ada-doc-001": openai.AdaSearchDocument,
"text-search-ada-query-001": openai.AdaSearchQuery,
"text-search-babbage-doc-001": openai.BabbageSearchDocument,
"text-search-babbage-query-001": openai.BabbageSearchQuery,
"text-search-curie-doc-001": openai.CurieSearchDocument,
"text-search-curie-query-001": openai.CurieSearchQuery,
"text-search-davinci-doc-001": openai.DavinciSearchDocument,
"text-search-davinci-query-001": openai.DavinciSearchQuery,
"code-search-ada-code-001": openai.AdaCodeSearchCode,
"code-search-ada-text-001": openai.AdaCodeSearchText,
"code-search-babbage-code-001": openai.BabbageCodeSearchCode,
"code-search-babbage-text-001": openai.BabbageCodeSearchText,
"text-embedding-ada-002": openai.AdaEmbeddingV2,
}

type OpenAIOptions struct {
// Model name to use.
ModelName string
EmbeddingContextLength int
// Maximum number of texts to embed in each batch
ChunkSize int
}

type OpenAI struct {
client *openai.Client
opts OpenAIOptions
}

func NewOpenAI(apiKey string, optFns ...func(o *OpenAIOptions)) (*OpenAI, error) {
opts := OpenAIOptions{
ModelName: "text-embedding-ada-002",
EmbeddingContextLength: 8191,
ChunkSize: 1000,
}

for _, fn := range optFns {
fn(&opts)
}

return &OpenAI{
client: openai.NewClient(apiKey),
opts: opts,
}, nil
}

func (e *OpenAI) EmbedDocuments(ctx context.Context, texts []string) ([][]float64, error) {
return e.getLenSafeEmbeddings(ctx, texts)
}

func (e *OpenAI) EmbedQuery(ctx context.Context, text string) ([]float64, error) {
if len(text) > e.opts.EmbeddingContextLength {
embeddings, err := e.getLenSafeEmbeddings(ctx, []string{text})
if err != nil {
return nil, err
}

return embeddings[0], nil
}

if strings.HasSuffix(e.opts.ModelName, "001") {
// See: https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
// replace newlines, which can negatively affect performance.
text = strings.ReplaceAll(text, "\n", " ")
}

res, err := e.client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
Model: nameToOpenAIModel[e.opts.ModelName],
Input: []string{text},
})
if err != nil {
return nil, err
}

return util.Map(res.Data[0].Embedding, func(e float32, i int) float64 {
return float64(e)
}), nil
}

func (e *OpenAI) getLenSafeEmbeddings(ctx context.Context, texts []string) ([][]float64, error) {
// please refer to
// https://github.com/openai/openai-cookbook/blob/main/examples/Embedding_long_inputs.ipynb
tokens := []string{}
indices := []int{}

encoding, err := tiktoken.EncodingForModel(e.opts.ModelName)
if err != nil {
return nil, err
}

for i, text := range texts {
if strings.HasSuffix(e.opts.ModelName, "001") {
// Replace newlines, which can negatively affect performance.
text = strings.ReplaceAll(text, "\n", " ")
}

token := encoding.Encode(text, nil, nil)

for j := 0; j < len(token); j += e.opts.EmbeddingContextLength {
limit := j + e.opts.EmbeddingContextLength
if limit > len(token) {
limit = len(token)
}

tokens = append(tokens, util.Map(token[j:limit], func(e int, _ int) string {
return fmt.Sprintf("%d", e)
})...)

indices = append(indices, i)
}
}

batchedEmbeddings := [][]float64{}

for i := 0; i < len(tokens); i += e.opts.ChunkSize {
limit := i + e.opts.ChunkSize
if limit > len(tokens) {
limit = len(tokens)
}

res, err := e.client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
Model: nameToOpenAIModel[e.opts.ModelName],
Input: tokens[i:limit],
})
if err != nil {
return nil, err
}

for _, d := range res.Data {
batchedEmbeddings = append(batchedEmbeddings, util.Map(d.Embedding, func(e float32, _ int) float64 {
return float64(e)
}))
}
}

results := make([][][]float64, len(texts))
numTokensInBatch := make([][]int, len(texts))

for i := 0; i < len(indices); i++ {
index := indices[i]
results[index] = append(results[index], batchedEmbeddings[i])
numTokensInBatch[index] = append(numTokensInBatch[index], len(tokens[i]))
}

embeddings := make([][]float64, len(texts))

for i := 0; i < len(texts); i++ {
var average []float64

result := results[i]

if len(result) == 0 {
res, err := e.client.CreateEmbeddings(ctx, openai.EmbeddingRequest{
Model: nameToOpenAIModel[e.opts.ModelName],
Input: []string{""},
})
if err != nil {
return nil, err
}

average = util.Map(res.Data[0].Embedding, func(e float32, i int) float64 {
return float64(e)
})
} else {
sum := make([]float64, len(result[0]))

weights := numTokensInBatch[i]

for j := 0; j < len(result); j++ {
embedding := result[j]
for k := 0; k < len(embedding); k++ {
sum[k] += embedding[k] * float64(weights[j])
}
}

average = make([]float64, len(sum))
for j := 0; j < len(sum); j++ {
average[j] = sum[j] / float64(util.SumInt(weights))
}
}

norm := 0.0
for _, value := range average {
norm += value * value
}

norm = math.Sqrt(norm)
for j := 0; j < len(average); j++ {
average[j] /= norm
}

embeddings[i] = average
}

return embeddings, nil
}
12 changes: 9 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,31 @@ go 1.20

require (
github.com/Masterminds/sprig/v3 v3.2.3
github.com/aws/aws-sdk-go-v2 v1.18.0
github.com/aws/aws-sdk-go-v2 v1.18.1
github.com/aws/aws-sdk-go-v2/service/sagemakerruntime v1.19.6
github.com/sashabaranov/go-openai v1.10.1
github.com/stretchr/testify v1.8.2
)

require (
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.27 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.34 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.28 // indirect
github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.14.14 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.7.28 // indirect
github.com/aws/smithy-go v1.13.5 // indirect
github.com/cohere-ai/tokenizer v1.1.1 // indirect
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

require (
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Masterminds/semver/v3 v3.2.0 // indirect
github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.10.28
github.com/aws/aws-sdk-go-v2/service/dynamodb v1.19.10
github.com/aws/aws-sdk-go-v2/service/kendra v1.40.2
github.com/cohere-ai/cohere-go v1.2.2
github.com/davecgh/go-spew v1.1.1 // indirect
Expand Down
21 changes: 18 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,25 @@ github.com/Masterminds/semver/v3 v3.2.0 h1:3MEsd0SM6jqZojhjLWWeBY+Kcjy9i6MQAeY7Y
github.com/Masterminds/semver/v3 v3.2.0/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA=
github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM=
github.com/aws/aws-sdk-go-v2 v1.18.0 h1:882kkTpSFhdgYRKVZ/VCgf7sd0ru57p2JCxz4/oN5RY=
github.com/aws/aws-sdk-go-v2 v1.18.0/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33 h1:kG5eQilShqmJbv11XL1VpyDbaEJzWxd4zRiCG30GSn4=
github.com/aws/aws-sdk-go-v2 v1.18.1 h1:+tefE750oAb7ZQGzla6bLkOwfcQCEtC5y2RqoqCeqKo=
github.com/aws/aws-sdk-go-v2 v1.18.1/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw=
github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.10.28 h1:0v/4ueonxdvfGwDIZf/85C6sl5TWWVY3oL3W686f52c=
github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.10.28/go.mod h1:xO5xY7M+f11S4/LDWYlJfO9ljCQNzjlLtsolMzL3fsw=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33/go.mod h1:7i0PF1ME/2eUPFcjkVIwq+DOygHEoK92t5cDqNgYbIw=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.27 h1:vFQlirhuM8lLlpI7imKOMsjdQLuN9CPi+k44F/OFVsk=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.34 h1:A5UqQEmPaCFpedKouS4v+dHCTUo2sKqhoKO9U5kxyWo=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.34/go.mod h1:wZpTEecJe0Btj3IYnDx/VlUzor9wm3fJHyvLpQF0VwY=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.27/go.mod h1:UrHnn3QV/d0pBZ6QBAEQcqFLf8FAzLmoUfPVIueOvoM=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.28 h1:srIVS45eQuewqz6fKKu6ZGXaq6FuFg5NzgQBAM6g8Y4=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.28/go.mod h1:7VRpKQQedkfIEXb4k52I7swUnZP0wohVajJMRn3vsUw=
github.com/aws/aws-sdk-go-v2/service/dynamodb v1.19.10 h1:7hcsca97GMqYPd8BrhZckWY/ljAhPli6L2MY2MZ+eVQ=
github.com/aws/aws-sdk-go-v2/service/dynamodb v1.19.10/go.mod h1:W1oiFegjVosgjIwb2Vv45jiCQT1ee8x85u8EyZRYLes=
github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.14.14 h1:T9FMVvefm8TWwyVYpFVohP2iLM1QnqAB0m/qksVqs+w=
github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.14.14/go.mod h1:31kKOlv+a+XLCu0wDK8BeeCOjdcZihEoQcLiPIZoyw4=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11 h1:y2+VQzC6Zh2ojtV2LoC0MNwHWc6qXv/j2vrQtlftkdA=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11/go.mod h1:iV4q2hsqtNECrfmlXyord9u4zyuFEJX9eLgLpSPzWA8=
github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.7.28 h1:/D994rtMQd1jQ2OY+7tvUlMlrv1L1c7Xtma/FhkbVtY=
github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.7.28/go.mod h1:3bJI2pLY3ilrqO5EclusI1GbjFJh1iXYrhOItf2sjKw=
github.com/aws/aws-sdk-go-v2/service/kendra v1.40.2 h1:4oiWp0Y9BnBh0x7V4/h3u/qnagKgl5eofYi3bANQWbk=
github.com/aws/aws-sdk-go-v2/service/kendra v1.40.2/go.mod h1:00b/aokrZ0r4fUsMP9RSOL9bvxTCCRCOeUy5o0lyqrA=
github.com/aws/aws-sdk-go-v2/service/sagemakerruntime v1.19.6 h1:1pLDLpx4bTonQo/yYFfHdMUiFT5XWDoQr/GwtK+np5Q=
Expand All @@ -34,7 +47,9 @@ github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4
github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE=
github.com/imdario/mergo v0.3.11 h1:3tnifQM4i+fbajXKBHXWEH+KvNHqojZ778UH75j3bGA=
github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ=
github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw=
Expand Down
13 changes: 13 additions & 0 deletions golc.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,19 @@ type Memory interface {
Clear() error
}

type ChatMessageHistory interface {
// Messages returns the messages stored in the store.
Messages() ([]ChatMessage, error)
// Add a user message to the store.
AddUserMessage(text string) error
// Add an AI message to the store.
AddAIMessage(text string) error
// Add a self-created message to the store.
AddMessage(message ChatMessage) error
// Remove all messages from the store.
Clear() error
}

type PromptValue interface {
String() string
Messages() []ChatMessage
Expand Down
Loading

0 comments on commit 82dd290

Please sign in to comment.