From 6e2bba7653bf7bcd95cf3af35f422536aeb87d2d Mon Sep 17 00:00:00 2001 From: technillogue Date: Mon, 7 Aug 2023 17:41:39 -0400 Subject: [PATCH] parse type signature without running container (#1155) * parse type signature from AST, without docker run Signed-off-by: technillogue * quote complex types for older python versions Signed-off-by: technillogue * add to tests and start fixing revealed problems Signed-off-by: technillogue * handle defaults more carefully and some other fixes Signed-off-by: technillogue * parse output type Signed-off-by: technillogue * add remaining test cases and make them pass Signed-off-by: technillogue * move static schema behind a flag Signed-off-by: technillogue * fix for py3.8 Signed-off-by: technillogue * another fix for py3.8 Signed-off-by: technillogue * more fixes for py3.7 AST Signed-off-by: technillogue * even more py3.7 fixes Signed-off-by: technillogue * Instead of the static-schema flag, just accept an openapi-schema flag Signed-off-by: technillogue * Write generated schema to a file Signed-off-by: technillogue * Remove debugging changes Signed-off-by: technillogue * fix gocritic Signed-off-by: technillogue * Formatting Signed-off-by: Mattt Zmuda * Fix linting warnings in types.py Signed-off-by: Mattt Zmuda * Log OpenAPI spec when validation fails Signed-off-by: Mattt Zmuda * Fix check for adding Cog labels Signed-off-by: Mattt Zmuda * Apply suggestions from code review Signed-off-by: Mattt --------- Signed-off-by: technillogue Signed-off-by: Mattt Zmuda Signed-off-by: Mattt Co-authored-by: Mattt Zmuda Co-authored-by: Mattt --- pkg/cli/build.go | 8 +- pkg/cli/push.go | 3 +- pkg/image/build.go | 41 +- python/cog/command/ast_openapi_schema.py | 535 +++++++++++++++++++++++ python/cog/schema.py | 9 +- python/cog/types.py | 9 +- python/tests/server/conftest.py | 9 + python/tests/server/test_http.py | 37 +- python/tests/test_types.py | 1 - 9 files changed, 619 insertions(+), 33 deletions(-) create mode 100644 python/cog/command/ast_openapi_schema.py diff --git a/pkg/cli/build.go b/pkg/cli/build.go index d9cd93be9a..bf57295c7d 100644 --- a/pkg/cli/build.go +++ b/pkg/cli/build.go @@ -15,6 +15,7 @@ var buildSeparateWeights bool var buildSecrets []string var buildNoCache bool var buildProgressOutput string +var buildSchemaFile string var buildUseCudaBaseImage string func newBuildCommand() *cobra.Command { @@ -28,6 +29,7 @@ func newBuildCommand() *cobra.Command { addSecretsFlag(cmd) addNoCacheFlag(cmd) addSeparateWeightsFlag(cmd) + addSchemaFlag(cmd) addUseCudaBaseImageFlag(cmd) cmd.Flags().StringVarP(&buildTag, "tag", "t", "", "A name for the built image in the form 'repository:tag'") return cmd @@ -47,7 +49,7 @@ func buildCommand(cmd *cobra.Command, args []string) error { imageName = config.DockerImageName(projectDir) } - if err := image.Build(cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput); err != nil { + if err := image.Build(cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile); err != nil { return err } @@ -76,6 +78,10 @@ func addSeparateWeightsFlag(cmd *cobra.Command) { cmd.Flags().BoolVar(&buildSeparateWeights, "separate-weights", false, "Separate model weights from code in image layers") } +func addSchemaFlag(cmd *cobra.Command) { + cmd.Flags().StringVar(&buildSchemaFile, "openapi-schema", "", "Load OpenAPI schema from a file") +} + func addUseCudaBaseImageFlag(cmd *cobra.Command) { cmd.Flags().StringVar(&buildUseCudaBaseImage, "use-cuda-base-image", "auto", "Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image). False results in a smaller image but may cause problems for non-torch projects") } diff --git a/pkg/cli/push.go b/pkg/cli/push.go index 13a018acac..b090ab7ce3 100644 --- a/pkg/cli/push.go +++ b/pkg/cli/push.go @@ -25,6 +25,7 @@ func newPushCommand() *cobra.Command { addSecretsFlag(cmd) addNoCacheFlag(cmd) addSeparateWeightsFlag(cmd) + addSchemaFlag(cmd) addUseCudaBaseImageFlag(cmd) addBuildProgressOutputFlag(cmd) @@ -46,7 +47,7 @@ func push(cmd *cobra.Command, args []string) error { return fmt.Errorf("To push images, you must either set the 'image' option in cog.yaml or pass an image name as an argument. For example, 'cog push registry.hooli.corp/hotdog-detector'") } - if err := image.Build(cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput); err != nil { + if err := image.Build(cfg, projectDir, imageName, buildSecrets, buildNoCache, buildSeparateWeights, buildUseCudaBaseImage, buildProgressOutput, buildSchemaFile); err != nil { return err } diff --git a/pkg/image/build.go b/pkg/image/build.go index a54d3f4e45..e5870d4751 100644 --- a/pkg/image/build.go +++ b/pkg/image/build.go @@ -9,6 +9,7 @@ import ( "path" "github.com/getkin/kin-openapi/openapi3" + "github.com/replicate/cog/pkg/config" "github.com/replicate/cog/pkg/docker" "github.com/replicate/cog/pkg/dockerfile" @@ -23,7 +24,7 @@ const weightsManifestPath = ".cog/cache/weights_manifest.json" // Build a Cog model from a config // // This is separated out from docker.Build(), so that can be as close as possible to the behavior of 'docker build'. -func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, separateWeights bool, useCudaBaseImage string, progressOutput string) error { +func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, separateWeights bool, useCudaBaseImage string, progressOutput string, schemaFile string) error { console.Infof("Building Docker image from environment in cog.yaml as %s...", imageName) generator, err := dockerfile.NewGenerator(cfg, dir) @@ -79,14 +80,34 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, } console.Info("Validating model schema...") - schema, err := GenerateOpenAPISchema(imageName, cfg.Build.GPU) - if err != nil { - return fmt.Errorf("Failed to get type signature: %w", err) - } - schemaJSON, err := json.Marshal(schema) - if err != nil { - return fmt.Errorf("Failed to convert type signature to JSON: %w", err) + + var schema map[string]interface{} + var schemaJSON []byte + + if schemaFile != "" { + // We were passed a schema file, so use that + schemaJSON, err = os.ReadFile(schemaFile) + if err != nil { + return fmt.Errorf("Failed to read schema file: %w", err) + } + + schema = make(map[string]interface{}) + err = json.Unmarshal(schemaJSON, &schema) + if err != nil { + return fmt.Errorf("Failed to parse schema file: %w", err) + } + } else { + schema, err = GenerateOpenAPISchema(imageName, cfg.Build.GPU) + if err != nil { + return fmt.Errorf("Failed to get type signature: %w", err) + } + + schemaJSON, err = json.Marshal(schema) + if err != nil { + return fmt.Errorf("Failed to convert type signature to JSON: %w", err) + } } + if len(schema) > 0 { loader := openapi3.NewLoader() loader.IsExternalRefsAllowed = true @@ -94,9 +115,10 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, if err != nil { return fmt.Errorf("Failed to load model schema JSON: %w", err) } + err = doc.Validate(loader.Context) if err != nil { - return err + return fmt.Errorf("Model schema is invalid: %w\n\n%s", err, string(schemaJSON)) } } @@ -122,7 +144,6 @@ func Build(cfg *config.Config, dir, imageName string, secrets []string, noCache, "org.cogmodel.config": string(bytes.TrimSpace(configJSON)), } - // OpenAPI schema is not set if there is no predictor. if len(schema) > 0 { labels[global.LabelNamespace+"openapi_schema"] = string(schemaJSON) labels["org.cogmodel.openapi_schema"] = string(schemaJSON) diff --git a/python/cog/command/ast_openapi_schema.py b/python/cog/command/ast_openapi_schema.py new file mode 100644 index 0000000000..bd21db5475 --- /dev/null +++ b/python/cog/command/ast_openapi_schema.py @@ -0,0 +1,535 @@ +import ast +import json +import sys +from pathlib import Path + +try: + assert ast.unparse +except (AssertionError, AttributeError): + # bad "compat" with python3.8 + ast.unparse = repr + +BASE_SCHEMA = """ +{ + "components": { + "schemas": { + "HTTPValidationError": { + "properties": { + "detail": { + "items": { "$ref": "#/components/schemas/ValidationError" }, + "title": "Detail", + "type": "array" + } + }, + "title": "HTTPValidationError", + "type": "object" + }, + "PredictionRequest": { + "properties": { + "created_at": { + "format": "date-time", + "title": "Created At", + "type": "string" + }, + "id": { "title": "Id", "type": "string" }, + "input": { "$ref": "#/components/schemas/Input" }, + "output_file_prefix": { + "title": "Output File Prefix", + "type": "string" + }, + "webhook": { + "format": "uri", + "maxLength": 65536, + "minLength": 1, + "title": "Webhook", + "type": "string" + }, + "webhook_events_filter": { + "default": ["start", "output", "logs", "completed"], + "items": { "$ref": "#/components/schemas/WebhookEvent" }, + "type": "array" + } + }, + "title": "PredictionRequest", + "type": "object" + }, + "PredictionResponse": { + "properties": { + "completed_at": { + "format": "date-time", + "title": "Completed At", + "type": "string" + }, + "created_at": { + "format": "date-time", + "title": "Created At", + "type": "string" + }, + "error": { "title": "Error", "type": "string" }, + "id": { "title": "Id", "type": "string" }, + "input": { "$ref": "#/components/schemas/Input" }, + "logs": { "default": "", "title": "Logs", "type": "string" }, + "metrics": { "title": "Metrics", "type": "object" }, + "output": { "$ref": "#/components/schemas/Output" }, + "started_at": { + "format": "date-time", + "title": "Started At", + "type": "string" + }, + "status": { "$ref": "#/components/schemas/Status" }, + "version": { "title": "Version", "type": "string" } + }, + "title": "PredictionResponse", + "type": "object" + }, + "Status": { + "description": "An enumeration.", + "enum": ["starting", "processing", "succeeded", "canceled", "failed"], + "title": "Status", + "type": "string" + }, + "ValidationError": { + "properties": { + "loc": { + "items": { "anyOf": [{ "type": "string" }, { "type": "integer" }] }, + "title": "Location", + "type": "array" + }, + "msg": { "title": "Message", "type": "string" }, + "type": { "title": "Error Type", "type": "string" } + }, + "required": ["loc", "msg", "type"], + "title": "ValidationError", + "type": "object" + }, + "WebhookEvent": { + "description": "An enumeration.", + "enum": ["start", "output", "logs", "completed"], + "title": "WebhookEvent", + "type": "string" + } + } + }, + "info": { "title": "Cog", "version": "0.1.0" }, + "openapi": "3.0.2", + "paths": { + "/": { + "get": { + "operationId": "root__get", + "responses": { + "200": { + "content": { + "application/json": { + "schema": { "title": "Response Root Get" } + } + }, + "description": "Successful Response" + } + }, + "summary": "Root" + } + }, + "/health-check": { + "get": { + "operationId": "healthcheck_health_check_get", + "responses": { + "200": { + "content": { + "application/json": { + "schema": { "title": "Response Healthcheck Health Check Get" } + } + }, + "description": "Successful Response" + } + }, + "summary": "Healthcheck" + } + }, + "/predictions": { + "post": { + "description": "Run a single prediction on the model", + "operationId": "predict_predictions_post", + "parameters": [ + { + "in": "header", + "name": "prefer", + "required": false, + "schema": { "title": "Prefer", "type": "string" } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/PredictionRequest" } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/PredictionResponse" } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + }, + "description": "Validation Error" + } + }, + "summary": "Predict" + } + }, + "/predictions/{prediction_id}": { + "put": { + "description": "Run a single prediction on the model (idempotent creation).", + "operationId": "predict_idempotent_predictions__prediction_id__put", + "parameters": [ + { + "in": "path", + "name": "prediction_id", + "required": true, + "schema": { "title": "Prediction ID", "type": "string" } + }, + { + "in": "header", + "name": "prefer", + "required": false, + "schema": { "title": "Prefer", "type": "string" } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "allOf": [{ "$ref": "#/components/schemas/PredictionRequest" }], + "title": "Prediction Request" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/PredictionResponse" } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + }, + "description": "Validation Error" + } + }, + "summary": "Predict Idempotent" + } + }, + "/predictions/{prediction_id}/cancel": { + "post": { + "description": "Cancel a running prediction", + "operationId": "cancel_predictions__prediction_id__cancel_post", + "parameters": [ + { + "in": "path", + "name": "prediction_id", + "required": true, + "schema": { "title": "Prediction ID", "type": "string" } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "title": "Response Cancel Predictions Prediction Id Cancel Post" + } + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + }, + "description": "Validation Error" + } + }, + "summary": "Cancel" + } + }, + "/shutdown": { + "post": { + "operationId": "start_shutdown_shutdown_post", + "responses": { + "200": { + "content": { + "application/json": { + "schema": { "title": "Response Start Shutdown Shutdown Post" } + } + }, + "description": "Successful Response" + } + }, + "summary": "Start Shutdown" + } + } + } +} +""" + +OPENAPI_TYPES = { + "str": "string", # includes dates, files + "int": "integer", + "float": "number", + "bool": "boolean", + "list": "array", + "cog.Path": "string", + "cog.File": "string", + "Path": "string", + "File": "string", +} + + +def find(obj: ast.AST, name: str) -> ast.AST: + """Find a particular named node in a tree""" + return next(node for node in ast.walk(obj) if getattr(node, "name", "") == name) + + +def get_value(node: ast.AST) -> "int | float | complex | str | list": + """Return the value of constant or list of constants""" + if isinstance(node, ast.Constant): + return node.value + # for python3.7, were deprecated for Constant + if isinstance(node, (ast.Str, ast.Bytes)): + return node.s + if isinstance(node, ast.Num): + return node.n + if isinstance(node, (ast.List, ast.Tuple)): + return [get_value(e) for e in node.elts] + raise ValueError("Unexpected node type", type(node)) + + +def get_annotation(node: "ast.AST | None") -> str: + """Return the annotation as a string""" + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Constant): + return node.value # e.g. arg: "Path" + # ignore Subscript (Optional[str]), BinOp (str | int), and stuff like that + # except we may need to care about list/List[str] + raise ValueError("Unexpected annotation type", type(node)) + + +def get_call_name(call: ast.Call) -> str: + """Try to get the name of a Call""" + if isinstance(call.func, ast.Name): + return call.func.id + if isinstance(call.func, ast.Attribute): + return call.func.attr + raise ValueError("Unexpected node type", type(call), ast.unparse(call)) + + +def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | Ellipsis]]": + """Parse argument, default pairs from a file with a predict function""" + predict = find(tree, "predict") + assert isinstance(predict, ast.FunctionDef) + args = predict.args.args # [-len(defaults) :] + # use Ellipsis instead of None here to distinguish a default of None + defaults = [...] * (len(args) - len(predict.args.defaults)) + predict.args.defaults + return list(zip(args, defaults)) + + +def parse_assignment(assignment: ast.AST) -> "tuple[str | None, dict | None]": + """Parse an assignment into an OpenAPI object property""" + if isinstance(assignment, ast.AnnAssign): + assert isinstance(assignment.target, ast.Name) # shouldn't be an Attribute + default = {"default": get_value(assignment.value)} if assignment.value else {} + return assignment.target.id, { + "title": assignment.target.id.replace("_", " ").title(), + "type": OPENAPI_TYPES[get_annotation(assignment.annotation)], + **default, + } + if isinstance(assignment, ast.Assign): + if len(assignment.targets) == 1 and isinstance(assignment.targets[0], ast.Name): + value = get_value(assignment.value) + return assignment.targets[0].id, { + "title": assignment.targets[0].id.replace("_", " ").title(), + "type": OPENAPI_TYPES[type(value).__name__], + "default": value, + } + raise ValueError("Unexpected assignment", assignment) + return None, None + + +def parse_class(classdef: ast.AST) -> dict: + """Parse a class definition into an OpenAPI object""" + assert isinstance(classdef, ast.ClassDef) + properties = { + key: property for key, property in map(parse_assignment, classdef.body) if key + } + return { + "title": classdef.name, + "type": "object", + "properties": properties, + } + + +# The supported types are: +# str: a string +# int: an integer +# float: a floating point number +# bool: a boolean +# cog.File: a file-like object representing a file +# cog.Path: a path to a file on disk + +BASE_TYPES = ["str", "int", "float", "bool", "File", "Path"] + + +def resolve_name(node: ast.expr) -> str: + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Index): + # depricated, but needed for py3.8 + return resolve_name(node.value) + if isinstance(node, ast.Attribute): + return node.attr + if isinstance(node, ast.Subscript): + return resolve_name(node.value) + raise ValueError("Unexpected node type", type(node), ast.unparse(node)) + + +def parse_return_annotation(tree: ast.AST, fn: str = "predict") -> "tuple[dict, dict]": + predict = find(tree, fn) + if not isinstance(predict, ast.FunctionDef): + raise ValueError("Could not find predict function") + annotation = predict.returns + if not annotation: + raise TypeError( + """You must set an output type. If your model can return multiple output types, you can explicitly set `Any` as the output type. + +For example: + + from typing import Any + + def predict( + self, + image: Path = Input(description="Input image"), + ) -> Any: + ... +""" + ) + # attributes should be resolved to names, maybe blindly + # subscript values are iterator or + name = resolve_name(annotation) + if isinstance(annotation, ast.Subscript): + # forget about other subscripts like Optional, and assume otherlib.File will still be an uri + slice = resolve_name(annotation.slice) + format = {"format": "uri"} if slice in ("Path", "File") else {} + array_type = {"x-cog-array-type": "iterator"} if "Iterator" in name else {} + display_type = ( + {"x-cog-array-display": "concatenate"} if "Concatenate" in name else {} + ) + return {}, { + "title": "Output", + "type": "array", + "items": { + "type": OPENAPI_TYPES.get(slice, slice), + **format, + }, + **array_type, + **display_type, + } + if name in BASE_TYPES: + # otherwise figure this out... + format = {"format": "uri"} if name in ("Path", "File") else {} + return {}, {"title": "Output", "type": OPENAPI_TYPES.get(name, name), **format} + # it must be a custom object + schema = {name: parse_class(find(tree, name))} + return schema, { + "title": "Output", + "$ref": f"#/components/schemas/{name}", + } + + +KEPT_ATTRS = ("description", "default", "ge", "le", "max_length", "min_length", "regex") + + +def extract_info(code: str) -> dict: + """Parse the schemas from a file with a predict function""" + tree = ast.parse(code) + inputs = {"title": "Input", "type": "object", "properties": {}} + required: "list[str]" = [] + schemas: "dict[str, dict]" = {} + for arg, default in parse_args(tree): + if arg.arg == "self": + continue + if isinstance(default, ast.Call) and get_call_name(default) == "Input": + kws = {kw.arg: get_value(kw.value) for kw in default.keywords} + elif isinstance(default, (ast.Constant, ast.List, ast.Tuple, ast.Str, ast.Num)): + kws = {"default": get_value(default)} # could be None + elif default == ...: # no default + kws = {} + else: + raise ValueError("Unexpected default value", default) + input: dict = {"x-order": len(inputs["properties"])} + # need to handle other types? + arg_type = OPENAPI_TYPES.get(get_annotation(arg.annotation), "string") + if get_annotation(arg.annotation) in ("Path", "File"): + input["format"] = "uri" + for attr in KEPT_ATTRS: + if attr in kws: + input[attr] = kws[attr] + if "default" not in input: + required.append(arg.arg) + if "choices" in kws and isinstance(kws["choices"], list): + input["allOf"] = [{"$ref": f"#/components/schemas/{arg.arg}"}] + # could use type(kws["choices"][0]).__name__ + schemas[arg.arg] = { + "title": arg.arg, + "enum": kws["choices"], + "type": arg_type, + "description": "An enumeration.", + } + else: + input["title"] = arg.arg.replace("_", " ").title() + input["type"] = arg_type + inputs["properties"][arg.arg] = input # type: ignore + if required: + inputs["required"] = required + # List[Path], list[Path], str, Iterator[str], MyOutput, Output + return_schema, output = parse_return_annotation(tree, "predict") + schema = json.loads(BASE_SCHEMA) + components = { + "Input": inputs, + "Output": output, + **schemas, + **return_schema, + } + schema["components"]["schemas"].update(components) + return schema + + +def extract_file(fname: "str | Path") -> dict: + return extract_info(open(fname, encoding="utf-8").read()) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + p = Path(sys.argv[1]) + if p.exists(): + print(json.dumps(extract_file(p))) + else: + print(json.dumps(extract_info(sys.stdin.read()))) diff --git a/python/cog/schema.py b/python/cog/schema.py index 446fe42feb..22948bd6a8 100644 --- a/python/cog/schema.py +++ b/python/cog/schema.py @@ -24,8 +24,11 @@ class WebhookEvent(str, Enum): COMPLETED = "completed" @classmethod - def default_events(cls) -> t.Set["WebhookEvent"]: - return {cls.START, cls.OUTPUT, cls.LOGS, cls.COMPLETED} + def default_events(cls) -> t.List["WebhookEvent"]: + # if this is a set, it gets serialized to an array with an unstable ordering + # so even though it's logically a set, have it as a list for deterministic schemas + # note: this change removes "uniqueItems":true + return [cls.START, cls.OUTPUT, cls.LOGS, cls.COMPLETED] class PredictionBaseModel(pydantic.BaseModel, extra=pydantic.Extra.allow): @@ -41,7 +44,7 @@ class PredictionRequest(PredictionBaseModel): webhook: t.Optional[pydantic.AnyHttpUrl] webhook_events_filter: t.Optional[ - t.Set[WebhookEvent] + t.List[WebhookEvent] ] = WebhookEvent.default_events() @classmethod diff --git a/python/cog/types.py b/python/cog/types.py index 91d80b1b1c..aa61275b3b 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -1,4 +1,3 @@ -import base64 import io import mimetypes import os @@ -56,7 +55,7 @@ def validate(cls, value: Any) -> io.IOBase: parsed_url = urllib.parse.urlparse(value) if parsed_url.scheme == "data": - res = urllib.request.urlopen(value) + res = urllib.request.urlopen(value) # noqa: S310 return io.BytesIO(res.read()) elif parsed_url.scheme == "http" or parsed_url.scheme == "https": return URLFile(value) @@ -211,7 +210,7 @@ def get_filename(url: str) -> str: parsed_url = urllib.parse.urlparse(url) if parsed_url.scheme == "data": - resp = urllib.request.urlopen(url) + resp = urllib.request.urlopen(url) # noqa: S310 mime_type = resp.headers.get_content_type() extension = mimetypes.guess_extension(mime_type) if extension is None: @@ -261,11 +260,11 @@ def validate(cls, value: Any) -> Iterator: return value -def _len_bytes(s, encoding="utf-8"): +def _len_bytes(s, encoding="utf-8") -> int: return len(s.encode(encoding)) -def _truncate_filename_bytes(s, length, encoding="utf-8"): +def _truncate_filename_bytes(s, length, encoding="utf-8") -> str: """ Truncate a filename to at most `length` bytes, preserving file extension and avoiding text encoding corruption from truncation. diff --git a/python/tests/server/conftest.py b/python/tests/server/conftest.py index 5f4c5b71ae..e7c6b18d31 100644 --- a/python/tests/server/conftest.py +++ b/python/tests/server/conftest.py @@ -7,6 +7,7 @@ import pytest from attrs import define +from cog.command import ast_openapi_schema from cog.server.http import create_app from fastapi.testclient import TestClient @@ -74,4 +75,12 @@ def client(request): c = make_client(fixture_name=fixture_name, **options) stack.enter_context(c) wait_for_setup(c) + c.ref = fixture_name yield c + + +@pytest.fixture +def static_schema(client) -> dict: + ref = _fixture_path(client.ref) + module_path = ref.split(":", 1)[0] + return ast_openapi_schema.extract_file(module_path) diff --git a/python/tests/server/test_http.py b/python/tests/server/test_http.py index 518760ec26..9e55e83d84 100644 --- a/python/tests/server/test_http.py +++ b/python/tests/server/test_http.py @@ -25,11 +25,12 @@ def test_predict_works_with_functions(client, match): @uses_predictor("openapi_complex_input") -def test_openapi_specification(client): +def test_openapi_specification(client, static_schema): resp = client.get("/openapi.json") assert resp.status_code == 200 schema = resp.json() + assert schema == static_schema assert schema["openapi"] == "3.0.2" assert schema["info"] == {"title": "Cog", "version": "0.1.0"} assert schema["paths"]["/"] == { @@ -190,11 +191,14 @@ def test_openapi_specification(client): @uses_predictor("openapi_custom_output_type") -def test_openapi_specification_with_custom_user_defined_output_type(client): +def test_openapi_specification_with_custom_user_defined_output_type( + client, static_schema +): resp = client.get("/openapi.json") assert resp.status_code == 200 schema = resp.json() + assert schema == static_schema assert schema["components"]["schemas"]["Output"] == { "$ref": "#/components/schemas/MyOutput", "title": "Output", @@ -219,11 +223,12 @@ def test_openapi_specification_with_custom_user_defined_output_type(client): @uses_predictor("openapi_output_type") def test_openapi_specification_with_custom_user_defined_output_type_called_output( - client, + client, static_schema ): resp = client.get("/openapi.json") assert resp.status_code == 200 - + schema = resp.json() + assert schema == static_schema assert resp.json()["components"]["schemas"]["Output"] == { "properties": { "foo_number": {"default": "42", "title": "Foo Number", "type": "integer"}, @@ -239,11 +244,12 @@ def test_openapi_specification_with_custom_user_defined_output_type_called_outpu @uses_predictor("openapi_output_yield") -def test_openapi_specification_with_yield(client): +def test_openapi_specification_with_yield(client, static_schema): resp = client.get("/openapi.json") assert resp.status_code == 200 - - assert resp.json()["components"]["schemas"]["Output"] == { + schema = resp.json() + assert schema == static_schema + assert schema["components"]["schemas"]["Output"] == { "title": "Output", "type": "array", "items": { @@ -254,11 +260,15 @@ def test_openapi_specification_with_yield(client): @uses_predictor("yield_concatenate_iterator") -def test_openapi_specification_with_yield_with_concatenate_iterator(client): +def test_openapi_specification_with_yield_with_concatenate_iterator( + client, static_schema +): resp = client.get("/openapi.json") assert resp.status_code == 200 - assert resp.json()["components"]["schemas"]["Output"] == { + schema = resp.json() + assert schema == static_schema + assert schema["components"]["schemas"]["Output"] == { "title": "Output", "type": "array", "items": { @@ -270,11 +280,13 @@ def test_openapi_specification_with_yield_with_concatenate_iterator(client): @uses_predictor("openapi_output_list") -def test_openapi_specification_with_list(client): +def test_openapi_specification_with_list(client, static_schema): resp = client.get("/openapi.json") assert resp.status_code == 200 - assert resp.json()["components"]["schemas"]["Output"] == { + schema = resp.json() + assert schema == static_schema + assert schema["components"]["schemas"]["Output"] == { "title": "Output", "type": "array", "items": { @@ -284,11 +296,12 @@ def test_openapi_specification_with_list(client): @uses_predictor("openapi_input_int_choices") -def test_openapi_specification_with_int_choices(client): +def test_openapi_specification_with_int_choices(client, static_schema): resp = client.get("/openapi.json") assert resp.status_code == 200 schema = resp.json() + assert schema == static_schema schemas = schema["components"]["schemas"] assert schemas["Input"]["properties"]["pick_a_number_any_number"] == { diff --git a/python/tests/test_types.py b/python/tests/test_types.py index 3e29eb7419..72efdca091 100644 --- a/python/tests/test_types.py +++ b/python/tests/test_types.py @@ -3,7 +3,6 @@ import pytest import responses - from cog.types import URLFile, get_filename