diff --git a/openapi3/schema.go b/openapi3/schema.go index 45350eced..54d084c03 100644 --- a/openapi3/schema.go +++ b/openapi3/schema.go @@ -547,7 +547,11 @@ func (schema *Schema) WithAdditionalProperties(v *Schema) *Schema { return schema } -func (schema *Schema) IsEmpty() bool { +func (schema *Schema) IsEmpty(settings *schemaValidationSettings) bool { + if (schema.ReadOnly && settings.writeEp) || (schema.WriteOnly && settings.readEp) { + return true + } + if schema.Type != "" || schema.Format != "" || len(schema.Enum) != 0 || schema.UniqueItems || schema.ExclusiveMin || schema.ExclusiveMax || schema.Nullable || schema.ReadOnly || schema.WriteOnly || schema.AllowEmptyValue || @@ -558,35 +562,35 @@ func (schema *Schema) IsEmpty() bool { schema.MinProps != 0 || schema.MaxProps != nil { return false } - if n := schema.Not; n != nil && !n.Value.IsEmpty() { + if n := schema.Not; n != nil && !n.Value.IsEmpty(settings) { return false } - if ap := schema.AdditionalProperties; ap != nil && !ap.Value.IsEmpty() { + if ap := schema.AdditionalProperties; ap != nil && !ap.Value.IsEmpty(settings) { return false } if apa := schema.AdditionalPropertiesAllowed; apa != nil && !*apa { return false } - if items := schema.Items; items != nil && !items.Value.IsEmpty() { + if items := schema.Items; items != nil && !items.Value.IsEmpty(settings) { return false } for _, s := range schema.Properties { - if !s.Value.IsEmpty() { + if !s.Value.IsEmpty(settings) { return false } } for _, s := range schema.OneOf { - if !s.Value.IsEmpty() { + if !s.Value.IsEmpty(settings) { return false } } for _, s := range schema.AnyOf { - if !s.Value.IsEmpty() { + if !s.Value.IsEmpty(settings) { return false } } for _, s := range schema.AllOf { - if !s.Value.IsEmpty() { + if !s.Value.IsEmpty(settings) { return false } } @@ -797,7 +801,7 @@ func (schema *Schema) visitJSON(settings *schemaValidationSettings, value interf } } - if schema.IsEmpty() { + if schema.IsEmpty(settings) { return } if err = schema.visitSetOperations(settings, value); err != nil { diff --git a/openapi3/schema_validation_settings.go b/openapi3/schema_validation_settings.go index 71db5f237..83446d7a3 100644 --- a/openapi3/schema_validation_settings.go +++ b/openapi3/schema_validation_settings.go @@ -4,9 +4,10 @@ package openapi3 type SchemaValidationOption func(*schemaValidationSettings) type schemaValidationSettings struct { - failfast bool - multiError bool - asreq, asrep bool // exclusive (XOR) fields + failfast bool + multiError bool + asreq, asrep bool // exclusive (XOR) fields + writeEp, readEp bool } // FailFast returns schema validation errors quicker. @@ -21,10 +22,19 @@ func MultiErrors() SchemaValidationOption { func VisitAsRequest() SchemaValidationOption { return func(s *schemaValidationSettings) { s.asreq, s.asrep = true, false } } + func VisitAsResponse() SchemaValidationOption { return func(s *schemaValidationSettings) { s.asreq, s.asrep = false, true } } +func ReadEndpoint() SchemaValidationOption { + return func(s *schemaValidationSettings) { s.writeEp, s.readEp = false, true } +} + +func WriteEndpoint() SchemaValidationOption { + return func(s *schemaValidationSettings) { s.readEp, s.writeEp = false, true } +} + func newSchemaValidationSettings(opts ...SchemaValidationOption) *schemaValidationSettings { settings := &schemaValidationSettings{} for _, opt := range opts { diff --git a/openapi3filter/options.go b/openapi3filter/options.go index 1622339e2..e4cbd329d 100644 --- a/openapi3filter/options.go +++ b/openapi3filter/options.go @@ -21,4 +21,7 @@ type Options struct { // See NoopAuthenticationFunc AuthenticationFunc AuthenticationFunc + + // See EndpointType + EndpointType EndpointType } diff --git a/openapi3filter/validate_readonly_test.go b/openapi3filter/validate_readonly_test.go index 454a927e9..a90c8e900 100644 --- a/openapi3filter/validate_readonly_test.go +++ b/openapi3filter/validate_readonly_test.go @@ -11,7 +11,184 @@ import ( "github.com/stretchr/testify/require" ) -func TestValidatingRequestBodyWithReadOnlyProperty(t *testing.T) { +func TestValidatingReadRequestBodyWithReadOnlyProperty(t *testing.T) { + const spec = `{ + "openapi": "3.0.3", + "info": { + "version": "1.0.0", + "title": "title", + "description": "desc", + "contact": { + "email": "email" + } + }, + "paths": { + "/accounts": { + "get": { + "description": "Get an account", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["_id"], + "properties": { + "_id": { + "type": "string", + "description": "Unique identifier for this object.", + "pattern": "[0-9a-v]+$", + "minLength": 20, + "maxLength": 20, + "readOnly": true + } + } + } + } + } + }, + "responses": { + "201": { + "description": "Successfully got an account" + }, + "400": { + "description": "The server could not understand the request due to invalid syntax", + } + } + } + } + } +} +` + + type Request struct { + ID string `json:"_id"` + } + + sl := openapi3.NewLoader() + doc, err := sl.LoadFromData([]byte(spec)) + require.NoError(t, err) + err = doc.Validate(sl.Context) + require.NoError(t, err) + router, err := legacyrouter.NewRouter(doc) + require.NoError(t, err) + + b, err := json.Marshal(Request{ID: "bt6kdc3d0cvp6u8u3ft0"}) + require.NoError(t, err) + + httpReq, err := http.NewRequest(http.MethodGet, "/accounts", bytes.NewReader(b)) + require.NoError(t, err) + httpReq.Header.Add(headerCT, "application/json") + + route, pathParams, err := router.FindRoute(httpReq) + require.NoError(t, err) + + err = ValidateRequest(sl.Context, &RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + Options: &Options{EndpointType: ReadEndpoint}, + }) + require.NoError(t, err) + + // Try again with an insufficient length ID + b, err = json.Marshal(Request{ID: "0cvp6u8u3ft0"}) + require.NoError(t, err) + + httpReq, err = http.NewRequest(http.MethodGet, "/accounts", bytes.NewReader(b)) + require.NoError(t, err) + httpReq.Header.Add(headerCT, "application/json") + + err = ValidateRequest(sl.Context, &RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + Options: &Options{EndpointType: ReadEndpoint}, + }) + require.Error(t, err) +} + +func TestValidatingIfReadOrWriteEndpointIsNotKnownDefaultValidatesProperty(t *testing.T) { + const spec = `{ + "openapi": "3.0.3", + "info": { + "version": "1.0.0", + "title": "title", + "description": "desc", + "contact": { + "email": "email" + } + }, + "paths": { + "/accounts": { + "post": { + "description": "Get an account", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["_id"], + "properties": { + "_id": { + "type": "string", + "description": "Unique identifier for this object.", + "pattern": "[0-9a-v]+$", + "minLength": 20, + "maxLength": 20, + "readOnly": true + } + } + } + } + } + }, + "responses": { + "201": { + "description": "Successfully got an account" + }, + "400": { + "description": "The server could not understand the request due to invalid syntax", + } + } + } + } + } +} +` + + type Request struct { + ID string `json:"_id"` + } + + sl := openapi3.NewLoader() + doc, err := sl.LoadFromData([]byte(spec)) + require.NoError(t, err) + err = doc.Validate(sl.Context) + require.NoError(t, err) + router, err := legacyrouter.NewRouter(doc) + require.NoError(t, err) + + b, err := json.Marshal(Request{ID: ""}) + require.NoError(t, err) + + httpReq, err := http.NewRequest(http.MethodPost, "/accounts", bytes.NewReader(b)) + require.NoError(t, err) + httpReq.Header.Add(headerCT, "application/json") + + route, pathParams, err := router.FindRoute(httpReq) + require.NoError(t, err) + + err = ValidateRequest(sl.Context, &RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + }) + require.Error(t, err) +} + +func TestValidatingWriteRequestOnRequiredReadOnlyProperty(t *testing.T) { const spec = `{ "openapi": "3.0.3", "info": { @@ -49,7 +226,7 @@ func TestValidatingRequestBodyWithReadOnlyProperty(t *testing.T) { }, "responses": { "201": { - "description": "Successfully created a new account" + "description": "Successfully created an account" }, "400": { "description": "The server could not understand the request due to invalid syntax", @@ -73,7 +250,8 @@ func TestValidatingRequestBodyWithReadOnlyProperty(t *testing.T) { router, err := legacyrouter.NewRouter(doc) require.NoError(t, err) - b, err := json.Marshal(Request{ID: "bt6kdc3d0cvp6u8u3ft0"}) + // Set no id because id is a required readonly field, but this is a write request + b, err := json.Marshal(Request{ID: ""}) require.NoError(t, err) httpReq, err := http.NewRequest(http.MethodPost, "/accounts", bytes.NewReader(b)) @@ -87,6 +265,7 @@ func TestValidatingRequestBodyWithReadOnlyProperty(t *testing.T) { Request: httpReq, PathParams: pathParams, Route: route, + Options: &Options{EndpointType: WriteEndpoint}, }) require.NoError(t, err) } diff --git a/openapi3filter/validate_request.go b/openapi3filter/validate_request.go index fae6b09f9..be47ccf41 100644 --- a/openapi3filter/validate_request.go +++ b/openapi3filter/validate_request.go @@ -245,6 +245,12 @@ func ValidateRequestBody(ctx context.Context, input *RequestValidationInput, req if options.MultiError { opts = append(opts, openapi3.MultiErrors()) } + if options.EndpointType == ReadEndpoint { + opts = append(opts, openapi3.ReadEndpoint()) + } + if options.EndpointType == WriteEndpoint { + opts = append(opts, openapi3.WriteEndpoint()) + } // Validate JSON with the schema if err := contentType.Schema.Value.VisitJSON(value, opts...); err != nil { diff --git a/openapi3filter/validate_writeonly_test.go b/openapi3filter/validate_writeonly_test.go new file mode 100644 index 000000000..55e947d55 --- /dev/null +++ b/openapi3filter/validate_writeonly_test.go @@ -0,0 +1,191 @@ +package openapi3filter + +import ( + "bytes" + "encoding/json" + "net/http" + "testing" + + "github.com/getkin/kin-openapi/openapi3" + legacyrouter "github.com/getkin/kin-openapi/routers/legacy" + "github.com/stretchr/testify/require" +) + +func TestValidatingWriteRequestBodyWithWriteOnlyProperty(t *testing.T) { + const spec = `{ + "openapi": "3.0.3", + "info": { + "version": "1.0.0", + "title": "title", + "description": "desc", + "contact": { + "email": "email" + } + }, + "paths": { + "/accounts": { + "post": { + "description": "Create a new account", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["_id"], + "properties": { + "_id": { + "type": "string", + "description": "Unique identifier for this object.", + "pattern": "[0-9a-v]+$", + "minLength": 20, + "maxLength": 20, + "writeOnly": true + } + } + } + } + } + }, + "responses": { + "201": { + "description": "Successfully got an account" + }, + "400": { + "description": "The server could not understand the request due to invalid syntax", + } + } + } + } + } +} +` + + type Request struct { + ID string `json:"_id"` + } + + sl := openapi3.NewLoader() + doc, err := sl.LoadFromData([]byte(spec)) + require.NoError(t, err) + err = doc.Validate(sl.Context) + require.NoError(t, err) + router, err := legacyrouter.NewRouter(doc) + require.NoError(t, err) + + b, err := json.Marshal(Request{ID: "bt6kdc3d0cvp6u8u3ft0"}) + require.NoError(t, err) + + httpReq, err := http.NewRequest(http.MethodPost, "/accounts", bytes.NewReader(b)) + require.NoError(t, err) + httpReq.Header.Add(headerCT, "application/json") + + route, pathParams, err := router.FindRoute(httpReq) + require.NoError(t, err) + + err = ValidateRequest(sl.Context, &RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + Options: &Options{EndpointType: WriteEndpoint}, + }) + require.NoError(t, err) + + // Try again with an insufficient length ID + b, err = json.Marshal(Request{ID: "0cvp6u8u3ft0"}) + require.NoError(t, err) + + httpReq, err = http.NewRequest(http.MethodPost, "/accounts", bytes.NewReader(b)) + require.NoError(t, err) + httpReq.Header.Add(headerCT, "application/json") + + err = ValidateRequest(sl.Context, &RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + Options: &Options{EndpointType: WriteEndpoint}, + }) + require.Error(t, err) +} + +func TestValidatingReadRequestOnRequiredWriteOnlyProperty(t *testing.T) { + const spec = `{ + "openapi": "3.0.3", + "info": { + "version": "1.0.0", + "title": "title", + "description": "desc", + "contact": { + "email": "email" + } + }, + "paths": { + "/accounts": { + "get": { + "description": "Get an account", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["_id"], + "properties": { + "_id": { + "type": "string", + "description": "Unique identifier for this object.", + "pattern": "[0-9a-v]+$", + "minLength": 20, + "maxLength": 20, + "writeOnly": true + } + } + } + } + } + }, + "responses": { + "201": { + "description": "Successfully got an account" + }, + "400": { + "description": "The server could not understand the request due to invalid syntax", + } + } + } + } + } +} +` + + type Request struct { + ID string `json:"_id"` + } + + sl := openapi3.NewLoader() + doc, err := sl.LoadFromData([]byte(spec)) + require.NoError(t, err) + err = doc.Validate(sl.Context) + require.NoError(t, err) + router, err := legacyrouter.NewRouter(doc) + require.NoError(t, err) + + // Set no id because id is a required readonly field, but this is a write request + b, err := json.Marshal(Request{ID: ""}) + require.NoError(t, err) + + httpReq, err := http.NewRequest(http.MethodGet, "/accounts", bytes.NewReader(b)) + require.NoError(t, err) + httpReq.Header.Add(headerCT, "application/json") + + route, pathParams, err := router.FindRoute(httpReq) + require.NoError(t, err) + + err = ValidateRequest(sl.Context, &RequestValidationInput{ + Request: httpReq, + PathParams: pathParams, + Route: route, + Options: &Options{EndpointType: ReadEndpoint}, + }) + require.NoError(t, err) +} diff --git a/openapi3filter/validation_handler.go b/openapi3filter/validation_handler.go index eeb1ca1ea..9f40c03fc 100644 --- a/openapi3filter/validation_handler.go +++ b/openapi3filter/validation_handler.go @@ -15,6 +15,15 @@ func NoopAuthenticationFunc(context.Context, *AuthenticationInput) error { retur var _ AuthenticationFunc = NoopAuthenticationFunc +// EndpointType Represents what type of endpoint we'll be running validation on +type EndpointType int32 + +const ( + None EndpointType = iota + WriteEndpoint + ReadEndpoint +) + type ValidationHandler struct { Handler http.Handler AuthenticationFunc AuthenticationFunc