Skip to content

Commit

Permalink
Add fallback for parsing invalid rest errors
Browse files Browse the repository at this point in the history
  • Loading branch information
emcfarlane committed Sep 20, 2023
1 parent 9db9627 commit 18f1828
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 5 deletions.
38 changes: 36 additions & 2 deletions protocol_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ import (

"connectrpc.com/connect"
"google.golang.org/genproto/googleapis/api/annotations"
httpbody "google.golang.org/genproto/googleapis/api/httpbody"
"google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/anypb"
)

func httpStatusCodeFromRPC(code connect.Code) int {
Expand Down Expand Up @@ -57,6 +59,26 @@ func httpStatusCodeFromRPC(code connect.Code) int {
return codes[code]
}

func httpStatusCodeToRPC(code int) connect.Code {
switch code {
case http.StatusOK:
return 0 // OK
case http.StatusUnauthorized:
return connect.CodeUnauthenticated // Unauthenticated
case http.StatusForbidden:
return connect.CodePermissionDenied // PermissionDenied
case http.StatusNotFound:
return connect.CodeUnimplemented // Unimplemented
case http.StatusTooManyRequests,
http.StatusBadGateway,
http.StatusServiceUnavailable,
http.StatusGatewayTimeout:
return connect.CodeUnavailable // Unavailable
default:
return connect.CodeUnknown // Unknown
}
}

func httpWriteError(rsp http.ResponseWriter, err error) {
codec := protojson.MarshalOptions{
EmitUnpopulated: true,
Expand All @@ -78,7 +100,10 @@ func httpWriteError(rsp http.ResponseWriter, err error) {
_, _ = rsp.Write(bin)
}

func httpErrorFromResponse(body io.Reader) *connect.Error {
func httpErrorFromResponse(statusCode int, contentType string, body io.Reader) *connect.Error {
if statusCode == http.StatusOK {
return nil
}
codec := protojson.UnmarshalOptions{}
body = io.LimitReader(body, 1024)
bin, err := io.ReadAll(body)
Expand All @@ -87,7 +112,16 @@ func httpErrorFromResponse(body io.Reader) *connect.Error {
}
var stat status.Status
if err := codec.Unmarshal(bin, &stat); err != nil {
return connect.NewError(connect.CodeInternal, err)
body, err := anypb.New(&httpbody.HttpBody{
ContentType: contentType,
Data: bin,
})
if err != nil {
return connect.NewError(connect.CodeInternal, err)
}
stat.Details = append(stat.Details, body)
stat.Code = int32(httpStatusCodeToRPC(statusCode))
stat.Message = http.StatusText(statusCode)
}

connectErr := connect.NewWireError(
Expand Down
71 changes: 70 additions & 1 deletion protocol_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ import (
testv1 "connectrpc.com/vanguard/internal/gen/vanguard/test/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/genproto/googleapis/api/httpbody"
"google.golang.org/genproto/googleapis/rpc/errdetails"
"google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
)

func TestHTTPErrorWriter(t *testing.T) {
Expand All @@ -46,10 +51,74 @@ func TestHTTPErrorWriter(t *testing.T) {
assert.Equal(t, `{"code":16,"message":"test error: Hello, 世界","details":[]}`, out.String())

body := bytes.NewReader(rec.Body.Bytes())
got := httpErrorFromResponse(body)
got := httpErrorFromResponse(http.StatusUnauthorized, "application/json", body)
assert.Equal(t, cerr, got)
}

func TestHTTPErrorFromResponse(t *testing.T) {
t.Parallel()
t.Run("empty", func(t *testing.T) {
t.Parallel()
var body bytes.Buffer
got := httpErrorFromResponse(http.StatusOK, "", &body)
assert.Nil(t, got)
})
t.Run("jsonStatus", func(t *testing.T) {
t.Parallel()
errorInfo, err := anypb.New(&errdetails.ErrorInfo{
Reason: "user is not authorized",
Domain: "vanguard.connectrpc.com",
Metadata: map[string]string{"key1": "value1"},
})
require.Nil(t, err)
stat := status.Status{
Code: int32(connect.CodeUnauthenticated),
Message: "auth error",
Details: []*anypb.Any{errorInfo},
}
out, err := protojson.Marshal(&stat)
require.Nil(t, err)
got := httpErrorFromResponse(http.StatusUnauthorized, "application/json", bytes.NewReader(out))
assert.Equal(t, connect.CodeUnauthenticated, got.Code())
assert.Equal(t, "auth error", got.Message())
})
t.Run("invalidStatus", func(t *testing.T) {
t.Parallel()
body := bytes.NewReader([]byte("unauthorized"))
got := httpErrorFromResponse(http.StatusUnauthorized, "application/json", body)
assert.Equal(t, connect.CodeUnauthenticated, got.Code())
assert.Equal(t, "Unauthorized", got.Message())
assert.Len(t, got.Details(), 1)
value, err := got.Details()[0].Value()
assert.NoError(t, err)
httpBody, ok := value.(*httpbody.HttpBody)
assert.True(t, ok)
assert.Equal(t, "application/json", httpBody.ContentType)
assert.Equal(t, []byte("unauthorized"), httpBody.Data)
})
t.Run("invalidAny", func(t *testing.T) {
t.Parallel()
stat := status.Status{
Code: int32(connect.CodeUnauthenticated),
Message: "auth error",
}
out, err := protojson.Marshal(&stat)
require.Nil(t, err)
out = append(out[:len(out)-1], []byte(`,"details":{"@type":"foo","value":"bar"}`)...)
got := httpErrorFromResponse(http.StatusUnauthorized, "application/json", bytes.NewReader(out))
t.Log(got)
assert.Equal(t, connect.CodeUnauthenticated, got.Code())
assert.Equal(t, "Unauthorized", got.Message())
assert.Len(t, got.Details(), 1)
value, err := got.Details()[0].Value()
assert.NoError(t, err)
httpBody, ok := value.(*httpbody.HttpBody)
assert.True(t, ok)
assert.Equal(t, "application/json", httpBody.ContentType)
assert.Equal(t, out, httpBody.Data)
})
}

func TestHTTPEncodePathValues(t *testing.T) {
t.Parallel()

Expand Down
4 changes: 2 additions & 2 deletions protocol_rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,18 +280,18 @@ func (r restServerProtocol) addProtocolRequestHeaders(meta requestMeta, headers
}

func (r restServerProtocol) extractProtocolResponseHeaders(statusCode int, headers http.Header) (responseMeta, responseEndUnmarshaller, error) {
contentType := headers.Get("Content-Type")
if statusCode/100 != 2 {
return responseMeta{
end: &responseEnd{httpCode: statusCode},
}, func(_ Codec, src io.Reader, end *responseEnd) {
if err := httpErrorFromResponse(src); err != nil {
if err := httpErrorFromResponse(statusCode, contentType, src); err != nil {
end.err = err
end.httpCode = httpStatusCodeFromRPC(err.Code())
}
}, nil
}
var meta responseMeta
contentType := headers.Get("Content-Type")
switch {
case contentType == "application/json":
meta.codec = CodecJSON
Expand Down

0 comments on commit 18f1828

Please sign in to comment.