Skip to content

Commit

Permalink
Add better bedrock ai21 llm support
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Nov 8, 2023
1 parent 14ad052 commit 53a596b
Showing 1 changed file with 69 additions and 0 deletions.
69 changes: 69 additions & 0 deletions model/llm/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/hupe1980/golc"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/integration/ai21"
"github.com/hupe1980/golc/schema"
"github.com/hupe1980/golc/tokenizer"
"github.com/hupe1980/golc/util"
Expand Down Expand Up @@ -125,6 +126,74 @@ type BedrockRuntimeClient interface {
InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error)
}

type BedrockAI21Options struct {
*schema.CallbackOptions `map:"-"`
schema.Tokenizer `map:"-"`

// Model id to use.
ModelID string `map:"model_id,omitempty"`

// Temperature controls the randomness of text generation. Higher values make it more random.
Temperature float64 `map:"temperature"`

// TopP sets the nucleus sampling probability. Higher values result in more diverse text.
TopP float64 `map:"topP"`

// MaxTokens sets the maximum number of tokens in the generated text.
MaxTokens int `map:"maxTokens"`

// PresencePenalty specifies the penalty for repeating words in generated text.
PresencePenalty ai21.Penalty `map:"presencePenalty"`

// CountPenalty specifies the penalty for repeating tokens in generated text.
CountPenalty ai21.Penalty `map:"countPenalty"`

// FrequencyPenalty specifies the penalty for generating frequent words.
FrequencyPenalty ai21.Penalty `map:"frequencyPenalty"`
}

func NewBedrockAI21(client BedrockRuntimeClient, optFns ...func(o *BedrockAI21Options)) (*Bedrock, error) {
opts := BedrockAI21Options{
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
ModelID: "ai21.j2-ultra-v1",
Temperature: 0.5,
TopP: 0.5,
MaxTokens: 200,
PresencePenalty: DefaultPenalty,
CountPenalty: DefaultPenalty,
FrequencyPenalty: DefaultPenalty,
}

for _, fn := range optFns {
fn(&opts)
}

if opts.Tokenizer == nil {
var tErr error

opts.Tokenizer, tErr = tokenizer.NewGPT2()
if tErr != nil {
return nil, tErr
}
}

return NewBedrock(client, func(o *BedrockOptions) {
o.CallbackOptions = opts.CallbackOptions
o.Tokenizer = opts.Tokenizer
o.ModelID = opts.ModelID
o.ModelParams = map[string]any{
"temperature": opts.Temperature,
"topP": opts.TopP,
"maxTokens": opts.MaxTokens,
"presencePenalty": opts.PresencePenalty,
"countPenalty": opts.CountPenalty,
"frequencyPenalty": opts.FrequencyPenalty,
}
})
}

type BedrockAnthropicOptions struct {
*schema.CallbackOptions `map:"-"`
schema.Tokenizer `map:"-"`
Expand Down

0 comments on commit 53a596b

Please sign in to comment.