Skip to content

Commit

Permalink
Add bedrock mistral llm support
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Mar 28, 2024
1 parent 5a473bc commit 8309cff
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 2 deletions.
32 changes: 32 additions & 0 deletions examples/models/bedrock_amazon_llm/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package main

import (
"context"
"fmt"
"log"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/hupe1980/golc/model"
"github.com/hupe1980/golc/model/llm"
"github.com/hupe1980/golc/prompt"
)

func main() {
cfg, _ := config.LoadDefaultConfig(context.Background(), config.WithRegion("us-east-1"))
client := bedrockruntime.NewFromConfig(cfg)

bedrock, err := llm.NewBedrockAmazon(client, func(o *llm.BedrockAmazonOptions) {
o.Temperature = 0.3
})
if err != nil {
log.Fatal(err)
}

res, err := model.GeneratePrompt(context.Background(), bedrock, prompt.StringPromptValue("Hello ai!"))
if err != nil {
log.Fatal(err)
}

fmt.Println(res.Generations[0].Text)
}
31 changes: 31 additions & 0 deletions examples/models/bedrock_amazon_llm_streaming/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package main

import (
"context"
"log"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/model"
"github.com/hupe1980/golc/model/llm"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
)

func main() {
cfg, _ := config.LoadDefaultConfig(context.Background(), config.WithRegion("us-east-1"))
client := bedrockruntime.NewFromConfig(cfg)

bedrock, err := llm.NewBedrockAmazon(client, func(o *llm.BedrockAmazonOptions) {
o.Callbacks = []schema.Callback{callback.NewStreamWriterHandler()}
o.Stream = true
})
if err != nil {
log.Fatal(err)
}

if _, err := model.GeneratePrompt(context.Background(), bedrock, prompt.StringPromptValue("Write me a song about sparkling water.")); err != nil {
log.Fatal(err)
}
}
30 changes: 30 additions & 0 deletions examples/models/bedrock_mistral_llm/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package main

import (
"context"
"fmt"
"log"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/hupe1980/golc/model"
"github.com/hupe1980/golc/model/llm"
"github.com/hupe1980/golc/prompt"
)

func main() {
cfg, _ := config.LoadDefaultConfig(context.Background(), config.WithRegion("us-east-1"))
client := bedrockruntime.NewFromConfig(cfg)

bedrock, err := llm.NewBedrockMistral(client)
if err != nil {
log.Fatal(err)
}

res, err := model.GeneratePrompt(context.Background(), bedrock, prompt.StringPromptValue("Tell me a joke"))
if err != nil {
log.Fatal(err)
}

fmt.Println(res.Generations[0].Text)
}
31 changes: 31 additions & 0 deletions examples/models/bedrock_mistral_llm_streaming/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package main

import (
"context"
"log"

"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/model"
"github.com/hupe1980/golc/model/llm"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
)

func main() {
cfg, _ := config.LoadDefaultConfig(context.Background(), config.WithRegion("us-east-1"))
client := bedrockruntime.NewFromConfig(cfg)

bedrock, err := llm.NewBedrockMistral(client, func(o *llm.BedrockMistralOptions) {
o.Callbacks = []schema.Callback{callback.NewStreamWriterHandler()}
o.Stream = true
})
if err != nil {
log.Fatal(err)
}

if _, err := model.GeneratePrompt(context.Background(), bedrock, prompt.StringPromptValue("Write me a song about sparkling water.")); err != nil {
log.Fatal(err)
}
}
97 changes: 96 additions & 1 deletion model/llm/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ var providerStopSequenceKeyMap = map[string]string{
"amazon": "stopSequences",
"ai21": "stop_sequences",
"cohere": "stop_sequences",
"mistral": "stop",
}

