Skip to content

Commit

Permalink
Validate model schema on build (#1232)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickstenning authored Jul 31, 2023
1 parent 02c4a3e commit 27ff168
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 10 deletions.
33 changes: 25 additions & 8 deletions pkg/image/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os/exec"
"path"

"github.com/getkin/kin-openapi/openapi3"
"github.com/replicate/cog/pkg/config"
"github.com/replicate/cog/pkg/docker"
"github.com/replicate/cog/pkg/dockerfile"
Expand Down Expand Up @@ -77,18 +78,38 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache,
}
}

console.Info("Adding labels to image...")
console.Info("Validating model schema...")
schema, err := GenerateOpenAPISchema(imageName, cfg.Build.GPU)
if err != nil {
return fmt.Errorf("Failed to get type signature: %w", err)
}
configJSON, err := json.Marshal(cfg)
schemaJSON, err := json.Marshal(schema)
if err != nil {
return fmt.Errorf("Failed to convert config to JSON: %w", err)
return fmt.Errorf("Failed to convert type signature to JSON: %w", err)
}
if len(schema) > 0 {
loader := openapi3.NewLoader()
loader.IsExternalRefsAllowed = true
doc, err := loader.LoadFromData(schemaJSON)
if err != nil {
return fmt.Errorf("Failed to load model schema JSON: %w", err)
}
err = doc.Validate(loader.Context)
if err != nil {
return err
}
}

console.Info("Adding labels to image...")

// We used to set the cog_version and config labels in Dockerfile, because we didn't require running the
// built image to get those. But, the escaping of JSON inside a label inside a Dockerfile was gnarly, and
// doesn't seem to be a problem here, so do it here instead.
configJSON, err := json.Marshal(cfg)
if err != nil {
return fmt.Errorf("Failed to convert config to JSON: %w", err)
}

labels := map[string]string{
global.LabelNamespace + "version": global.Version,
global.LabelNamespace + "config": string(bytes.TrimSpace(configJSON)),
Expand All @@ -102,11 +123,7 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache,
}

// OpenAPI schema is not set if there is no predictor.
if len((*schema).(map[string]interface{})) != 0 {
schemaJSON, err := json.Marshal(schema)
if err != nil {
return fmt.Errorf("Failed to convert type signature to JSON: %w", err)
}
if len(schema) > 0 {
labels[global.LabelNamespace+"openapi_schema"] = string(schemaJSON)
labels["org.cogmodel.openapi_schema"] = string(schemaJSON)
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/image/openapi_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

// GenerateOpenAPISchema by running the image and executing Cog
// This will be run as part of the build process then added as a label to the image. It can be retrieved more efficiently with the label by using GetOpenAPISchema
func GenerateOpenAPISchema(imageName string, enableGPU bool) (*interface{}, error) {
func GenerateOpenAPISchema(imageName string, enableGPU bool) (map[string]any, error) {
var stdout bytes.Buffer
var stderr bytes.Buffer

Expand Down Expand Up @@ -44,7 +44,7 @@ func GenerateOpenAPISchema(imageName string, enableGPU bool) (*interface{}, erro
console.Info(stderr.String())
return nil, err
}
var schema *interface{}
var schema map[string]any
if err := json.Unmarshal(stdout.Bytes(), &schema); err != nil {
// Exit code was 0, but JSON was not returned.
// This is verbose, but print so anything that gets printed in Python bubbles up here.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.8"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from cog import BasePredictor, Input


class Predictor(BasePredictor):
def predict(
self, num: int = Input(description="Number of things", default=1, ge=2, le=10)
) -> int:
return num * 2
11 changes: 11 additions & 0 deletions test-integration/test_integration/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ def test_build_with_model(docker_image):
}


def test_build_invalid_schema(docker_image):
project_dir = Path(__file__).parent / "fixtures/invalid-int-project"
build_process = subprocess.run(
["cog", "build", "-t", docker_image],
cwd=project_dir,
capture_output=True,
)
assert build_process.returncode > 0
assert "invalid default: number must be at least 2" in build_process.stderr.decode()


def test_build_gpu_model_on_cpu(tmpdir, docker_image):
if os.environ.get("CI") != "true":
pytest.skip("only runs on CI environment")
Expand Down

0 comments on commit 27ff168

Please sign in to comment.