Skip to content

Commit

Permalink
Validate model schema on build
Browse files Browse the repository at this point in the history
It's important that the OpenAPI schema generated by cog validates. This
adds a validation step into the `cog build` process.

This should at least help us reduce the number of invalid schemas pushed
to r8.im.
  • Loading branch information
nickstenning committed Jul 28, 2023
1 parent 02c4a3e commit bf767ea
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 10 deletions.
33 changes: 25 additions & 8 deletions pkg/image/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os/exec"
"path"

"github.com/getkin/kin-openapi/openapi3"
"github.com/replicate/cog/pkg/config"
"github.com/replicate/cog/pkg/docker"
"github.com/replicate/cog/pkg/dockerfile"
Expand Down Expand Up @@ -77,18 +78,38 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache,
}
}

console.Info("Adding labels to image...")
console.Info("Validating model schema...")
schema, err := GenerateOpenAPISchema(imageName, cfg.Build.GPU)
if err != nil {
return fmt.Errorf("Failed to get type signature: %w", err)
}
configJSON, err := json.Marshal(cfg)
schemaJSON, err := json.Marshal(schema)
if err != nil {
return fmt.Errorf("Failed to convert config to JSON: %w", err)
return fmt.Errorf("Failed to convert type signature to JSON: %w", err)
}
if len(schema) > 0 {
loader := openapi3.NewLoader()
loader.IsExternalRefsAllowed = true
doc, err := loader.LoadFromData(schemaJSON)
if err != nil {
return fmt.Errorf("Failed to load model schema JSON: %w", err)
}
err = doc.Validate(loader.Context)
if err != nil {
return err
}
}

console.Info("Adding labels to image...")

// We used to set the cog_version and config labels in Dockerfile, because we didn't require running the
// built image to get those. But, the escaping of JSON inside a label inside a Dockerfile was gnarly, and
// doesn't seem to be a problem here, so do it here instead.
configJSON, err := json.Marshal(cfg)
if err != nil {
return fmt.Errorf("Failed to convert config to JSON: %w", err)
}

labels := map[string]string{
global.LabelNamespace + "version": global.Version,
global.LabelNamespace + "config": string(bytes.TrimSpace(configJSON)),
Expand All @@ -102,11 +123,7 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache,
}

// OpenAPI schema is not set if there is no predictor.
if len((*schema).(map[string]interface{})) != 0 {
schemaJSON, err := json.Marshal(schema)
if err != nil {
return fmt.Errorf("Failed to convert type signature to JSON: %w", err)
}
if len(schema) > 0 {
labels[global.LabelNamespace+"openapi_schema"] = string(schemaJSON)
labels["org.cogmodel.openapi_schema"] = string(schemaJSON)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/image/openapi_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

// GenerateOpenAPISchema by running the image and executing Cog
// This will be run as part of the build process then added as a label to the image. It can be retrieved more efficiently with the label by using GetOpenAPISchema
func GenerateOpenAPISchema(imageName string, enableGPU bool) (*interface{}, error) {
func GenerateOpenAPISchema(imageName string, enableGPU bool) (map[string]any, error) {
var stdout bytes.Buffer
var stderr bytes.Buffer

Expand Down Expand Up @@ -44,7 +44,7 @@ func GenerateOpenAPISchema(imageName string, enableGPU bool) (*interface{}, erro
console.Info(stderr.String())
return nil, err
}
var schema *interface{}
var schema map[string]any
if err := json.Unmarshal(stdout.Bytes(), &schema); err != nil {
// Exit code was 0, but JSON was not returned.
// This is verbose, but print so anything that gets printed in Python bubbles up here.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.8"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from cog import BasePredictor, Input


class Predictor(BasePredictor):
def predict(
self, num: int = Input(description="Number of things", default=1, ge=2, le=10)
) -> int:
return num * 2
10 changes: 10 additions & 0 deletions test-integration/test_integration/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@ def test_build_with_model(docker_image):
}


def test_build_invalid_schema(docker_image):
project_dir = Path(__file__).parent / "fixtures/invalid-int-project"
subprocess.run(
["cog", "build", "-t", docker_image],
cwd=project_dir,
)
assert build_process.returncode > 0
assert "invalid default: number must be at least 2" in build_process.stderr.decode()


def test_build_gpu_model_on_cpu(tmpdir, docker_image):
if os.environ.get("CI") != "true":
pytest.skip("only runs on CI environment")
Expand Down

0 comments on commit bf767ea

Please sign in to comment.