From 18f18284ca59e086db7745f7658ab4073d6843cd Mon Sep 17 00:00:00 2001 From: Edward McFarlane Date: Wed, 20 Sep 2023 21:09:19 +0100 Subject: [PATCH] Add fallback for parsing invalid rest errors --- protocol_http.go | 38 +++++++++++++++++++++-- protocol_http_test.go | 71 ++++++++++++++++++++++++++++++++++++++++++- protocol_rest.go | 4 +-- 3 files changed, 108 insertions(+), 5 deletions(-) diff --git a/protocol_http.go b/protocol_http.go index 1ce5ee5..7239bf5 100644 --- a/protocol_http.go +++ b/protocol_http.go @@ -25,10 +25,12 @@ import ( "connectrpc.com/connect" "google.golang.org/genproto/googleapis/api/annotations" + httpbody "google.golang.org/genproto/googleapis/api/httpbody" "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/known/anypb" ) func httpStatusCodeFromRPC(code connect.Code) int { @@ -57,6 +59,26 @@ func httpStatusCodeFromRPC(code connect.Code) int { return codes[code] } +func httpStatusCodeToRPC(code int) connect.Code { + switch code { + case http.StatusOK: + return 0 // OK + case http.StatusUnauthorized: + return connect.CodeUnauthenticated // Unauthenticated + case http.StatusForbidden: + return connect.CodePermissionDenied // PermissionDenied + case http.StatusNotFound: + return connect.CodeUnimplemented // Unimplemented + case http.StatusTooManyRequests, + http.StatusBadGateway, + http.StatusServiceUnavailable, + http.StatusGatewayTimeout: + return connect.CodeUnavailable // Unavailable + default: + return connect.CodeUnknown // Unknown + } +} + func httpWriteError(rsp http.ResponseWriter, err error) { codec := protojson.MarshalOptions{ EmitUnpopulated: true, @@ -78,7 +100,10 @@ func httpWriteError(rsp http.ResponseWriter, err error) { _, _ = rsp.Write(bin) } -func httpErrorFromResponse(body io.Reader) *connect.Error { +func httpErrorFromResponse(statusCode int, contentType string, body io.Reader) *connect.Error { + if statusCode == http.StatusOK { + return nil + } codec := protojson.UnmarshalOptions{} body = io.LimitReader(body, 1024) bin, err := io.ReadAll(body) @@ -87,7 +112,16 @@ func httpErrorFromResponse(body io.Reader) *connect.Error { } var stat status.Status if err := codec.Unmarshal(bin, &stat); err != nil { - return connect.NewError(connect.CodeInternal, err) + body, err := anypb.New(&httpbody.HttpBody{ + ContentType: contentType, + Data: bin, + }) + if err != nil { + return connect.NewError(connect.CodeInternal, err) + } + stat.Details = append(stat.Details, body) + stat.Code = int32(httpStatusCodeToRPC(statusCode)) + stat.Message = http.StatusText(statusCode) } connectErr := connect.NewWireError( diff --git a/protocol_http_test.go b/protocol_http_test.go index b2727d6..b8bea0e 100644 --- a/protocol_http_test.go +++ b/protocol_http_test.go @@ -27,7 +27,12 @@ import ( testv1 "connectrpc.com/vanguard/internal/gen/vanguard/test/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/genproto/googleapis/api/httpbody" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/genproto/googleapis/rpc/status" + "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" ) func TestHTTPErrorWriter(t *testing.T) { @@ -46,10 +51,74 @@ func TestHTTPErrorWriter(t *testing.T) { assert.Equal(t, `{"code":16,"message":"test error: Hello, 世界","details":[]}`, out.String()) body := bytes.NewReader(rec.Body.Bytes()) - got := httpErrorFromResponse(body) + got := httpErrorFromResponse(http.StatusUnauthorized, "application/json", body) assert.Equal(t, cerr, got) } +func TestHTTPErrorFromResponse(t *testing.T) { + t.Parallel() + t.Run("empty", func(t *testing.T) { + t.Parallel() + var body bytes.Buffer + got := httpErrorFromResponse(http.StatusOK, "", &body) + assert.Nil(t, got) + }) + t.Run("jsonStatus", func(t *testing.T) { + t.Parallel() + errorInfo, err := anypb.New(&errdetails.ErrorInfo{ + Reason: "user is not authorized", + Domain: "vanguard.connectrpc.com", + Metadata: map[string]string{"key1": "value1"}, + }) + require.Nil(t, err) + stat := status.Status{ + Code: int32(connect.CodeUnauthenticated), + Message: "auth error", + Details: []*anypb.Any{errorInfo}, + } + out, err := protojson.Marshal(&stat) + require.Nil(t, err) + got := httpErrorFromResponse(http.StatusUnauthorized, "application/json", bytes.NewReader(out)) + assert.Equal(t, connect.CodeUnauthenticated, got.Code()) + assert.Equal(t, "auth error", got.Message()) + }) + t.Run("invalidStatus", func(t *testing.T) { + t.Parallel() + body := bytes.NewReader([]byte("unauthorized")) + got := httpErrorFromResponse(http.StatusUnauthorized, "application/json", body) + assert.Equal(t, connect.CodeUnauthenticated, got.Code()) + assert.Equal(t, "Unauthorized", got.Message()) + assert.Len(t, got.Details(), 1) + value, err := got.Details()[0].Value() + assert.NoError(t, err) + httpBody, ok := value.(*httpbody.HttpBody) + assert.True(t, ok) + assert.Equal(t, "application/json", httpBody.ContentType) + assert.Equal(t, []byte("unauthorized"), httpBody.Data) + }) + t.Run("invalidAny", func(t *testing.T) { + t.Parallel() + stat := status.Status{ + Code: int32(connect.CodeUnauthenticated), + Message: "auth error", + } + out, err := protojson.Marshal(&stat) + require.Nil(t, err) + out = append(out[:len(out)-1], []byte(`,"details":{"@type":"foo","value":"bar"}`)...) + got := httpErrorFromResponse(http.StatusUnauthorized, "application/json", bytes.NewReader(out)) + t.Log(got) + assert.Equal(t, connect.CodeUnauthenticated, got.Code()) + assert.Equal(t, "Unauthorized", got.Message()) + assert.Len(t, got.Details(), 1) + value, err := got.Details()[0].Value() + assert.NoError(t, err) + httpBody, ok := value.(*httpbody.HttpBody) + assert.True(t, ok) + assert.Equal(t, "application/json", httpBody.ContentType) + assert.Equal(t, out, httpBody.Data) + }) +} + func TestHTTPEncodePathValues(t *testing.T) { t.Parallel() diff --git a/protocol_rest.go b/protocol_rest.go index 79ba6f4..15c20e8 100644 --- a/protocol_rest.go +++ b/protocol_rest.go @@ -280,18 +280,18 @@ func (r restServerProtocol) addProtocolRequestHeaders(meta requestMeta, headers } func (r restServerProtocol) extractProtocolResponseHeaders(statusCode int, headers http.Header) (responseMeta, responseEndUnmarshaller, error) { + contentType := headers.Get("Content-Type") if statusCode/100 != 2 { return responseMeta{ end: &responseEnd{httpCode: statusCode}, }, func(_ Codec, src io.Reader, end *responseEnd) { - if err := httpErrorFromResponse(src); err != nil { + if err := httpErrorFromResponse(statusCode, contentType, src); err != nil { end.err = err end.httpCode = httpStatusCodeFromRPC(err.Code()) } }, nil } var meta responseMeta - contentType := headers.Get("Content-Type") switch { case contentType == "application/json": meta.codec = CodecJSON