From 4b249823a95836550924af0fa964262684473da6 Mon Sep 17 00:00:00 2001 From: Margus Kerma Date: Wed, 10 Jul 2024 16:11:01 +0300 Subject: [PATCH] Fixed EncodeJSONResponse body writing #1291 --- transport/http/server.go | 14 +++++++--- transport/http/server_test.go | 51 ++++++++++++++++++++--------------- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/transport/http/server.go b/transport/http/server.go index ab87d4ad0..b98375064 100644 --- a/transport/http/server.go +++ b/transport/http/server.go @@ -159,7 +159,8 @@ func NopRequestDecoder(ctx context.Context, r *http.Request) (interface{}, error // JSON object to the ResponseWriter. Many JSON-over-HTTP services can use it as // a sensible default. If the response implements Headerer, the provided headers // will be applied to the response. If the response implements StatusCoder, the -// provided StatusCode will be used instead of 200. +// provided StatusCode will be used instead of 200. If the StatusCode is between 100-199, +// or equal to 204 or 304, the response body will not be written. func EncodeJSONResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { w.Header().Set("Content-Type", "application/json; charset=utf-8") if headerer, ok := response.(Headerer); ok { @@ -174,10 +175,17 @@ func EncodeJSONResponse(_ context.Context, w http.ResponseWriter, response inter code = sc.StatusCode() } w.WriteHeader(code) - if code == http.StatusNoContent { + + switch { + case code >= 100 && code <= 199: + return nil + case code == http.StatusNoContent: + return nil + case code == http.StatusNotModified: return nil + default: + return json.NewEncoder(w).Encode(response) } - return json.NewEncoder(w).Encode(response) } // DefaultErrorEncoder writes the error to the ResponseWriter, by default a diff --git a/transport/http/server_test.go b/transport/http/server_test.go index 5c0fadb29..ddee7b56b 100644 --- a/transport/http/server_test.go +++ b/transport/http/server_test.go @@ -3,6 +3,7 @@ package http_test import ( "context" "errors" + "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -332,30 +333,38 @@ func TestAddMultipleHeadersErrorEncoder(t *testing.T) { } } -type noContentResponse struct{} +type noBodyResponse struct{ Code int } -func (e noContentResponse) StatusCode() int { return http.StatusNoContent } +func (e noBodyResponse) StatusCode() int { return e.Code } -func TestEncodeNoContent(t *testing.T) { - handler := httptransport.NewServer( - func(context.Context, interface{}) (interface{}, error) { return noContentResponse{}, nil }, - func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, - httptransport.EncodeJSONResponse, - ) - - server := httptest.NewServer(handler) - defer server.Close() - - resp, err := http.Get(server.URL) - if err != nil { - t.Fatal(err) +func TestEncodeNoBody(t *testing.T) { + testCases := []int{ + http.StatusNoContent, + http.StatusNotModified, } - if want, have := http.StatusNoContent, resp.StatusCode; want != have { - t.Errorf("StatusCode: want %d, have %d", want, have) - } - buf, _ := ioutil.ReadAll(resp.Body) - if want, have := 0, len(buf); want != have { - t.Errorf("Body: want no content, have %d bytes", have) + for _, code := range testCases { + t.Run(fmt.Sprint(code), func(t *testing.T) { + handler := httptransport.NewServer( + func(context.Context, interface{}) (interface{}, error) { return noBodyResponse{code}, nil }, + func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil }, + httptransport.EncodeJSONResponse, + ) + + server := httptest.NewServer(handler) + defer server.Close() + + resp, err := http.Get(server.URL) + if err != nil { + t.Fatal(err) + } + if want, have := code, resp.StatusCode; want != have { + t.Errorf("StatusCode: want %d, have %d", want, have) + } + buf, _ := ioutil.ReadAll(resp.Body) + if want, have := 0, len(buf); want != have { + t.Errorf("Body: want no content, have %d bytes", have) + } + }) } }