diff --git a/mux.go b/mux.go index f8dc38f..b8d3d1f 100644 --- a/mux.go +++ b/mux.go @@ -186,6 +186,7 @@ func registerFuegoController[T, B any, Contexted ctx[B]](s *Server, method, path if route.MainRouter == nil { route.MainRouter = s } + route.AcceptedContentTypes = route.MainRouter.acceptedContentTypes for _, o := range options { o(&route) diff --git a/option/option.go b/option/option.go index 2191fc9..0e9c6dc 100644 --- a/option/option.go +++ b/option/option.go @@ -237,6 +237,7 @@ func AddError(code int, description string, errorType ...any) func(*fuego.BaseRo // RequestContentType sets the accepted content types for the route. // By default, the accepted content types is */*. +// This will override any options set at the server level. func RequestContentType(consumes ...string) func(*fuego.BaseRoute) { return func(r *fuego.BaseRoute) { r.AcceptedContentTypes = consumes diff --git a/option/option_test.go b/option/option_test.go index 3ec9846..84ae8d6 100644 --- a/option/option_test.go +++ b/option/option_test.go @@ -338,6 +338,22 @@ func TestRequestContentType(t *testing.T) { _, ok := s.OpenApiSpec.Components.RequestBodies["ReqBody"] require.False(t, ok) }) + + t.Run("override server", func(t *testing.T) { + s := fuego.NewServer(fuego.WithRequestContentType("application/json", "application/xml")) + route := fuego.Post( + s, "/test", dummyController, + RequestContentType("my/content-type"), + ) + + content := route.Operation.RequestBody.Value.Content + require.Nil(t, content.Get("application/json")) + require.Nil(t, content.Get("application/xml")) + require.NotNil(t, content.Get("my/content-type")) + require.Equal(t, "#/components/schemas/ReqBody", content.Get("my/content-type").Schema.Ref) + _, ok := s.OpenApiSpec.Components.RequestBodies["ReqBody"] + require.False(t, ok) + }) } func TestAddError(t *testing.T) { diff --git a/options.go b/options.go index a629bc5..fda1a15 100644 --- a/options.go +++ b/options.go @@ -73,6 +73,8 @@ type Server struct { fs fs.FS template *template.Template // TODO: use preparsed templates + acceptedContentTypes []string + DisallowUnknownFields bool // If true, the server will return an error if the request body contains unknown fields. Useful for quick debugging in development. DisableOpenapi bool // If true, the routes within the server will not generate an OpenAPI spec. maxBodySize int64 @@ -307,6 +309,12 @@ func WithLogHandler(handler slog.Handler) func(*Server) { } } +// WithRequestContentType sets the accepted content types for the server. +// By default, the accepted content types is */*. +func WithRequestContentType(consumes ...string) func(*Server) { + return func(s *Server) { s.acceptedContentTypes = consumes } +} + // WithSerializer sets a custom serializer of type Sender that overrides the default one. // Please send a PR if you think the default serializer should be improved, instead of jumping to this option. func WithSerializer(serializer Sender) func(*Server) { diff --git a/options_test.go b/options_test.go index f8f63c3..645ed14 100644 --- a/options_test.go +++ b/options_test.go @@ -375,6 +375,45 @@ func TestServerTags(t *testing.T) { }) } +type ReqBody struct { + A string + B int +} + +type Resp struct { + Message string `json:"message"` +} + +func dummyController(_ *ContextWithBody[ReqBody]) (Resp, error) { + return Resp{Message: "hello world"}, nil +} + +func TestWithRequestContentType(t *testing.T) { + t.Run("base", func(t *testing.T) { + s := NewServer() + require.Nil(t, s.acceptedContentTypes) + }) + + t.Run("input", func(t *testing.T) { + arr := []string{"application/json", "application/xml"} + s := NewServer(WithRequestContentType("application/json", "application/xml")) + require.ElementsMatch(t, arr, s.acceptedContentTypes) + }) + + t.Run("ensure applied to route", func(t *testing.T) { + s := NewServer(WithRequestContentType("application/json", "application/xml")) + route := Post(s, "/test", dummyController) + + content := route.Operation.RequestBody.Value.Content + require.NotNil(t, content.Get("application/json")) + require.NotNil(t, content.Get("application/xml")) + require.Equal(t, "#/components/schemas/ReqBody", content.Get("application/json").Schema.Ref) + require.Equal(t, "#/components/schemas/ReqBody", content.Get("application/xml").Schema.Ref) + _, ok := s.OpenApiSpec.Components.RequestBodies["ReqBody"] + require.False(t, ok) + }) +} + func TestCustomSerialization(t *testing.T) { s := NewServer( WithSerializer(func(w http.ResponseWriter, r *http.Request, a any) error {