Skip to content

Commit

Permalink
Feat: add endOfCachedContents
Browse files Browse the repository at this point in the history
  • Loading branch information
alonsopec89 committed Dec 20, 2024
1 parent 2b6f573 commit 5f560e7
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
69 changes: 68 additions & 1 deletion go/plugins/vertexai/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ func getContentForCache(
modelVersion string,
cacheConfig *CacheConfigDetails,
) (*genai.CachedContent, error) {
endOfCachedContents, extractedCacheConfig, err := extractCacheConfig(request)
if err != nil {
return nil, err
}

// If no cache metadata found, return nil
if extractedCacheConfig == nil {
return nil, nil
}
if endOfCachedContents < 0 || endOfCachedContents >= len(request.Messages) {
return nil, fmt.Errorf("invalid endOfCachedContents index")
}
if cacheConfig == nil {
cacheConfig = extractedCacheConfig
}
var systemInstruction string
var userParts []*genai.Content

Expand Down Expand Up @@ -177,6 +192,58 @@ func validateContextCacheRequest(request *ai.ModelRequest, modelVersion string)
return nil
}

func extractCacheConfig(request *ai.ModelRequest) (int, *CacheConfigDetails, error) {
endOfCachedContents := -1
var cacheConfig *CacheConfigDetails

for i := len(request.Messages) - 1; i >= 0; i-- {
m := request.Messages[i]
if m.Metadata != nil {
if c, ok := m.Metadata["cache"]; ok && c != nil {
// Found a message with `metadata.cache`
endOfCachedContents = i

// Parse the cache config. The TS code uses zod schema;
// here we assume `cache` can be either a boolean or a map with `ttlSeconds`.
switch val := c.(type) {
case bool:
// If it's just a boolean, true = default TTL, false = no cache
if val {
cacheConfig = &CacheConfigDetails{
TTLSeconds: 0, // use default if 0
}
} else {
// false means no caching
cacheConfig = &CacheConfigDetails{
TTLSeconds: 0,
}
}

case map[string]interface{}:
ttlSeconds := time.Duration(0)
if ttlVal, ok := val["ttlSeconds"].(float64); ok {
ttlSeconds = time.Duration(ttlVal)
}
cacheConfig = &CacheConfigDetails{
TTLSeconds: ttlSeconds,
}

default:
return -1, nil, fmt.Errorf("invalid cache config type: %T", val)
}
break
}
}
}

if endOfCachedContents == -1 {
// No cache metadata found
return -1, nil, nil
}

return endOfCachedContents, cacheConfig, nil
}

// handleCacheIfNeeded checks if caching should be used, attempts to find or create the cache,
// and returns the cached content if applicable.
func handleCacheIfNeeded(
Expand All @@ -191,7 +258,7 @@ func handleCacheIfNeeded(
return nil, nil
}
cachedContent, err := getContentForCache(request, modelVersion, cacheConfig)
if err != nil {
if err != nil || cachedContent == nil {
return nil, nil
}

Expand Down
3 changes: 3 additions & 0 deletions go/samples/cache/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ import (
"github.com/firebase/genkit/go/ai"
"github.com/firebase/genkit/go/plugins/vertexai"
"log"
"os"
"time"
)

func main() {
os.Setenv("GCLOUD_PROJECT", "404021582266")
os.Setenv("GOOGLE_GENAI_API_KEY", "AIzaSyBGeUMp-a8EPKkm_OEww0rjX4sMMti8eEo")
ctx := context.Background()
if err := vertexai.Init(ctx, nil); err != nil {
log.Fatal(err)
Expand Down

0 comments on commit 5f560e7

Please sign in to comment.