Skip to content

Commit

Permalink
Fixed EncodeJSONResponse body writing go-kit#1291
Browse files Browse the repository at this point in the history
  • Loading branch information
kerma committed Jul 10, 2024
1 parent 78fbbce commit 0cdcd43
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 27 deletions.
14 changes: 11 additions & 3 deletions transport/http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ func NopRequestDecoder(ctx context.Context, r *http.Request) (interface{}, error
// JSON object to the ResponseWriter. Many JSON-over-HTTP services can use it as
// a sensible default. If the response implements Headerer, the provided headers
// will be applied to the response. If the response implements StatusCoder, the
// provided StatusCode will be used instead of 200.
// provided StatusCode will be used instead of 200. If the StatusCode is between 100-199,
// or equal to 204 or 304, the response body will not be written.
func EncodeJSONResponse(_ context.Context, w http.ResponseWriter, response interface{}) error {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
if headerer, ok := response.(Headerer); ok {
Expand All @@ -174,10 +175,17 @@ func EncodeJSONResponse(_ context.Context, w http.ResponseWriter, response inter
code = sc.StatusCode()
}
w.WriteHeader(code)
if code == http.StatusNoContent {

switch {
case code >= 100 && code <= 199:
return nil
case code == http.StatusNoContent:
return nil
case code == http.StatusNotModified:
return nil
default:
return json.NewEncoder(w).Encode(response)
}
return json.NewEncoder(w).Encode(response)
}

// DefaultErrorEncoder writes the error to the ResponseWriter, by default a
Expand Down
59 changes: 35 additions & 24 deletions transport/http/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package http_test
import (
"context"
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -68,7 +69,9 @@ func TestServerErrorEncoder(t *testing.T) {
func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errTeapot },
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
func(context.Context, http.ResponseWriter, interface{}) error { return nil },
httptransport.ServerErrorEncoder(func(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(code(err)) }),
httptransport.ServerErrorEncoder(
func(_ context.Context, err error, w http.ResponseWriter) { w.WriteHeader(code(err)) },
),
)
server := httptest.NewServer(handler)
defer server.Close()
Expand Down Expand Up @@ -281,7 +284,7 @@ func TestAddMultipleHeaders(t *testing.T) {
if err != nil {
t.Fatal(err)
}
expect := map[string]map[string]struct{}{"Vary": map[string]struct{}{"Origin": struct{}{}, "User-Agent": struct{}{}}}
expect := map[string]map[string]struct{}{"Vary": {"Origin": {}, "User-Agent": {}}}
for k, vls := range resp.Header {
for _, v := range vls {
delete((expect[k]), v)
Expand Down Expand Up @@ -318,7 +321,7 @@ func TestAddMultipleHeadersErrorEncoder(t *testing.T) {
if err != nil {
t.Fatal(err)
}
expect := map[string]map[string]struct{}{"Vary": map[string]struct{}{"Origin": struct{}{}, "User-Agent": struct{}{}}}
expect := map[string]map[string]struct{}{"Vary": {"Origin": {}, "User-Agent": {}}}
for k, vls := range resp.Header {
for _, v := range vls {
delete((expect[k]), v)
Expand All @@ -332,30 +335,38 @@ func TestAddMultipleHeadersErrorEncoder(t *testing.T) {
}
}

type noContentResponse struct{}
type noBodyResponse struct{ Code int }

func (e noContentResponse) StatusCode() int { return http.StatusNoContent }
func (e noBodyResponse) StatusCode() int { return e.Code }

func TestEncodeNoContent(t *testing.T) {
handler := httptransport.NewServer(
func(context.Context, interface{}) (interface{}, error) { return noContentResponse{}, nil },
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
httptransport.EncodeJSONResponse,
)

server := httptest.NewServer(handler)
defer server.Close()

resp, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
func TestEncodeNoBody(t *testing.T) {
testCases := []int{
http.StatusNoContent,
http.StatusNotModified,
}
if want, have := http.StatusNoContent, resp.StatusCode; want != have {
t.Errorf("StatusCode: want %d, have %d", want, have)
}
buf, _ := ioutil.ReadAll(resp.Body)
if want, have := 0, len(buf); want != have {
t.Errorf("Body: want no content, have %d bytes", have)
for _, code := range testCases {
t.Run(fmt.Sprint(code), func(t *testing.T) {
handler := httptransport.NewServer(
func(context.Context, interface{}) (interface{}, error) { return noBodyResponse{code}, nil },
func(context.Context, *http.Request) (interface{}, error) { return struct{}{}, nil },
httptransport.EncodeJSONResponse,
)

server := httptest.NewServer(handler)
defer server.Close()

resp, err := http.Get(server.URL)
if err != nil {
t.Fatal(err)
}
if want, have := code, resp.StatusCode; want != have {
t.Errorf("StatusCode: want %d, have %d", want, have)
}
buf, _ := ioutil.ReadAll(resp.Body)
if want, have := 0, len(buf); want != have {
t.Errorf("Body: want no content, have %d bytes", have)
}
})
}
}

Expand Down

0 comments on commit 0cdcd43

Please sign in to comment.