Skip to content

Commit

Permalink
Add bedrock llm model
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Sep 30, 2023
1 parent fa647f2 commit fbd27a3
Showing 1 changed file with 230 additions and 0 deletions.
230 changes: 230 additions & 0 deletions model/llm/bedrock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
package llm

import (
"context"
"encoding/json"
"fmt"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/hupe1980/golc"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/schema"
"github.com/hupe1980/golc/tokenizer"
"github.com/hupe1980/golc/util"
)

// Compile time check to ensure Bedrock satisfies the LLM interface.
var _ schema.LLM = (*Bedrock)(nil)

// providerStopSequenceKeyMap is a mapping between language model (LLM) providers
// and the corresponding key names used for stop sequences. Stop sequences are sets
// of words that, when encountered in the generated text, signal the language model
// to stop generating further content. Different LLM providers might use different
// key names to specify these stop sequences in the input parameters.
var providerStopSequenceKeyMap = map[string]string{
"anthropic": "stop_sequences",
"amazon": "stopSequences",
"ai21": "stop_sequences",
}

// BedrockInputOutputAdapter is a helper struct for preparing input and handling output for Bedrock model.
type BedrockInputOutputAdapter struct {
provider string
}

// NewBedrockInputOutputAdpter creates a new instance of BedrockInputOutputAdpter.
func NewBedrockInputOutputAdapter(provider string) *BedrockInputOutputAdapter {
return &BedrockInputOutputAdapter{
provider: provider,
}
}

// PrepareInput prepares the input for the Bedrock model based on the specified provider.
func (bioa *BedrockInputOutputAdapter) PrepareInput(prompt string, modelParams map[string]any) ([]byte, error) {
var body map[string]any

switch bioa.provider {
case "ai21":
body = modelParams
body["prompt"] = prompt
case "amazon":
body = make(map[string]any)
body["inputText"] = prompt
body["textGenerationConfig"] = modelParams
default:
return nil, fmt.Errorf("unsupported provider: %s", bioa.provider)
}

return json.Marshal(body)
}

// ai21Output represents the structure of the output from the AI21 language model.
// It is used for unmarshaling JSON responses from the language model's API.
type ai21Output struct {
Completions []struct {
Data struct {
Text string `json:"text"`
} `json:"data"`
} `json:"completions"`
}

// amazonOutput represents the structure of the output from the Amazon language model.
// It is used for unmarshaling JSON responses from the language model's API.
type amazonOutput struct {
Results []struct {
OutputText string `json:"outputText"`
} `json:"results"`
}

// PrepareOutput prepares the output for the Bedrock model based on the specified provider.
func (bioa *BedrockInputOutputAdapter) PrepareOutput(response *bedrockruntime.InvokeModelOutput) (string, error) {
switch bioa.provider {
case "ai21":
output := &ai21Output{}
if err := json.Unmarshal(response.Body, output); err != nil {
return "", err
}

return output.Completions[0].Data.Text, nil
case "amazon":
output := &amazonOutput{}
if err := json.Unmarshal(response.Body, output); err != nil {
return "", err
}

return output.Results[0].OutputText, nil
}

return "", fmt.Errorf("unsupported provider: %s", bioa.provider)
}

// BedrockRuntimeClient is an interface for the Bedrock model runtime client.
type BedrockRuntimeClient interface {
InvokeModel(ctx context.Context, params *bedrockruntime.InvokeModelInput, optFns ...func(*bedrockruntime.Options)) (*bedrockruntime.InvokeModelOutput, error)
}

// BedrockOptions contains options for configuring the Bedrock LLM model.
type BedrockOptions struct {
*schema.CallbackOptions `map:"-"`
schema.Tokenizer `map:"-"`

// Model id to use.
ModelID string `map:"model_id,omitempty"`

// Model params to use.
ModelParams map[string]any `map:"model_params,omitempty"`
}

// Bedrock is a Bedrock LLM model that generates text based on a provided response function.
type Bedrock struct {
schema.Tokenizer
client BedrockRuntimeClient
opts BedrockOptions
}

// NewBedrock creates a new instance of the Bedrock LLM model with the provided response function and options.
func NewBedrock(client BedrockRuntimeClient, optFns ...func(o *BedrockOptions)) (*Bedrock, error) {
opts := BedrockOptions{
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
ModelParams: make(map[string]any),
}

for _, fn := range optFns {
fn(&opts)
}

if opts.Tokenizer == nil {
var tErr error

opts.Tokenizer, tErr = tokenizer.NewGPT2()
if tErr != nil {
return nil, tErr
}
}

return &Bedrock{
Tokenizer: opts.Tokenizer,
client: client,
opts: opts,
}, nil
}

// Generate generates text based on the provided prompt and options.
func (l *Bedrock) Generate(ctx context.Context, prompt string, optFns ...func(o *schema.GenerateOptions)) (*schema.ModelResult, error) {
opts := schema.GenerateOptions{
CallbackManger: &callback.NoopManager{},
}

for _, fn := range optFns {
fn(&opts)
}

provider := l.getProvider()

params := util.CopyMap(l.opts.ModelParams)

if len(opts.Stop) > 0 {
key, ok := providerStopSequenceKeyMap[provider]
if !ok {
return nil, fmt.Errorf("stop sequence key name for provider %s is not supported", provider)
}

params[key] = opts.Stop
}

bioa := NewBedrockInputOutputAdapter(provider)

body, err := bioa.PrepareInput(prompt, params)
if err != nil {
return nil, err
}

res, err := l.client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{
ModelId: aws.String(l.opts.ModelID),
Body: body,
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
})
if err != nil {
return nil, err
}

completion, err := bioa.PrepareOutput(res)
if err != nil {
return nil, err
}

return &schema.ModelResult{
Generations: []schema.Generation{{Text: completion}},
LLMOutput: map[string]any{},
}, nil
}

// Type returns the type of the model.
func (l *Bedrock) Type() string {
return "llm.Bedrock"
}

// Verbose returns the verbosity setting of the model.
func (l *Bedrock) Verbose() bool {
return l.opts.Verbose
}

// Callbacks returns the registered callbacks of the model.
func (l *Bedrock) Callbacks() []schema.Callback {
return l.opts.Callbacks
}

// InvocationParams returns the parameters used in the model invocation.
func (l *Bedrock) InvocationParams() map[string]any {
return util.StructToMap(l.opts)
}

// getProvider returns the provider of the model based on the model ID.
func (l *Bedrock) getProvider() string {
return strings.Split(l.opts.ModelID, ".")[0]
}

0 comments on commit fbd27a3

Please sign in to comment.