Skip to content

Commit

Permalink
Response end work on buffers to avoid read
Browse files Browse the repository at this point in the history
  • Loading branch information
emcfarlane committed Sep 25, 2023
1 parent 18f1828 commit 2c93a93
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 26 deletions.
5 changes: 3 additions & 2 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package vanguard

import (
"bytes"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -219,9 +220,9 @@ type serverProtocolHandler interface {
}

// responseEndUnmarshaller populates the given responseEnd by unmarshalling
// information from the given reader. If unmarshalling needs to know the
// information from the given buffer. If unmarshalling needs to know the
// server's codec, it also provided as the first argument.
type responseEndUnmarshaller func(Codec, io.Reader, *responseEnd)
type responseEndUnmarshaller func(Codec, *bytes.Buffer, *responseEnd)

// clientProtocolEndMustBeInHeaders is an optional interface implemented
// by clientProtocolHandler instances to indicate if the end of an RPC
Expand Down
12 changes: 3 additions & 9 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ func (c connectUnaryServerProtocol) extractProtocolResponseHeaders(statusCode in
trailers := connectExtractUnaryTrailers(headers)

var endUnmarshaller responseEndUnmarshaller
if statusCode == http.StatusOK { //nolint:nestif
if statusCode == http.StatusOK {
respMeta.pendingTrailers = trailers
} else {
// Content-Type must be application/json for errors or else it's invalid
Expand All @@ -302,15 +302,9 @@ func (c connectUnaryServerProtocol) extractProtocolResponseHeaders(statusCode in
wasCompressed: respMeta.compression != "",
trailers: trailers,
}
endUnmarshaller = func(_ Codec, r io.Reader, end *responseEnd) {
// TODO: buffer size limit; use op.bufferPool
data, err := io.ReadAll(r)
if err != nil {
end.err = connect.NewError(connect.CodeInternal, err)
return
}
endUnmarshaller = func(_ Codec, buf *bytes.Buffer, end *responseEnd) {
var wireErr connectWireError
if err := json.Unmarshal(data, &wireErr); err != nil {
if err := json.Unmarshal(buf.Bytes(), &wireErr); err != nil {
end.err = connect.NewError(connect.CodeInternal, err)
return
}
Expand Down
13 changes: 4 additions & 9 deletions protocol_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
package vanguard

import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
Expand Down Expand Up @@ -100,21 +100,16 @@ func httpWriteError(rsp http.ResponseWriter, err error) {
_, _ = rsp.Write(bin)
}

func httpErrorFromResponse(statusCode int, contentType string, body io.Reader) *connect.Error {
func httpErrorFromResponse(statusCode int, contentType string, src *bytes.Buffer) *connect.Error {
if statusCode == http.StatusOK {
return nil
}
codec := protojson.UnmarshalOptions{}
body = io.LimitReader(body, 1024)
bin, err := io.ReadAll(body)
if err != nil {
return connect.NewError(connect.CodeInternal, err)
}
var stat status.Status
if err := codec.Unmarshal(bin, &stat); err != nil {
if err := codec.Unmarshal(src.Bytes(), &stat); err != nil {
body, err := anypb.New(&httpbody.HttpBody{
ContentType: contentType,
Data: bin,
Data: src.Bytes(),
})
if err != nil {
return connect.NewError(connect.CodeInternal, err)
Expand Down
8 changes: 4 additions & 4 deletions protocol_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestHTTPErrorWriter(t *testing.T) {
assert.NoError(t, json.Compact(&out, rec.Body.Bytes()))
assert.Equal(t, `{"code":16,"message":"test error: Hello, 世界","details":[]}`, out.String())

body := bytes.NewReader(rec.Body.Bytes())
body := bytes.NewBuffer(rec.Body.Bytes())
got := httpErrorFromResponse(http.StatusUnauthorized, "application/json", body)
assert.Equal(t, cerr, got)
}
Expand Down Expand Up @@ -78,13 +78,13 @@ func TestHTTPErrorFromResponse(t *testing.T) {
}
out, err := protojson.Marshal(&stat)
require.Nil(t, err)
got := httpErrorFromResponse(http.StatusUnauthorized, "application/json", bytes.NewReader(out))
got := httpErrorFromResponse(http.StatusUnauthorized, "application/json", bytes.NewBuffer(out))
assert.Equal(t, connect.CodeUnauthenticated, got.Code())
assert.Equal(t, "auth error", got.Message())
})
t.Run("invalidStatus", func(t *testing.T) {
t.Parallel()
body := bytes.NewReader([]byte("unauthorized"))
body := bytes.NewBufferString("unauthorized")
got := httpErrorFromResponse(http.StatusUnauthorized, "application/json", body)
assert.Equal(t, connect.CodeUnauthenticated, got.Code())
assert.Equal(t, "Unauthorized", got.Message())
Expand All @@ -105,7 +105,7 @@ func TestHTTPErrorFromResponse(t *testing.T) {
out, err := protojson.Marshal(&stat)
require.Nil(t, err)
out = append(out[:len(out)-1], []byte(`,"details":{"@type":"foo","value":"bar"}`)...)
got := httpErrorFromResponse(http.StatusUnauthorized, "application/json", bytes.NewReader(out))
got := httpErrorFromResponse(http.StatusUnauthorized, "application/json", bytes.NewBuffer(out))
t.Log(got)
assert.Equal(t, connect.CodeUnauthenticated, got.Code())
assert.Equal(t, "Unauthorized", got.Message())
Expand Down
5 changes: 3 additions & 2 deletions protocol_rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package vanguard

import (
"bytes"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -284,8 +285,8 @@ func (r restServerProtocol) extractProtocolResponseHeaders(statusCode int, heade
if statusCode/100 != 2 {
return responseMeta{
end: &responseEnd{httpCode: statusCode},
}, func(_ Codec, src io.Reader, end *responseEnd) {
if err := httpErrorFromResponse(statusCode, contentType, src); err != nil {
}, func(_ Codec, buf *bytes.Buffer, end *responseEnd) {
if err := httpErrorFromResponse(statusCode, contentType, buf); err != nil {
end.err = err
end.httpCode = httpStatusCodeFromRPC(err.Code())
}
Expand Down

0 comments on commit 2c93a93

Please sign in to comment.