Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(go): allow multiple schemas using picoschema #1400

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 76 additions & 36 deletions go/plugins/dotprompt/picoschema.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,16 @@ func parsePico(val any) (*jsonschema.Schema, error) {
case string:
typ, desc, found := strings.Cut(val, ",")
switch typ {
case "string", "boolean", "null", "number", "integer", "any":
case "string", "boolean", "null", "number", "integer":
case "any":
typ = ""
default:
return nil, fmt.Errorf("picoschema: unsupported scalar type %q", typ)
}
if typ == "any" {
typ = ""
}
ret := &jsonschema.Schema{
Type: typ,

ret := &jsonschema.Schema{}
if typ != "" {
ret.Type = typ
}
if found {
ret.Description = strings.TrimSpace(desc)
Expand Down Expand Up @@ -100,40 +101,41 @@ func parsePico(val any) (*jsonschema.Schema, error) {
return nil, err
}

if !found {
ret.Properties.Set(propertyName, property)
continue
}

typ = strings.TrimSuffix(typ, ")")
typ, desc, found := strings.Cut(strings.TrimSuffix(typ, ")"), ",")
switch typ {
case "array":
property = &jsonschema.Schema{
Type: "array",
Items: property,
}
case "object":
// Use property unchanged.
case "enum":
if property.Enum == nil {
return nil, fmt.Errorf("picoschema: enum value %v is not an array", property)
}
if isOptional {
property.Enum = append(property.Enum, nil)
if found {
typ = strings.TrimSuffix(typ, ")")
typ, desc, found := strings.Cut(strings.TrimSuffix(typ, ")"), ",")
switch typ {
case "array":
property = &jsonschema.Schema{
Type: "array",
Items: property,
}
case "object":
// Use property unchanged.
case "enum":
if property.Enum == nil {
return nil, fmt.Errorf("picoschema: enum value %v is not an array", property)
}

if isOptional {
property.Enum = append(property.Enum, nil)
}

case "*":
ret.AdditionalProperties = property
continue
default:
return nil, fmt.Errorf("picoschema: parenthetical type %q is none of %q", typ,
[]string{"object", "array", "enum", "*"})
}

case "*":
ret.AdditionalProperties = property
continue
default:
return nil, fmt.Errorf("picoschema: parenthetical type %q is none of %q", typ,
[]string{"object", "array", "enum", "*"})

if found {
property.Description = strings.TrimSpace(desc)
}
}

if found {
property.Description = strings.TrimSpace(desc)
if isOptional {
property = makeNullable(property)
}

ret.Properties.Set(propertyName, property)
Expand Down Expand Up @@ -268,3 +270,41 @@ func mapToJSONSchema(m map[string]any) (*jsonschema.Schema, error) {

return &ret, nil
}

func makeNullable(schema *jsonschema.Schema) *jsonschema.Schema {
// Do not wrap enums in anyOf
if len(schema.Enum) > 0 {
return schema
}
// If the schema is empty (represents 'any'), do not wrap it
if schema.Type == "" &&
(schema.Properties == nil || schema.Properties.Len() == 0) &&
schema.Items == nil &&
len(schema.AnyOf) == 0 &&
len(schema.AllOf) == 0 &&
len(schema.OneOf) == 0 &&
len(schema.Enum) == 0 {
return schema
}
// Check if the schema already allows null
if schema.Type == "null" {
return schema
}
if len(schema.AnyOf) > 0 {
for _, s := range schema.AnyOf {
if s.Type == "null" {
return schema
}
}
}
// Wrap the original schema to allow null
return &jsonschema.Schema{
AnyOf: []*jsonschema.Schema{
schema,
{Type: "null"},
},
Description: schema.Description,
AdditionalProperties: schema.AdditionalProperties,
Properties: schema.Properties,
}
}
72 changes: 63 additions & 9 deletions go/plugins/dotprompt/picoschema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
package dotprompt

import (
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
Expand All @@ -41,16 +43,8 @@ func TestPicoschema(t *testing.T) {
t.Fatal(err)
}

skip := map[string]bool{
"required field": true,
"nested object in array and out": true,
}

for _, test := range tests {
t.Run(test.Description, func(t *testing.T) {
if skip[test.Description] {
t.Skip("no support for type as an array")
}
var val any
if err := yaml.Unmarshal([]byte(test.YAML), &val); err != nil {
t.Fatalf("YAML unmarshal failure: %v", err)
Expand All @@ -67,8 +61,17 @@ func TestPicoschema(t *testing.T) {
if err != nil {
t.Fatal(err)
}
gotData, err := json.Marshal(got)
if err != nil {
t.Fatal(err)
}
var gotMap map[string]any
if err := json.Unmarshal(gotData, &gotMap); err != nil {
t.Fatal(err)
}
replaceAnyOfWithTypeArray(gotMap)
want := replaceEmptySchemas(test.Want)
if diff := cmp.Diff(want, got); diff != "" {
if diff := cmp.Diff(want, gotMap); diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
})
Expand Down Expand Up @@ -97,3 +100,54 @@ func replaceEmptySchemas(m map[string]any) any {
}
return m
}

func replaceAnyOfWithTypeArray(schema map[string]any) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help me understand what this is for? Is it only needed for the test or everywhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that it is a helper function, to match exactly the test cases with the string expected, without it we have error because the serializate func outputs something like string("red", "blue") and we want AnyOf("red", "blue"), but the important refactor I think it is about: picoschema.go https://github.com/firebase/genkit/pull/1400/files#diff-d643807a9a75f3e4805fd792bbd253ad99819a506fcfb5bdf9ffa20660d87a4eR109

Copy link
Collaborator

@apascal07 apascal07 Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a sample with an actual dotprompt file that uses these features that weren't supported before? The tests are good but we're doing quite a bit of extra logic here so I want to make sure that it actually works in practice. There's a lot happening here so it's hard for me to tell if this is correct or not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like dotpromot have now his own project,

// Check if 'anyOf' is present
if anyOf, ok := schema["anyOf"].([]any); ok && len(anyOf) > 0 {
types := []any{}
descriptions := []string{}
otherKeysExist := false

for _, item := range anyOf {
if subSchema, ok := item.(map[string]any); ok {
// Collect 'type' and 'description' from sub-schemas
if t, hasType := subSchema["type"]; hasType {
types = append(types, t)
} else {
otherKeysExist = true
break
}
if desc, hasDesc := subSchema["description"]; hasDesc {
descriptions = append(descriptions, desc.(string))
}
} else {
otherKeysExist = true
break
}
}

// Replace 'anyOf' with 'type' array if no other keys exist
if !otherKeysExist && len(types) > 0 {
schema["type"] = types
delete(schema, "anyOf")
// Combine descriptions if necessary
if len(descriptions) > 0 && schema["description"] == nil {
schema["description"] = strings.Join(descriptions, "; ")
}
}
}

// Recursively process nested schemas
for _, value := range schema {
switch v := value.(type) {
case map[string]any:
replaceAnyOfWithTypeArray(v)
case []any:
for _, item := range v {
if m, ok := item.(map[string]any); ok {
replaceAnyOfWithTypeArray(m)
}
}
}
}
}
Loading