// BedrockInputOutputAdapter is a helper struct for preparing input and handling output for Bedrock model.
Expand Down Expand Up @@ -60,7 +61,7 @@ func (bioa *BedrockInputOutputAdapter) PrepareInput(prompt string, modelParams m
body = modelParams

if _, ok := body["max_tokens_to_sample"]; !ok {
body["max_tokens_to_sample"] = 256
body["max_tokens_to_sample"] = 1024
}

body["prompt"] = fmt.Sprintf("\n\nHuman:%s\n\nAssistant:", prompt)
Expand All @@ -70,6 +71,9 @@ func (bioa *BedrockInputOutputAdapter) PrepareInput(prompt string, modelParams m
case "meta":
body = modelParams
body["prompt"] = prompt
case "mistral":
body = modelParams
body["prompt"] = fmt.Sprintf("<s>[INST] %s [/INST]", prompt)
default:
return nil, fmt.Errorf("unsupported provider: %s", bioa.provider)
}
Expand Down Expand Up @@ -115,6 +119,14 @@ type metaOutput struct {
Generation string `json:"generation"`
}

// mistralOutput is a struct representing the output structure for the "mistral" provider.
type mistralOutput struct {
Outputs []struct {
Text string `json:"text"`
StopReason string `json:"stop_reason"`
} `json:"outputs"`
}

// PrepareOutput prepares the output for the Bedrock model based on the specified provider.
func (bioa *BedrockInputOutputAdapter) PrepareOutput(response []byte) (string, error) {
switch bioa.provider {
Expand Down Expand Up @@ -153,6 +165,13 @@ func (bioa *BedrockInputOutputAdapter) PrepareOutput(response []byte) (string, e
}

return output.Generation, nil
case "mistral":
output := &mistralOutput{}
if err := json.Unmarshal(response, output); err != nil {
return "", err
}

return output.Outputs[0].Text, nil
}

return "", fmt.Errorf("unsupported provider: %s", bioa.provider)
Expand Down Expand Up @@ -185,6 +204,14 @@ type metaStreamOutput struct {
Generation string `json:"generation"`
}

// mistralStreamOutput is a struct representing the stream output structure for the "mistral" provider.
type mistralStreamOutput struct {
Outputs []struct {
Text string `json:"text"`
StopReason string `json:"stop_reason"`
} `json:"outputs"`
}

// PrepareStreamOutput prepares the output for the Bedrock model based on the specified provider.
func (bioa *BedrockInputOutputAdapter) PrepareStreamOutput(response []byte) (string, error) {
switch bioa.provider {
Expand Down Expand Up @@ -217,6 +244,13 @@ func (bioa *BedrockInputOutputAdapter) PrepareStreamOutput(response []byte) (str
}

return output.Generation, nil
case "mistral":
output := &mistralStreamOutput{}
if err := json.Unmarshal(response, output); err != nil {
return "", err
}

return output.Outputs[0].Text, nil
}

return "", fmt.Errorf("unsupported provider: %s", bioa.provider)
Expand Down Expand Up @@ -549,6 +583,67 @@ func NewBedrockMeta(client BedrockRuntimeClient, optFns ...func(o *BedrockMetaOp
})
}

type BedrockMistralOptions struct {
*schema.CallbackOptions `map:"-"`
schema.Tokenizer `map:"-"`

// Model id to use.
ModelID string `map:"model_id,omitempty"`

// Temperature controls the randomness of text generation. Higher values make it more random.
Temperature float32 `map:"temperature"`

// TopP is the total probability mass of tokens to consider at each step.
TopP float32 `map:"top_p,omitempty"`

// TopK determines how the model selects tokens for output.
TopK int `map:"top_k"`

// MaxTokens sets the maximum number of tokens in the generated text.
MaxTokens int `json:"max_tokens,omitempty"`

// Stream indicates whether to stream the results or not.
Stream bool `map:"stream,omitempty"`
}

func NewBedrockMistral(client BedrockRuntimeClient, optFns ...func(o *BedrockMistralOptions)) (*Bedrock, error) {
opts := BedrockMistralOptions{
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
ModelID: "mistral.mistral-7b-instruct-v0:2", //https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
Temperature: 0.5,
TopP: 0.9,
TopK: 200,
MaxTokens: 512,
}

for _, fn := range optFns {
fn(&opts)
}

if opts.Tokenizer == nil {
var tErr error

opts.Tokenizer, tErr = tokenizer.NewGPT2()
if tErr != nil {
return nil, tErr
}
}

return NewBedrock(client, opts.ModelID, func(o *BedrockOptions) {
o.CallbackOptions = opts.CallbackOptions
o.Tokenizer = opts.Tokenizer
o.ModelParams = map[string]any{
"temperature": opts.Temperature,
"top_p": opts.TopP,
"top_k": opts.TopK,
"max_tokens": opts.MaxTokens,
}
o.Stream = opts.Stream
})
}

// BedrockOptions contains options for configuring the Bedrock LLM model.
type BedrockOptions struct {
*schema.CallbackOptions `map:"-"`
Expand Down
2 changes: 1 addition & 1 deletion model/llm/bedrock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestBedrockInputOutputAdapter(t *testing.T) {
modelParams: map[string]interface{}{
"param1": "value1",
},
expectedBody: `{"param1":"value1","max_tokens_to_sample":256,"prompt":"\n\nHuman:Test prompt\n\nAssistant:"}`,
expectedBody: `{"param1":"value1","max_tokens_to_sample":1024,"prompt":"\n\nHuman:Test prompt\n\nAssistant:"}`,
expectedErr: "",
},
{
Expand Down

0 comments on commit 8309cff

Please sign in to comment.