Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Feat: create context caching for vertexai #1566

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
node_modules
lib
dist
.idea
.vscode
.DS_Store
package-lock.json
Expand Down
15 changes: 9 additions & 6 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package ai

import "time"

type dataPart struct {
Data any `json:"data,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Expand All @@ -36,12 +38,13 @@ const (

// GenerationCommonConfig holds configuration for generation.
type GenerationCommonConfig struct {
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"topK,omitempty"`
TopP float64 `json:"topP,omitempty"`
Version string `json:"version,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"topK,omitempty"`
TopP float64 `json:"topP,omitempty"`
TTL time.Duration `json:"ttl,omitempty"`
Version string `json:"version,omitempty"`
}

// GenerationUsage provides information about the generation process.
Expand Down
46 changes: 28 additions & 18 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"slices"
"strconv"
"strings"
"time"

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/core/logger"
Expand Down Expand Up @@ -99,29 +100,33 @@ func LookupModel(provider, name string) Model {
return (*modelActionDef)(action)
}

// generateParams represents various params of the Generate call.
type generateParams struct {
// GenerateParams represents various params of the Generate call.
type GenerateParams struct {
Request *ModelRequest
Stream ModelStreamingCallback
History []*Message
SystemPrompt *Message
TTL *time.Duration
}

// GenerateOption configures params of the Generate call.
type GenerateOption func(req *generateParams) error
type GenerateOption func(req *GenerateParams) error

// WithTextPrompt adds a simple text user prompt to ModelRequest.
func WithTextPrompt(prompt string) GenerateOption {
return func(req *generateParams) error {
func WithTextPrompt(prompt string, ttl ...int) GenerateOption {
return func(req *GenerateParams) error {
req.Request.Messages = append(req.Request.Messages, NewUserTextMessage(prompt))
if len(ttl) > 0 {
*req.TTL = time.Duration(ttl[0]) * time.Second
}
return nil
}
}

// WithSystemPrompt adds a simple text system prompt as the first message in ModelRequest.
// System prompt will always be put first in the list of messages.
func WithSystemPrompt(prompt string) GenerateOption {
return func(req *generateParams) error {
return func(req *GenerateParams) error {
if req.SystemPrompt != nil {
return errors.New("cannot set system prompt (WithSystemPrompt) more than once")
}
Expand All @@ -132,7 +137,7 @@ func WithSystemPrompt(prompt string) GenerateOption {

// WithMessages adds provided messages to ModelRequest.
func WithMessages(messages ...*Message) GenerateOption {
return func(req *generateParams) error {
return func(req *GenerateParams) error {
req.Request.Messages = append(req.Request.Messages, messages...)
return nil
}
Expand All @@ -143,7 +148,7 @@ func WithMessages(messages ...*Message) GenerateOption {
// exception of system prompt which will always be first.
// [WithMessages] and [WithTextPrompt] will insert messages after system prompt and history.
func WithHistory(history ...*Message) GenerateOption {
return func(req *generateParams) error {
return func(req *GenerateParams) error {
if req.History != nil {
return errors.New("cannot set history (WithHistory) more than once")
}
Expand All @@ -154,7 +159,7 @@ func WithHistory(history ...*Message) GenerateOption {

// WithConfig adds provided config to ModelRequest.
func WithConfig(config any) GenerateOption {
return func(req *generateParams) error {
return func(req *GenerateParams) error {
if req.Request.Config != nil {
return errors.New("cannot set Request.Config (WithConfig) more than once")
}
Expand All @@ -165,15 +170,15 @@ func WithConfig(config any) GenerateOption {

// WithContext adds provided context to ModelRequest.
func WithContext(c ...any) GenerateOption {
return func(req *generateParams) error {
return func(req *GenerateParams) error {
req.Request.Context = append(req.Request.Context, c...)
return nil
}
}

// WithTools adds provided tools to ModelRequest.
func WithTools(tools ...Tool) GenerateOption {
return func(req *generateParams) error {
return func(req *GenerateParams) error {
var toolDefs []*ToolDefinition
for _, t := range tools {
toolDefs = append(toolDefs, t.Definition())
Expand All @@ -185,7 +190,7 @@ func WithTools(tools ...Tool) GenerateOption {

// WithOutputSchema adds provided output schema to ModelRequest.
func WithOutputSchema(schema any) GenerateOption {
return func(req *generateParams) error {
return func(req *GenerateParams) error {
if req.Request.Output != nil && req.Request.Output.Schema != nil {
return errors.New("cannot set Request.Output.Schema (WithOutputSchema) more than once")
}
Expand All @@ -200,7 +205,7 @@ func WithOutputSchema(schema any) GenerateOption {

// WithOutputFormat adds provided output format to ModelRequest.
func WithOutputFormat(format OutputFormat) GenerateOption {
return func(req *generateParams) error {
return func(req *GenerateParams) error {
if req.Request.Output == nil {
req.Request.Output = &ModelRequestOutput{}
}
Expand All @@ -211,7 +216,7 @@ func WithOutputFormat(format OutputFormat) GenerateOption {

// WithStreaming adds a streaming callback to the generate request.
func WithStreaming(cb ModelStreamingCallback) GenerateOption {
return func(req *generateParams) error {
return func(req *GenerateParams) error {
if req.Stream != nil {
return errors.New("cannot set streaming callback (WithStreaming) more than once")
}
Expand All @@ -220,17 +225,22 @@ func WithStreaming(cb ModelStreamingCallback) GenerateOption {
}
}

// Generate run generate request for this model. Returns ModelResponse struct.
func Generate(ctx context.Context, m Model, opts ...GenerateOption) (*ModelResponse, error) {
req := &generateParams{
func GetParams(opts ...GenerateOption) GenerateParams {
req := &GenerateParams{
Request: &ModelRequest{},
}
for _, with := range opts {
err := with(req)
if err != nil {
return nil, err
return GenerateParams{}
}
}
return *req
}

// Generate run generate request for this model. Returns ModelResponse struct.
func Generate(ctx context.Context, m Model, opts ...GenerateOption) (*ModelResponse, error) {
req := GetParams(opts...)
if req.History != nil {
prev := req.Request.Messages
req.Request.Messages = req.History
Expand Down
Loading
Loading