Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes most TODO #82

Merged
merged 5 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,7 @@ func (o *operation) validate(mux *Mux, codecs codecMap) error {
}

if o.server.protocol.protocol() == ProtocolREST {
// REST always uses JSON.
// TODO: Allow non-JSON encodings with REST? Would require registering content-types with codecs.
// Would also require figuring out how to (un)marshal things other than messages when a body
// path indicates a non-message field (do-able with JSON, but maybe non-starter with proto?)
//
// REST always defaults to JSON.
// NB: This is fine to set even if a custom content-type is used via
// the use of google.api.HttpBody. The actual content-type and body
// data will be written via serverBodyPreparer implementation.
Expand All @@ -403,8 +399,8 @@ func (o *operation) validate(mux *Mux, codecs codecMap) error {
if _, supportsCompression := o.methodConf.compressorNames[reqMeta.compression]; supportsCompression {
o.server.reqCompression = o.client.reqCompression
}
// else: we'll just decompress and not recompress
// TODO: should we instead pick a supported compression scheme (if there is one)?
// If the server doesn't support the compression scheme, we'll just
// decompress and not recompress.
}

o.isValid = true // Successfully validated!
Expand Down Expand Up @@ -580,8 +576,6 @@ func (o *operation) resolveMethod(mux *Mux) error {
default:
methodConf := mux.methods[uriPath]
if methodConf == nil {
// TODO: if the service is known, but the method is not, we should send to the client
// a proper RPC error (encoded per protocol handler) with an Unimplemented code.
return errNotFound
}
o.restTarget = methodConf.httpRule
Expand Down
2 changes: 1 addition & 1 deletion protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ type serverEnvelopedProtocolHandler interface {
// The given codec represents the sub-format used to send
// the request to the server (which may be used to decode
// the error).
decodeEndFromMessage(*operation, io.Reader) (responseEnd, error)
decodeEndFromMessage(*operation, *bytes.Buffer) (responseEnd, error)
}

// requestLineBuilder is an optional interface implemented by
Expand Down
9 changes: 1 addition & 8 deletions protocol_connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,14 +569,7 @@ func (c connectStreamServerProtocol) encodeEnvelope(env envelope) envelopeBytes
return envBytes
}

func (c connectStreamServerProtocol) decodeEndFromMessage(op *operation, reader io.Reader) (responseEnd, error) {
// TODO: buffer size limit for headers/trailers; should use http.DefaultMaxHeaderBytes if not configured
buffer := op.bufferPool.Get()
defer op.bufferPool.Put(buffer)
_, err := buffer.ReadFrom(reader)
if err != nil {
return responseEnd{}, err
}
func (c connectStreamServerProtocol) decodeEndFromMessage(_ *operation, buffer *bytes.Buffer) (responseEnd, error) {
var streamEnd connectStreamEnd
if err := json.Unmarshal(buffer.Bytes(), &streamEnd); err != nil {
return responseEnd{}, err
Expand Down
13 changes: 3 additions & 10 deletions protocol_grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func (g grpcServerProtocol) encodeEnvelope(env envelope) envelopeBytes {
return envBytes
}

func (g grpcServerProtocol) decodeEndFromMessage(_ *operation, _ io.Reader) (responseEnd, error) {
func (g grpcServerProtocol) decodeEndFromMessage(_ *operation, _ *bytes.Buffer) (responseEnd, error) {
return responseEnd{}, errors.New("gRPC protocol does not allow embedding result/trailers in body")
}

Expand Down Expand Up @@ -184,7 +184,7 @@ func (g grpcWebClientProtocol) encodeEnd(op *operation, end *responseEnd, writer
buffer := op.bufferPool.Get()
defer op.bufferPool.Put(buffer)
_ = trailers.Write(buffer)
// TODO: compress?
// TODO: Send envelope compressed if possible.
env := envelope{trailer: true, length: uint32(buffer.Len())}
envBytes := g.encodeEnvelope(env)
_, _ = writer.Write(envBytes[:])
Expand Down Expand Up @@ -254,14 +254,7 @@ func (g grpcWebServerProtocol) encodeEnvelope(env envelope) envelopeBytes {
return grpcServerProtocol{}.encodeEnvelope(env)
}

func (g grpcWebServerProtocol) decodeEndFromMessage(op *operation, reader io.Reader) (responseEnd, error) {
// TODO: buffer size limit for headers/trailers; should use http.DefaultMaxHeaderBytes if not configured
buffer := op.bufferPool.Get()
defer op.bufferPool.Put(buffer)
_, err := buffer.ReadFrom(reader)
if err != nil {
return responseEnd{}, err
}
func (g grpcWebServerProtocol) decodeEndFromMessage(_ *operation, buffer *bytes.Buffer) (responseEnd, error) {
headerLines := bytes.Split(buffer.Bytes(), []byte{'\r', '\n'})
trailers := make(http.Header, len(headerLines))
for i, headerLine := range headerLines {
Expand Down
48 changes: 31 additions & 17 deletions protocol_rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ func (r restClientProtocol) acceptsStreamType(op *operation, streamType connect.
case connect.StreamTypeClient:
return restHTTPBodyRequest(op)
case connect.StreamTypeServer:
// TODO: support server streams even when body is not google.api.HttpBody
return restHTTPBodyResponse(op)
default:
return false
}
}

func (r restClientProtocol) endMustBeInHeaders() bool {
// TODO: when we support server streams over REST, this should return false when streaming
// TODO: when we support server streams over REST, this should return
// false when streaming
return true
}

Expand All @@ -82,24 +82,23 @@ func (r restClientProtocol) extractProtocolRequestHeaders(op *operation, headers
headers.Del("Content-Type")

if timeoutStr := headers.Get("X-Server-Timeout"); timeoutStr != "" {
timeout, err := strconv.ParseFloat(timeoutStr, 64)
timeout, err := restDecodeTimeout(timeoutStr)
if err != nil {
return requestMeta{}, err
}
reqMeta.timeout = time.Duration(timeout * float64(time.Second))
reqMeta.timeout = timeout
reqMeta.hasTimeout = true
}
return reqMeta, nil
}

func (r restClientProtocol) addProtocolResponseHeaders(meta responseMeta, headers http.Header) int {
isErr := meta.end != nil && meta.end.err != nil
// TODO: this formulation might only be valid when meta.codec is JSON; support other codecs.
// Headers are only set if they are not already set, specially to allow
// for google.api.HttpBody payloads.
// Only JSON is supported for now unless using google.api.HttpBody
// payloads which override the content-type.
if headers["Content-Type"] == nil {
headers["Content-Type"] = []string{"application/" + meta.codec}
}
// TODO: Content-Encoding to compress error, too?
if !isErr && meta.compression != "" {
headers["Content-Encoding"] = []string{meta.compression}
}
Expand All @@ -126,12 +125,9 @@ func (r restClientProtocol) encodeEnd(op *operation, end *responseEnd, writer io
stat := grpcStatusFromError(cerr)
bin, err := op.client.codec.MarshalAppend(nil, stat)
if err != nil {
// TODO: This is always uses JSON whereas above we use the given codec.
// If/when we support codecs for REST other than JSON, what should
// we do here?
bin = []byte(`{"code": 13, "message": ` + strconv.Quote("failed to marshal end error: "+err.Error()) + `}`)
// Hardcode the error to be a JSON-encoded gRPC status.
emcfarlane marked this conversation as resolved.
Show resolved Hide resolved
bin = []byte(`{"code":13,"message":"failed to marshal end error"}`)
}
// TODO: compress?
_, _ = writer.Write(bin)
return nil
}
Expand Down Expand Up @@ -272,9 +268,8 @@ func (r restServerProtocol) addProtocolRequestHeaders(meta requestMeta, headers
if len(meta.acceptCompression) != 0 {
headers["Accept-Encoding"] = []string{strings.Join(meta.acceptCompression, ", ")}
}
if meta.timeout != 0 {
// Encode timeout as a float in seconds.
value := strconv.FormatFloat(meta.timeout.Seconds(), 'E', -1, 64)
if meta.hasTimeout {
value := restEncodeTimeout(meta.timeout)
headers["X-Server-Timeout"] = []string{value}
}
}
Expand Down Expand Up @@ -393,14 +388,33 @@ func (r restServerProtocol) requestLine(op *operation, req proto.Message) (urlPa
urlPath = path
queryParams = query.Encode()
includeBody = op.restTarget.requestBodyFields != nil // can be len(0) if body is '*'
// TODO: Should this return an error if URL (path + query string) is greater than op.methodConf.maxGetURLSz?
return urlPath, queryParams, op.restTarget.method, includeBody, nil
}

func (r restServerProtocol) String() string {
return protocolNameREST
}

// Decode timeout as a float in seconds from X-Server-Timeout header.
func restDecodeTimeout(timeout string) (time.Duration, error) {
if timeout == "" {
return 0, nil
}
val, err := strconv.ParseFloat(timeout, 64)
if err != nil {
return 0, fmt.Errorf("invalid timeout %q: %w", timeout, err)
}
return time.Duration(val * float64(time.Second)), nil
}

// Encode timeout as a float in seconds for X-Server-Timeout header.
func restEncodeTimeout(timeout time.Duration) string {
if timeout == 0 {
return ""
}
return strconv.FormatFloat(timeout.Seconds(), 'f', -1, 64)
}

func restHTTPBodyRequest(op *operation) bool {
return restIsHTTPBody(op.methodConf.descriptor.Input(), op.restTarget.requestBodyFields)
}
Expand Down
1 change: 1 addition & 0 deletions vanguard_restxrpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ func TestMux_RESTxRPC(t *testing.T) {
for key, values := range input.meta {
req.Header[key] = values
}
req.Header["X-Server-Timeout"] = []string{"30"}
if isCompressed {
req.Header["Content-Encoding"] = []string{comp.Name()}
}
Expand Down
4 changes: 0 additions & 4 deletions vanguard_rpcxrpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,6 @@ func TestMux_RPCxRPC(t *testing.T) {
},
},
},
// TODO: Add more tests -- more permutations to catch things like trailers-only responses in gRPC,
// empty client streams, empty server streams
// TODO: Exercise Connect GET for unary operations with Connect client
// TODO: Verify timeouts are propagated correctly
}
for _, opts := range testOpts {
opts := opts
Expand Down
56 changes: 55 additions & 1 deletion vanguard_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ import (
"google.golang.org/protobuf/types/known/emptypb"
)

const (
defaultTestTimeout = 30 * time.Second
)

func TestMux_BufferTooLargeFails(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -628,9 +632,12 @@ func TestMux_ConnectGetUsesPostIfRequestTooLarge(t *testing.T) {
connect.WithHTTPGetMaxURLSize(512, false),
connect.WithSendGzip(),
)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

req := connect.NewRequest(largeRequest)
req.Header().Set("Test", t.Name()) // must set this for interceptor to work
_, err := client.GetBook(context.Background(), req)
_, err := client.GetBook(ctx, req)
// No error means it made through above interceptor unscathed
// (so server handler got a POST).
require.NoError(t, err)
Expand Down Expand Up @@ -1611,6 +1618,9 @@ func TestRuleSelector(t *testing.T) {
}))

ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, defaultTestTimeout)
defer cancel()

req, err := http.NewRequestWithContext(ctx, http.MethodGet, "/v1/selector/shelves/123/books/456", http.NoBody)
require.NoError(t, err)
req.Header.Set("Message", "hello")
Expand Down Expand Up @@ -1792,6 +1802,9 @@ func (i *testInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
ctx context.Context,
req connect.AnyRequest,
) (_ connect.AnyResponse, resultError error) {
if err := assertTestTimeoutEncoded(ctx); err != nil {
return nil, err
}
val := req.Header().Get("test")
if val == "" {
return next(ctx, req)
Expand Down Expand Up @@ -1874,6 +1887,9 @@ func (i *testInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc
ctx context.Context,
conn connect.StreamingHandlerConn,
) (resultError error) {
if err := assertTestTimeoutEncoded(ctx); err != nil {
return err
}
val := conn.RequestHeader().Get("test")
if val == "" {
return next(ctx, conn)
Expand Down Expand Up @@ -2057,6 +2073,19 @@ func (i *testInterceptor) restUnaryHandler(
http.Error(rsp, "invalid test header", http.StatusInternalServerError)
return
}
timeoutStr := req.Header.Get("X-Server-Timeout")
timeout, err := restDecodeTimeout(timeoutStr)
if err != nil {
http.Error(rsp, "invalid timeout header", http.StatusInternalServerError)
return
}
ctx, cancel := context.WithTimeout(req.Context(), timeout)
defer cancel()
if err := assertTestTimeoutEncoded(ctx); err != nil {
http.Error(rsp, err.Error(), http.StatusInternalServerError)
return
}
req = req.WithContext(ctx)
if err := handler(stream, rsp, req); err != nil {
stream.T.Error(err)
http.Error(rsp, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -2178,6 +2207,8 @@ func outputFromUnary[Req, Resp any](
headers http.Header,
reqs []proto.Message,
) (http.Header, []proto.Message, http.Header, error) {
ctx, cancel := context.WithTimeout(ctx, defaultTestTimeout)
defer cancel()
if len(reqs) != 1 {
return nil, nil, nil, fmt.Errorf("unary method takes exactly 1 request but got %d", len(reqs))
}
Expand All @@ -2200,6 +2231,8 @@ func outputFromServerStream[Req, Resp any](
headers http.Header,
reqs []proto.Message,
) (http.Header, []proto.Message, http.Header, error) {
ctx, cancel := context.WithTimeout(ctx, defaultTestTimeout)
defer cancel()
if len(reqs) != 1 {
return nil, nil, nil, fmt.Errorf("unary method takes exactly 1 request but got %d", len(reqs))
}
Expand All @@ -2226,6 +2259,8 @@ func outputFromClientStream[Req, Resp any](
headers http.Header,
reqs []proto.Message,
) (http.Header, []proto.Message, http.Header, error) {
ctx, cancel := context.WithTimeout(ctx, defaultTestTimeout)
defer cancel()
str := method(ctx)
for k, v := range headers {
str.RequestHeader()[k] = v
Expand Down Expand Up @@ -2255,6 +2290,8 @@ func outputFromBidiStream[Req, Resp any](
headers http.Header,
reqs []proto.Message,
) (http.Header, []proto.Message, http.Header, error) {
ctx, cancel := context.WithTimeout(ctx, defaultTestTimeout)
defer cancel()
str := method(ctx)
defer func() {
_ = str.CloseResponse()
Expand Down Expand Up @@ -2563,3 +2600,20 @@ func newConnectError(code connect.Code, msg string) *connect.Error {
err.Meta()
return err
}

// assert a 30 second timeout has been set.
func assertTestTimeoutEncoded(ctx context.Context) error {
now := time.Now()
deadline, ok := ctx.Deadline()
if !ok {
return errors.New("context should have deadline")
}
if deadline.After(now.Add(defaultTestTimeout)) {
return errors.New("context deadline should be 30 seconds")
}
// Allow a little bit of slop.
if deadline.Before(now.Add(defaultTestTimeout - 5*time.Second)) {
return errors.New("context deadline should be at least 20 seconds")
}
return nil
}