diff --git a/buffer_pool.go b/buffer_pool.go deleted file mode 100644 index d9374ee..0000000 --- a/buffer_pool.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2023 Buf Technologies, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package vanguard - -import ( - "bytes" - "sync" -) - -const ( - initialBufferSize = 512 - maxRecycleBufferSize = 8 * 1024 * 1024 // if >8MiB, don't hold onto a buffer -) - -type bufferPool struct { - sync.Pool -} - -func (b *bufferPool) Get() *bytes.Buffer { - if buffer, ok := b.Pool.Get().(*bytes.Buffer); ok { - buffer.Reset() - return buffer - } - return bytes.NewBuffer(make([]byte, 0, initialBufferSize)) -} - -func (b *bufferPool) Put(buffer *bytes.Buffer) { - if buffer.Cap() > maxRecycleBufferSize { - return - } - b.Pool.Put(buffer) -} - -func (b *bufferPool) Wrap(data []byte, orig *bytes.Buffer) *bytes.Buffer { - if cap(data) > orig.Cap() { - // Original buffer was too small, so we had to grow its slice to - // compute data. Replace the buffer with the larger, - // newly-allocated slice. - return bytes.NewBuffer(data) - } - // The buffer from the pool was large enough so no growing was necessary. - // That means this should be a no-op since the buffer, under the hood, will - // copy the given data to its internal slice, which should be the exact - // same slice. - orig.Reset() - orig.Write(data) - return orig -} diff --git a/buffers.go b/buffers.go new file mode 100644 index 0000000..dcbfb98 --- /dev/null +++ b/buffers.go @@ -0,0 +1,317 @@ +// Copyright 2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vanguard + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "sync" + + "connectrpc.com/connect" + "google.golang.org/protobuf/proto" +) + +const ( + initialBufferSize = bytes.MinRead + maxRecycleBufferSize = 8 * 1024 * 1024 // if >8MiB, don't hold onto a buffer + // chunkMessageSize for google.api.HttpBody messages will be chunked into + // multiple messages of this size. It should be large enough to avoid + // excessive overhead, but small enough to avoid holding onto large buffers. + chunkMessageSize = 4 * 1024 * 1024 // 4MiB +) + +type bufferPool struct { + sync.Pool +} + +func (b *bufferPool) Get() *bytes.Buffer { + if buffer, ok := b.Pool.Get().(*bytes.Buffer); ok { + buffer.Reset() + return buffer + } + return bytes.NewBuffer(make([]byte, 0, initialBufferSize)) +} + +func (b *bufferPool) Put(buffer *bytes.Buffer) { + if buffer.Cap() > maxRecycleBufferSize { + return + } + b.Pool.Put(buffer) +} + +type readMode int + +const ( + readModeSize = readMode(iota) + readModeEOF + readModeChunk +) + +type srcParams struct { + Size uint32 // Size or message + ReadMode readMode // Read size, unil EOF, or chunked + IsEOF bool // Last bytes of stream + IsTrailer bool // Trailer message, call DecodeTrailer + IsCompressed bool // Compressed message, call Decompress +} + +func (s srcParams) String() string { + return fmt.Sprintf("srcParams{Size: %d, IsEOF: %t, IsTrailer: %t, IsCompressed: %t}", s.Size, s.IsEOF, s.IsTrailer, s.IsCompressed) +} + +type dstParams struct { + Flags uint8 // Envelope flags + IsEnvelope bool // Set envelope prefix on messages + IsCompressed bool // Compress message, call Compress + WaitForTrailer bool // Wait for trailers, buffering messages +} + +func (d dstParams) String() string { + return fmt.Sprintf("dstParams{Flags: %d, IsEnvelope: %t, IsCompressed: %t, WaitForTrailer: %t}", d.Flags, d.IsEnvelope, d.IsCompressed, d.WaitForTrailer) +} + +type messageStage int + +const ( + stageEmpty = messageStage(iota) + stageRead // TODO: docs + stageBuffered // TODO: docs + stageEOF // TODO: docs +) + +type messageBuffer struct { + Buf *bytes.Buffer + Index int // Index of message in the stream + Src srcParams + Dst dstParams + + offset int + envOffset int + size int + stage messageStage +} + +func (m *messageBuffer) Read(data []byte) (n int, err error) { + if m.stage != stageBuffered { + return 0, errorf(connect.CodeInternal, "message not buffered") + } + if m.Dst.IsEnvelope && m.envOffset < 5 { + env := makeEnvelope(m.Dst.Flags, m.size) + envN := copy(data, env[m.envOffset:]) + data = data[envN:] + n += envN + m.envOffset += envN + if m.envOffset < 5 { + return n, nil + } + } + src := m.Buf.Bytes()[m.offset:m.size] + wroteN := copy(data, src) + m.offset += wroteN + n += wroteN + if n == 0 && len(data) > 0 { + err = io.EOF + } + return n, err +} + +func (m *messageBuffer) WriteTo(dst io.Writer) (n int64, err error) { + if m.stage != stageBuffered { + return 0, errorf(connect.CodeInternal, "message not buffered") + } + if m.Dst.IsEnvelope && m.envOffset < 5 { + env := makeEnvelope(m.Dst.Flags, m.size) + envN, err := dst.Write(env[m.envOffset:]) + n += int64(envN) + m.envOffset += envN + if err != nil { + return n, err + } + if m.envOffset < 5 { + return n, io.ErrShortWrite + } + } + src := m.Buf.Bytes()[m.offset:m.size] + wroteN, err := dst.Write(src) + m.offset += wroteN + n += int64(wroteN) + return n, err +} + +// Flush the first message from the buffer and reclaim size by shifting any +// excess data to the front of the buffer. +func (m *messageBuffer) Flush() { + // Shift any excess data to the front of the buffer. + excess := m.Buf.Bytes()[m.size:] + m.Buf.Reset() + _, _ = m.Buf.Write(excess) + + m.Src = srcParams{} + m.Dst = dstParams{} + m.Index++ + m.offset = 0 + m.envOffset = 0 + m.size = 0 + m.stage = stageEmpty +} + +func (m *messageBuffer) Convert(buffers *bufferPool, msg proto.Message, src, dst encoding) error { + srcCompressor := src.Compressor + if !m.Src.IsCompressed { + srcCompressor = nil + } + dstCompressor := dst.Compressor + if !m.Dst.IsCompressed { + dstCompressor = nil + } + if err := convertBuffer( + buffers, + m.Buf, + srcCompressor, + src.Codec, + msg, + dst.Codec, + dstCompressor, + ); err != nil { + return err + } + m.size = m.Buf.Len() + m.stage = stageBuffered + return nil +} + +// encode the message into the buffer, compressing and encoding as needed. +func encodeBuffer(buffers *bufferPool, buf *bytes.Buffer, msg proto.Message, codec Codec, comp compressor) error { //nolint:unused + // Force re-encoding. + // Force re-compression, if needed. + return convertBuffer(buffers, buf, nil, nil, msg, codec, comp) +} + +// decode the message from the buffer, decompressing and unmarshalling as needed. +func decodeBuffer(buffers *bufferPool, buf *bytes.Buffer, msg proto.Message, codec Codec, comp compressor) error { + // Force decompression, if needed. + // Force decoding. + return convertBuffer(buffers, buf, comp, codec, msg, nil, nil) +} + +// convert the message in the buffer to the new compression and encoding. +// The message will only be used if required to convert the encoding. +func convertBuffer( + buffers *bufferPool, + buf *bytes.Buffer, + srcCompressor compressor, + srcCodec Codec, + msg proto.Message, + dstCodec Codec, + dstCompressor compressor, +) error { + var tmp *bytes.Buffer + defer func() { + if tmp != nil { + buffers.Put(tmp) + } + }() + needsRecoding := getName(srcCodec) != getName(dstCodec) + needsRecompressing := getName(srcCompressor) != getName(dstCompressor) || needsRecoding + if srcCompressor != nil && needsRecompressing { + // Decompress + tmp = buffers.Get() + // Read from m, don't mutate m.buf + if err := srcCompressor.decompress(tmp, buf); err != nil { + return err + } + *buf, *tmp = *tmp, *buf // swap buffers + } + if srcCodec != nil && needsRecoding { + // Decode + if err := srcCodec.Unmarshal(buf.Bytes(), msg); err != nil { + return err + } + } + if dstCodec != nil && needsRecoding { + // Encode + buf.Reset() + if err := marshal(buf, msg, dstCodec); err != nil { + return err + } + } + if dstCompressor != nil && needsRecompressing { + // Compress + if tmp == nil { + tmp = buffers.Get() + } else { + tmp.Reset() + } + if err := dstCompressor.compress(tmp, buf); err != nil { + return err + } + *buf, *tmp = *tmp, *buf // swap buffers + } + return nil +} + +func getName(thing interface{ Name() string }) string { + if thing == nil { + return "" + } + return thing.Name() +} + +func readEnvelope(src io.Reader) (uint8, uint32, error) { + var env envelopeBytes + if _, err := io.ReadFull(src, env[:]); err != nil { + return 0, 0, errorf(connect.CodeInternal, "read envelope: %w", err) + } + flags := env[0] + size := binary.BigEndian.Uint32(env[1:]) + return flags, size, nil +} + +// read a bit from the src into the dst, growing the dst if needed. +// This is used to check for EOF when reading messages. +func read(dst *bytes.Buffer, src io.Reader) (int, error) { + dst.Grow(bytes.MinRead) + b := dst.Bytes()[dst.Len() : dst.Len()+bytes.MinRead] + n, err := src.Read(b) + _, _ = dst.Write(b[:n]) // noop + return n, err +} + +func marshal(dst *bytes.Buffer, msg proto.Message, codec Codec) error { + raw, err := codec.MarshalAppend(dst.Bytes(), msg) + if err != nil { + return err + } + if cap(raw) > dst.Cap() { + // Dst buffer was too small, so MarshalAppend grew the slice. + // Replace the buffer with the larger, newly-allocated slice. + *dst = *bytes.NewBuffer(raw) + } else { + // The buffer from the pool was large enough, MarshalAppend didn't allocate. + // Copy to the same byte slice is a nop. + dst.Write(raw[dst.Len():]) + } + return nil +} + +// makeEnvelope returns a byte array representing an encoded envelope. +func makeEnvelope(flags uint8, size int) [5]byte { + prefix := [5]byte{} + prefix[0] = flags + binary.BigEndian.PutUint32(prefix[1:5], uint32(size)) + return prefix +} diff --git a/buffers_test.go b/buffers_test.go new file mode 100644 index 0000000..f144a24 --- /dev/null +++ b/buffers_test.go @@ -0,0 +1,247 @@ +// Copyright 2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vanguard + +import ( + "bytes" + "io" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/reflect/protoregistry" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +func TestMessageConvert(t *testing.T) { + t.Parallel() + buffers := &bufferPool{} + codecJSON := DefaultJSONCodec(protoregistry.GlobalTypes) + codecProto := DefaultProtoCodec(protoregistry.GlobalTypes) + compGzip := newCompressionPool(CompressionGzip, DefaultGzipCompressor, DefaultGzipDecompressor) + + encode := func(t *testing.T, codec Codec, comp *compressionPool, msg proto.Message) string { + t.Helper() + data, err := codec.MarshalAppend(nil, msg) + require.NoError(t, err) + if comp == nil { + return string(data) + } + var buf bytes.Buffer + require.NoError(t, comp.compress(&buf, bytes.NewBuffer(data))) + return buf.String() + } + + testCases := []struct { + name string + src, dst string + srcCodec, dstCodec Codec + srcComp, dstComp *compressionPool + msg proto.Message + wantMsg proto.Message + wantMarshalCalls int + wantUnmarshalCalls int + wantCompressCalls int + wantDecompressCalls int + wantErr string + }{{ + name: "SameCodec", + src: `"hello"`, dst: `"hello"`, + srcCodec: codecJSON, dstCodec: codecJSON, + msg: &wrapperspb.StringValue{}, + wantMsg: &wrapperspb.StringValue{}, // Not decoded + }, { + name: "MustDecode", + src: `"hello"`, dst: `"hello"`, + srcCodec: codecJSON, dstCodec: nil, + msg: &wrapperspb.StringValue{}, + wantMsg: &wrapperspb.StringValue{Value: "hello"}, // decoded + wantUnmarshalCalls: 1, + }, { + name: "DiffCodec", + src: `"hello"`, dst: encode(t, codecProto, nil, &wrapperspb.StringValue{Value: "hello"}), + srcCodec: codecJSON, dstCodec: codecProto, + msg: &wrapperspb.StringValue{}, + wantMsg: &wrapperspb.StringValue{Value: "hello"}, + wantUnmarshalCalls: 1, + wantMarshalCalls: 1, + }, { + name: "Compress", + src: `"hello"`, dst: encode(t, codecProto, compGzip, &wrapperspb.StringValue{Value: "hello"}), + srcCodec: codecJSON, dstCodec: codecProto, + srcComp: nil, dstComp: compGzip, + msg: &wrapperspb.StringValue{}, + wantMsg: &wrapperspb.StringValue{Value: "hello"}, + wantUnmarshalCalls: 1, + wantMarshalCalls: 1, + wantCompressCalls: 1, + }, { + name: "SameCodecCompress", + src: `"hello"`, dst: encode(t, codecJSON, compGzip, &wrapperspb.StringValue{Value: "hello"}), + srcCodec: codecJSON, dstCodec: codecJSON, + srcComp: nil, dstComp: compGzip, + wantCompressCalls: 1, + }, { + name: "Decompress", + src: encode(t, codecProto, compGzip, &wrapperspb.StringValue{Value: "hello"}), dst: `"hello"`, + srcCodec: codecProto, dstCodec: codecJSON, + srcComp: compGzip, dstComp: nil, + msg: &wrapperspb.StringValue{}, + wantMsg: &wrapperspb.StringValue{Value: "hello"}, + wantUnmarshalCalls: 1, + wantMarshalCalls: 1, + wantDecompressCalls: 1, + }, { + name: "SameCodecDecompress", + src: encode(t, codecJSON, compGzip, &wrapperspb.StringValue{Value: "hello"}), + dst: `"hello"`, + srcCodec: codecJSON, dstCodec: codecJSON, + srcComp: compGzip, dstComp: nil, + wantDecompressCalls: 1, + }, { + name: "MustDecodeDecompress", + src: encode(t, codecJSON, compGzip, &wrapperspb.StringValue{Value: "hello"}), + dst: encode(t, codecJSON, nil, &wrapperspb.StringValue{Value: "hello"}), + srcCodec: codecJSON, dstCodec: nil, + srcComp: compGzip, dstComp: nil, + msg: &wrapperspb.StringValue{}, + wantMsg: &wrapperspb.StringValue{Value: "hello"}, + wantUnmarshalCalls: 1, + wantDecompressCalls: 1, + }, { + name: "ForceRecode", + src: `""`, dst: `"from msg"`, + srcCodec: nil, dstCodec: codecJSON, + srcComp: nil, dstComp: nil, + msg: &wrapperspb.StringValue{Value: "from msg"}, + wantMarshalCalls: 1, + }, { + name: "ForceRecodeRecompress", + src: `""`, dst: encode(t, codecJSON, compGzip, &wrapperspb.StringValue{Value: "from msg"}), + srcCodec: nil, dstCodec: codecJSON, + srcComp: nil, dstComp: compGzip, + msg: &wrapperspb.StringValue{Value: "from msg"}, + wantMarshalCalls: 1, + wantCompressCalls: 1, + }, { + name: "RecompressAndRecode", + src: encode(t, codecJSON, compGzip, &wrapperspb.StringValue{Value: "hello"}), + dst: encode(t, codecProto, compGzip, &wrapperspb.StringValue{Value: "hello"}), + srcCodec: codecJSON, dstCodec: codecProto, + srcComp: compGzip, dstComp: compGzip, + msg: &wrapperspb.StringValue{}, + wantMsg: &wrapperspb.StringValue{Value: "hello"}, + wantUnmarshalCalls: 1, + wantDecompressCalls: 1, + wantMarshalCalls: 1, + wantCompressCalls: 1, + }} + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + var ( + marshalCalls, unmarshalCalls int + compressCalls, decompressCalls int + ) + var srcComp compressor + if testCase.srcComp != nil { + srcComp = countCompressor{ + compressionPool: testCase.srcComp, + compressorCalls: &compressCalls, + decompressorCalls: &decompressCalls, + } + } + var srcCodec Codec + if testCase.srcCodec != nil { + srcCodec = countCodec{ + Codec: testCase.srcCodec, + marshalCalls: &marshalCalls, + unmarshalCalls: &unmarshalCalls, + } + } + var dstComp compressor + if testCase.dstComp != nil { + dstComp = countCompressor{ + compressionPool: testCase.dstComp, + compressorCalls: &compressCalls, + decompressorCalls: &decompressCalls, + } + } + var dstCodec Codec + if testCase.dstCodec != nil { + dstCodec = countCodec{ + Codec: testCase.dstCodec, + marshalCalls: &marshalCalls, + unmarshalCalls: &unmarshalCalls, + } + } + + got := testCase.msg + buf := bytes.NewBufferString(testCase.src) + if err := convertBuffer( + buffers, + buf, + srcComp, + srcCodec, + got, + dstCodec, + dstComp, + ); err != nil { + assert.EqualError(t, err, testCase.wantErr) + return + } + if testCase.wantMsg != nil { + assert.Empty(t, cmp.Diff(testCase.wantMsg, got, protocmp.Transform())) + } + assert.Equal(t, testCase.dst, buf.String()) + assert.Equal(t, testCase.wantMarshalCalls, marshalCalls, "marshalCalls") + assert.Equal(t, testCase.wantUnmarshalCalls, unmarshalCalls, "unmarshalCalls") + assert.Equal(t, testCase.wantCompressCalls, compressCalls, "compressCalls") + assert.Equal(t, testCase.wantDecompressCalls, decompressCalls, "decompressCalls") + }) + } +} + +type countCodec struct { + Codec + marshalCalls, unmarshalCalls *int +} + +func (c countCodec) MarshalAppend(b []byte, msg proto.Message) ([]byte, error) { + *c.marshalCalls++ + return c.Codec.MarshalAppend(b, msg) +} +func (c countCodec) Unmarshal(b []byte, msg proto.Message) error { + *c.unmarshalCalls++ + return c.Codec.Unmarshal(b, msg) +} + +type countCompressor struct { + *compressionPool + compressorCalls, decompressorCalls *int +} + +func (c countCompressor) compress(dst io.Writer, src *bytes.Buffer) error { + *c.compressorCalls++ + return c.compressionPool.compress(dst, src) +} +func (c countCompressor) decompress(dst *bytes.Buffer, src io.Reader) error { + *c.decompressorCalls++ + return c.compressionPool.decompress(dst, src) +} diff --git a/compression.go b/compression.go index 9eb5c34..13c6620 100644 --- a/compression.go +++ b/compression.go @@ -56,6 +56,12 @@ func (m compressionMap) intersection(names []string) []string { return intersection } +type compressor interface { + Name() string + compress(dst io.Writer, src *bytes.Buffer) error + decompress(dst *bytes.Buffer, src io.Reader) error +} + type compressionPool struct { name string decompressors sync.Pool @@ -85,14 +91,7 @@ func (p *compressionPool) Name() string { return p.name } -func (p *compressionPool) compress(dst, src *bytes.Buffer) error { - if p == nil { - _, err := io.Copy(dst, src) - return err - } - if src.Len() == 0 { - return nil - } +func (p *compressionPool) compress(dst io.Writer, src *bytes.Buffer) error { comp, _ := p.compressors.Get().(connect.Compressor) defer p.compressors.Put(comp) @@ -104,14 +103,7 @@ func (p *compressionPool) compress(dst, src *bytes.Buffer) error { return comp.Close() } -func (p *compressionPool) decompress(dst, src *bytes.Buffer) error { - if p == nil { - _, err := io.Copy(dst, src) - return err - } - if src.Len() == 0 { - return nil - } +func (p *compressionPool) decompress(dst *bytes.Buffer, src io.Reader) error { decomp, _ := p.decompressors.Get().(connect.Decompressor) defer p.decompressors.Put(decomp) diff --git a/errors.go b/errors.go index d4dbfce..5e0b69c 100644 --- a/errors.go +++ b/errors.go @@ -118,3 +118,9 @@ func malformedRequestError(err error) error { // Adds 400 Bad Request / InvalidArgument status codes to error return connect.NewError(connect.CodeInvalidArgument, err) } + +func errorf(code connect.Code, msg string, args ...any) error { + err := connect.NewError(code, fmt.Errorf(msg, args...)) + err.Meta() + return err +} diff --git a/handler.go b/handler.go index 5ac63f3..77a5595 100644 --- a/handler.go +++ b/handler.go @@ -1,5 +1,5 @@ // Copyright 2023 Buf Technologies, Inc. -// +//J // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -16,938 +16,638 @@ package vanguard import ( "bytes" - "context" "errors" "fmt" "io" "net/http" - "net/url" - "strconv" "strings" - "time" "connectrpc.com/connect" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/reflect/protoreflect" ) -// ServeHTTP implements http.Handler. -func (m *Mux) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - op := m.newOperation(writer, request) - err := op.validate(m, m.codecs) - - useUnknownHandler := m.UnknownHandler != nil && errors.Is(err, errNotFound) - var callback func(context.Context, Operation) (Hooks, error) - if op.methodConf != nil { - callback = op.methodConf.hooksCallback - } else { - callback = m.HooksCallback - } - if callback != nil { - var hookErr error - if op.hooks, hookErr = callback(op.request.Context(), op); hookErr != nil { - useUnknownHandler = false - err = hookErr +// ServeHTTP dispatches the request to the service converting protocols, +// encoding and compression as needed. +func (m *Mux) ServeHTTP(response http.ResponseWriter, request *http.Request) { + if err := m.serveHTTP(response, request); err != nil { + if herr := asHTTPError(err); herr != nil { + herr.Encode(response) + } else { + http.Error(response, err.Error(), http.StatusInternalServerError) } - } - if useUnknownHandler { - request.Header = op.originalHeaders // restore headers, just in case initialization removed keys - m.UnknownHandler.ServeHTTP(writer, request) return } +} +func (m *Mux) serveHTTP(response http.ResponseWriter, request *http.Request) error { + // Identify the method being invoked. + method, err := m.resolveMethod(request) if err != nil { - op.reportError(err) - return + if errors.Is(err, errNotFound) && m.UnknownHandler != nil { + m.UnknownHandler.ServeHTTP(response, request) + return nil + } + return err } - if op.hooks.isEmpty() && - op.client.protocol.protocol() == op.server.protocol.protocol() && - op.client.codec.Name() == op.server.codec.Name() && - op.client.reqCompression.Name() == op.server.reqCompression.Name() { - // No transformation needed. But we do need to restore the original headers first - // since extracting request metadata may have removed keys. - request.Header = op.originalHeaders - op.methodConf.handler.ServeHTTP(writer, request) - return + flusher, ok := response.(http.Flusher) + if !ok { + return errors.New("http.ResponseWriter must implement http.Flusher") } - op.handle() -} - -func (m *Mux) newOperation(writer http.ResponseWriter, request *http.Request) *operation { - ctx, cancel := context.WithCancel(request.Context()) - request = request.WithContext(ctx) op := &operation{ - writer: writer, - request: request, - cancel: cancel, - bufferPool: &m.bufferPool, - compressors: m.compressors, + bufferPool: &m.bufferPool, + request: request, + response: response, + flusher: flusher, + method: method, + requestMeta: requestMeta{ + Body: request.Body, + Header: httpHeader(request.Header), + URL: request.URL, + Method: request.Method, + ProtoMajor: request.ProtoMajor, + ProtoMinor: request.ProtoMinor, + }, + responseMeta: responseMeta{ + Header: httpHeader(response.Header()), + StatusCode: http.StatusOK, + }, + requestBuffer: messageBuffer{ + Buf: m.bufferPool.Get(), + }, + responseBuffer: messageBuffer{ + Buf: m.bufferPool.Get(), + }, + } + + if err := op.handle(); err != nil { + if httperr := (*httpError)(nil); errors.As(err, &httperr) { + return httperr + } + + // Protocol error, encode the error message. + buf := op.requestBuffer.Buf + buf.Reset() + op.method.client.EncodeError(buf, &op.responseMeta, err) + if !op.responseMeta.WroteStatus { + op.responseMeta.WroteStatus = true + op.response.WriteHeader(op.responseMeta.StatusCode) + } + // Write the error message if it was buffered. + if buf.Len() > 0 { + _, _ = op.response.Write(buf.Bytes()) + op.flusher.Flush() + } + } + + // Free buffers. + m.bufferPool.Put(op.requestBuffer.Buf) + m.bufferPool.Put(op.responseBuffer.Buf) + if op.responseEnd != nil { + m.bufferPool.Put(op.responseEnd) } - op.requestLine.fromRequest(request) - return op -} - -type clientProtocolDetails struct { - protocol clientProtocolHandler - codec Codec - reqCompression *compressionPool - respCompression *compressionPool + return nil } -var _ PeerInfo = (*clientProtocolDetails)(nil) - -func (c *clientProtocolDetails) Protocol() Protocol { - if c.protocol == nil { - return ProtocolUnknown +// operation represents a single HTTP operation, which maps to an incoming HTTP request. +// It tracks properties needed to implement protocol transformation. +type operation struct { + bufferPool *bufferPool + request *http.Request + response http.ResponseWriter + flusher http.Flusher + + method method + requestMeta requestMeta + responseMeta responseMeta + requestBuffer messageBuffer + responseBuffer messageBuffer + responseEnd *bytes.Buffer // Buffer waiting on trailers. + requestMessage proto.Message + responseMessage proto.Message + + codecName string + compressorName string + hasResponse bool + wroteHeader bool +} + +func (o *operation) handle() error { + if err := o.decodeRequestHeader(); err != nil { + return err } - return c.protocol.protocol() -} - -func (c *clientProtocolDetails) Codec() string { - if c.codec == nil { - return "" + if err := o.encodeRequestHeader(); err != nil { + return err } - return c.codec.Name() -} -func (c *clientProtocolDetails) RequestCompression() string { - if c.reqCompression == nil { - return "" + // Init message types. + requestType, err := o.method.config.resolver.FindMessageByName( + o.method.config.descriptor.Input().FullName(), + ) + if err != nil { + return err } - return c.reqCompression.Name() -} + o.requestMessage = requestType.New().Interface() + responseType, err := o.method.config.resolver.FindMessageByName( + o.method.config.descriptor.Output().FullName(), + ) + if err != nil { + return err + } + o.responseMessage = responseType.New().Interface() -func (c *clientProtocolDetails) ResponseCompression() string { - if c.respCompression == nil { - return "" + // If request parameters are partially encoded in the body, read them. + if o.requestMeta.RequiresBody { + // Trigger a read to buffer the first request message. + if _, err := o.read(nil); err != nil { + // io.EOF on the first read is okay, it just means the request + // body was empty. + if !errors.Is(err, io.EOF) { + return err + } + } } - return c.respCompression.Name() -} -func (c *clientProtocolDetails) doNotImplement() {} + // Build the request. + o.request.Body = requestReader{ + operation: o, + } + o.request.GetBody = nil // TODO: support GetBody + o.request.ProtoMajor = o.requestMeta.ProtoMajor + o.request.ProtoMinor = o.requestMeta.ProtoMinor + o.request.Method = o.requestMeta.Method + o.request.RequestURI = o.request.URL.RequestURI() // override -type serverProtocolDetails struct { - protocol serverProtocolHandler - codec Codec - reqCompression *compressionPool - respCompression *compressionPool -} + // Build the response writer. + response := responseWriter{ + operation: o, + } -var _ PeerInfo = (*serverProtocolDetails)(nil) + // Serve the request. + o.method.config.handler.ServeHTTP(response, o.request) -func (s *serverProtocolDetails) Protocol() Protocol { - if s.protocol == nil { - return ProtocolUnknown + // Flush the response EOF. + if _, err := o.write(nil, true); err != nil { + return err } - return s.protocol.protocol() -} -func (s *serverProtocolDetails) Codec() string { - if s.codec == nil { - return "" + // Encode trailers. + trailer := o.responseBuffer.Buf + if !o.responseBuffer.Src.IsTrailer { + trailer.Reset() } - return s.codec.Name() -} - -func (s *serverProtocolDetails) RequestCompression() string { - if s.reqCompression == nil { - return "" + if err := o.method.server.DecodeResponseTrailer(trailer, &o.responseMeta); err != nil { + return err + } + trailer.Reset() + if err := o.method.client.EncodeResponseTrailer(trailer, &o.responseMeta); err != nil { + return err } - return s.reqCompression.Name() -} -func (s *serverProtocolDetails) ResponseCompression() string { - if s.respCompression == nil { - return "" + // Ensure the status is written. + o.writeStatus() + + // Encode message waiting on trailers. + if o.responseEnd != nil { + _, _ = o.responseEnd.WriteTo(o.response) + o.flusher.Flush() + } + // Encode trailers if needed. + if trailer.Len() > 0 { + _, _ = trailer.WriteTo(o.response) + o.flusher.Flush() } - return s.respCompression.Name() + return nil } -func (s *serverProtocolDetails) doNotImplement() {} - -func classifyRequest(req *http.Request) (clientProtocolHandler, url.Values) { - contentTypes := req.Header["Content-Type"] +func (o *operation) read(data []byte) (readN int, err error) { + msgBuf := &o.requestBuffer + meta := &o.requestMeta + for { + switch msgBuf.stage { + case stageEmpty: + // Read the first partial bytes for EOF detection. + if msgBuf.Buf.Len() == 0 { + if _, err := read(msgBuf.Buf, o.requestMeta.Body); err != nil { + if !errors.Is(err, io.EOF) { + return 0, err + } + msgBuf.Src.IsEOF = true + if msgBuf.Index > 0 { + return 0, io.EOF + } + } + } - if len(contentTypes) == 0 { //nolint:nestif - // Empty bodies should still have content types. So this should only - // happen for requests with NO body at all. That's only allowed for - // REST calls and Connect GET calls. - connectVersion := req.Header["Connect-Protocol-Version"] - // If this header is present, the intent is clear. But Connect GET - // requests should actually encode this via query string (see below). - if len(connectVersion) == 1 && connectVersion[0] == "1" { - if req.Method == http.MethodGet { - return connectUnaryGetClientProtocol{}, nil + if err := o.method.client.PrepareRequestMessage(msgBuf, meta); err != nil { + if !errors.Is(err, io.ErrShortBuffer) { + return 0, err + } + if _, err := read(msgBuf.Buf, o.requestMeta.Body); err != nil { + if !errors.Is(err, io.EOF) { + return 0, err + } + if msgBuf.Src.IsEOF { + return 0, io.EOF + } + msgBuf.Src.IsEOF = true + } + continue } - return nil, nil - } - values := req.URL.Query() - if values.Get("connect") == "v1" { - if req.Method != http.MethodGet { - return nil, nil + if msgBuf.Src.IsTrailer { + return 0, fmt.Errorf("unexpected trailer in request") } - return connectUnaryGetClientProtocol{}, values - } - return restClientProtocol{}, values - } - - if len(contentTypes) > 1 { - return nil, nil // Ick. Don't allow this. - } - contentType := contentTypes[0] - var values url.Values - switch { - case strings.HasPrefix(contentType, "application/connect+"): - return connectStreamClientProtocol{}, nil - case contentType == "application/grpc" || strings.HasPrefix(contentType, "application/grpc+"): - return grpcClientProtocol{}, nil - case contentType == "application/grpc-web" || strings.HasPrefix(contentType, "application/grpc-web+"): - return grpcWebClientProtocol{}, nil - case strings.HasPrefix(contentType, "application/"): - connectVersion := req.Header["Connect-Protocol-Version"] - if len(connectVersion) == 1 && connectVersion[0] == "1" { - if req.Method == http.MethodGet { - return connectUnaryGetClientProtocol{}, nil + if msgBuf.Index > 0 && o.method.config.streamType == connect.StreamTypeUnary { + return 0, fmt.Errorf("unexpected message in request") } - return connectUnaryPostClientProtocol{}, nil - } - values = req.URL.Query() - if values.Get("connect") == "v1" { - if req.Method != http.MethodGet { - return nil, nil + if err := o.method.server.PrepareRequestMessage(msgBuf, meta); err != nil { + return 0, err + } + // TODO: optimize streaming case to avoid buffering. + msgBuf.stage = stageRead + + case stageRead: + size := o.resolveSize(msgBuf.Src) + remN := size - int64(msgBuf.Buf.Len()) + + var excessBuf *bytes.Buffer + if remN < 0 { + // Excess bytes in the buffer, so we need to split the buffer. + buf := o.bufferPool.Get() + _, _ = buf.ReadFrom(io.LimitReader(msgBuf.Buf, size)) + excessBuf, msgBuf.Buf = msgBuf.Buf, buf // swap + defer o.bufferPool.Put(excessBuf) + } else { + if _, err := msgBuf.Buf.ReadFrom( + io.LimitReader(o.requestMeta.Body, remN), + ); err != nil { + return 0, err + } + } + if msgBuf.Src.ReadMode == readModeSize && + int64(msgBuf.Buf.Len()) < int64(msgBuf.Src.Size) { + return 0, io.ErrUnexpectedEOF } - return connectUnaryGetClientProtocol{}, values - } - // REST usually uses application/json, but use of google.api.HttpBody means it could - // also use *any* content-type. - fallthrough - default: - return restClientProtocol{}, values - } -} - -// operation represents a single HTTP operation, which maps to an incoming HTTP request. -// It tracks properties needed to implement protocol transformation. -type operation struct { - writer http.ResponseWriter - request *http.Request - cancel context.CancelFunc - bufferPool *bufferPool - compressors compressionMap - - queryVars url.Values - originalHeaders http.Header - reqContentType string // original content-type in incoming request headers - rspContentType string // original content-type in outgoing response headers - contentLen int64 // original content-length in incoming request headers or -1 - requestLine requestLine // properties of the original incoming request line - reqMeta requestMeta - deadline time.Time - methodConf *methodConfig - - client clientProtocolDetails - server serverProtocolDetails - - // only used when clientProtocolDetails.protocol == ProtocolREST - restTarget *routeTarget - restVars []routeTargetVarMatch - - hooks Hooks - isValid bool - - // these fields memoize the results of type assertions and some method calls - clientEnveloper envelopedProtocolHandler - clientPreparer clientBodyPreparer - clientReqNeedsPrep bool - clientRespNeedsPrep bool - serverEnveloper serverEnvelopedProtocolHandler - serverPreparer serverBodyPreparer - serverReqNeedsPrep bool - serverRespNeedsPrep bool -} -var _ Operation = (*operation)(nil) + if err := msgBuf.Convert( + o.bufferPool, + o.requestMessage, + meta.Client, + meta.Server, + ); err != nil { + return 0, err + } -func (o *operation) IsValid() bool { - return o.isValid -} + if excessBuf != nil { + // Append excess, if any. + msgBuf.Buf.Write(excessBuf.Bytes()) + } -func (o *operation) HTTPRequestLine() (method, path, queryString, httpVersion string) { - return o.requestLine.method, o.requestLine.path, o.requestLine.queryString, o.requestLine.httpVersion -} + case stageBuffered: + readN, err = msgBuf.Read(data) + if errors.Is(err, io.EOF) { + msgBuf.Flush() + continue + } + return readN, err -func (o *operation) Method() protoreflect.MethodDescriptor { - if o.methodConf == nil { - return nil + default: + return 0, errors.New("invalid message stage") + } } - return o.methodConf.descriptor -} - -func (o *operation) Deadline() (time.Time, bool) { - return o.deadline, o.reqMeta.hasTimeout } -func (o *operation) ClientInfo() PeerInfo { - return &o.client +func (o *operation) decodeRequestHeader() error { + if err := o.method.client.DecodeRequestHeader(&o.requestMeta); err != nil { + return err + } + o.codecName = o.requestMeta.CodecName + o.compressorName = o.requestMeta.CompressionName + o.requestMeta.CodecName = o.method.config.ResolveServerCodecName( + o.requestMeta.CodecName) + o.requestMeta.CompressionName = o.method.config.ResolveServerCompressorName( + o.requestMeta.CompressionName) + return nil } - -func (o *operation) HandlerInfo() PeerInfo { - return &o.server +func (o *operation) encodeRequestHeader() error { + o.responseMeta.CodecName = o.requestMeta.CodecName + o.responseMeta.CompressionName = o.requestMeta.CompressionName + return o.method.server.EncodeRequestHeader(&o.requestMeta) } - -func (o *operation) doNotImplement() {} - -func (o *operation) validate(mux *Mux, codecs codecMap) error { - // Identify the protocol. - clientProtoHandler, queryVars := classifyRequest(o.request) - if clientProtoHandler == nil { - return newHTTPError(http.StatusUnsupportedMediaType, "could not classify protocol") - } - o.client.protocol = clientProtoHandler - if queryVars != nil { - // memoize this, so we don't have to parse query string again later - o.queryVars = queryVars - } - o.originalHeaders = o.request.Header.Clone() - o.reqContentType = o.originalHeaders.Get("Content-Type") - o.contentLen = o.request.ContentLength - o.request.ContentLength = -1 // transforming it will likely change it - - // Identify the method being invoked. - err := o.resolveMethod(mux) - if err != nil { +func (o *operation) decodeResponseHeader() error { + if err := o.method.server.DecodeRequestHeader(&o.responseMeta); err != nil { return err } - if !o.client.protocol.acceptsStreamType(o, o.methodConf.streamType) { - return newHTTPError(http.StatusUnsupportedMediaType, "stream type %s not supported with %s protocol", o.methodConf.streamType, o.client.protocol) - } - if o.methodConf.streamType == connect.StreamTypeBidi && o.request.ProtoMajor < 2 { - return newHTTPError(http.StatusHTTPVersionNotSupported, "bidi streams require HTTP/2") - } - if clientProtoHandler.protocol() == ProtocolGRPC && o.request.ProtoMajor != 2 { - return newHTTPError(http.StatusHTTPVersionNotSupported, "gRPC requires HTTP/2") - } - - // Identify the request encoding and compression. - reqMeta, err := clientProtoHandler.extractProtocolRequestHeaders(o, o.request.Header) - if err != nil { - return newHTTPError(http.StatusBadRequest, err.Error()) - } - // Remove other headers that might mess up the next leg - if enc := o.request.Header.Get("Content-Encoding"); enc != "" && enc != CompressionIdentity { - // If the protocol didn't remove the "Content-Encoding" header in above step, - // that's because it models encoding in a different way. In that case, encoding - // of the whole response with this header is not valid. - return newHTTPError(http.StatusUnsupportedMediaType, "content-encoding %q not allowed for this protocol", enc) - } - o.request.Header.Del("Content-Encoding") - o.request.Header.Del("Accept-Encoding") - o.request.Header.Del("Content-Length") - - o.reqMeta = reqMeta - if reqMeta.hasTimeout { - o.deadline = time.Now().Add(reqMeta.timeout) - } - if reqMeta.compression == CompressionIdentity { - reqMeta.compression = "" // normalize to empty string - } - if reqMeta.compression != "" { - var ok bool - o.client.reqCompression, ok = o.compressors[reqMeta.compression] - if !ok { - return newHTTPError(http.StatusUnsupportedMediaType, "%q compression not supported", reqMeta.compression) - } - } - o.client.codec = codecs.get(reqMeta.codec, o.methodConf.resolver) - if o.client.codec == nil { - return newHTTPError(http.StatusUnsupportedMediaType, "%q sub-format not supported", reqMeta.codec) + return nil +} +func (o *operation) encodeResponseHeader() error { + if o.codecName != "" { + o.responseMeta.CodecName = o.codecName } - - // Now we can determine the destination protocol details - if _, supportsProtocol := o.methodConf.protocols[clientProtoHandler.protocol()]; supportsProtocol { - o.server.protocol = clientProtoHandler.protocol().serverHandler(o) - } else { - for protocol := protocolMin; protocol <= protocolMax; protocol++ { - if _, supportsProtocol := o.methodConf.protocols[protocol]; supportsProtocol { - o.server.protocol = protocol.serverHandler(o) - break - } - } + if o.compressorName != "" { + o.responseMeta.CompressionName = o.compressorName } + return o.method.client.EncodeRequestHeader(&o.responseMeta) +} - // Now that we've ruled out the use of bidi streaming above, it's safe to simulate HTTP/2 - // for the benefit of gRPC handlers, which require HTTP/2. - if o.server.protocol.protocol() == ProtocolGRPC { - o.request.Proto, o.request.ProtoMajor, o.request.ProtoMinor = "HTTP/2", 2, 0 +func (o *operation) writeHeader() error { + if o.wroteHeader { + return nil } - - if o.server.protocol.protocol() == ProtocolREST { - // REST always uses JSON. - // TODO: Allow non-JSON encodings with REST? Would require registering content-types with codecs. - // Would also require figuring out how to (un)marshal things other than messages when a body - // path indicates a non-message field (do-able with JSON, but maybe non-starter with proto?) - // - // NB: This is fine to set even if a custom content-type is used via - // the use of google.api.HttpBody. The actual content-type and body - // data will be written via serverBodyPreparer implementation. - o.server.codec = codecs.get(CodecJSON, o.methodConf.resolver) - } else if _, supportsCodec := o.methodConf.codecNames[reqMeta.codec]; supportsCodec { - o.server.codec = o.client.codec - } else { - o.server.codec = codecs.get(o.methodConf.preferredCodec, o.methodConf.resolver) + if err := o.decodeResponseHeader(); err != nil { + return err } - - if reqMeta.compression != "" { - if _, supportsCompression := o.methodConf.compressorNames[reqMeta.compression]; supportsCompression { - o.server.reqCompression = o.client.reqCompression - } - // else: we'll just decompress and not recompress - // TODO: should we instead pick a supported compression scheme (if there is one)? + if err := o.encodeResponseHeader(); err != nil { + return err } - - o.isValid = true // Successfully validated! + o.wroteHeader = true return nil } -func (o *operation) queryValues() url.Values { - if o.queryVars == nil && o.request.URL.RawQuery != "" { - o.queryVars = o.request.URL.Query() +func (o *operation) writeStatus() { + _ = o.writeHeader() // ignore error, should be written already + if o.responseMeta.WroteStatus { + return } - return o.queryVars + o.responseMeta.WroteStatus = true + o.response.WriteHeader(o.responseMeta.StatusCode) } -func (o *operation) handle() { //nolint:gocyclo - if o.hooks.OnClientRequestHeaders != nil { - if err := o.hooks.OnClientRequestHeaders(o.request.Context(), o, o.request.Header); err != nil { - o.reportError(err) - return - } +func (o *operation) resolveSize(param srcParams) int64 { + switch param.ReadMode { + case readModeSize: + return int64(param.Size) + case readModeEOF: + return int64(o.method.config.maxMsgBufferBytes) + case readModeChunk: + return chunkMessageSize + default: + return -1 } +} - o.clientEnveloper, _ = o.client.protocol.(envelopedProtocolHandler) - o.clientPreparer, _ = o.client.protocol.(clientBodyPreparer) - if o.clientPreparer != nil { - o.clientReqNeedsPrep = o.clientPreparer.requestNeedsPrep(o) - } - o.serverEnveloper, _ = o.server.protocol.(serverEnvelopedProtocolHandler) - o.serverPreparer, _ = o.server.protocol.(serverBodyPreparer) - if o.serverPreparer != nil { - o.serverReqNeedsPrep = o.serverPreparer.requestNeedsPrep(o) - } +func (o *operation) write(data []byte, isEOF bool) (wroteN int, err error) { + msgBuf := &o.responseBuffer + meta := &o.responseMeta - serverRequestBuilder, _ := o.server.protocol.(requestLineBuilder) - var requireMessageForRequestLine bool - if serverRequestBuilder != nil { - requireMessageForRequestLine = serverRequestBuilder.requiresMessageToProvideRequestLine(o) + if err := o.writeHeader(); err != nil { + return 0, err } - sameRequestCompression := o.client.reqCompression.Name() == o.server.reqCompression.Name() - sameCodec := o.client.codec.Name() == o.server.codec.Name() - // even if body encoding uses same content type, we can't treat them as the same - // (which means re-using encoded data) if either side needs to prep the data first - sameRequestCodec := sameCodec && !o.clientReqNeedsPrep && !o.serverReqNeedsPrep - mustDecodeRequest := !sameRequestCodec || requireMessageForRequestLine || o.hooks.OnClientRequestMessage != nil + // Buffer the first partial bytes of the response message. + wroteN, _ = msgBuf.Buf.Write(data) + msgBuf.Src.IsEOF = isEOF - reqMsg := message{ - sameCompression: sameRequestCompression, - sameCodec: sameRequestCodec, - } + for { + switch msgBuf.stage { + case stageEmpty: + if isEOF { + if msgBuf.Buf.Len() > 0 { + return wroteN, io.ErrUnexpectedEOF + } + return wroteN, nil + } + if err := o.method.server.PrepareResponseMessage(msgBuf, meta); err != nil { + if errors.Is(err, io.ErrShortBuffer) { + return wroteN, nil // okay to return partial write + } + return 0, err + } + if err := o.method.client.PrepareResponseMessage(msgBuf, meta); err != nil { + return 0, err + } + // TODO: optimize streaming case to avoid buffering. + msgBuf.stage = stageRead - if mustDecodeRequest { - // Need the message type to decode - messageType, err := o.methodConf.resolver.FindMessageByName(o.methodConf.descriptor.Input().FullName()) - if err != nil { - o.reportError(err) - return - } - reqMsg.msg = messageType.New().Interface() - } + case stageRead: + size := o.resolveSize(msgBuf.Src) + remN := size - int64(msgBuf.Buf.Len()) - if (o.hooks.OnClientRequestMessage != nil && o.methodConf.streamType == connect.StreamTypeUnary) || - requireMessageForRequestLine { - // Go ahead and process first request message - switch err := o.readRequestMessage(nil, o.request.Body, &reqMsg); { - case errors.Is(err, io.EOF): - // okay for the first message: means empty message data - reqMsg.markReady() - case err != nil: - o.reportError(err) - return - } - if err := reqMsg.advanceToStage(o, stageDecoded); err != nil { - o.reportError(err) - return - } - if o.hooks.OnClientRequestMessage != nil { - compressed := reqMsg.wasCompressed && o.client.reqCompression != nil - err := o.hooks.OnClientRequestMessage(o.request.Context(), o, reqMsg.msg, compressed, reqMsg.size) - if err != nil { - o.reportError(err) - return + if (msgBuf.Src.ReadMode == readModeEOF || + msgBuf.Src.ReadMode == readModeChunk) && isEOF { + if remN > 0 { + remN = 0 + } } - } - } - var skipBody bool - if serverRequestBuilder != nil { - var hasBody bool - var err error - o.request.URL.Path, o.request.URL.RawQuery, o.request.Method, hasBody, err = - serverRequestBuilder.requestLine(o, reqMsg.msg) - if err != nil { - o.reportError(err) - return - } - skipBody = !hasBody - // Recompute if the server needs to prep the request, now that we've modified - // properties of op.request. - if o.serverPreparer != nil { - o.serverReqNeedsPrep = o.serverPreparer.requestNeedsPrep(o) - } - } else { - // if no request line builder, use simple request layout - o.request.URL.Path = o.methodConf.methodPath - o.request.URL.RawQuery = "" - o.request.Method = http.MethodPost - } - o.request.URL.ForceQuery = false - serverReqMeta := o.reqMeta - serverReqMeta.codec = o.server.codec.Name() - serverReqMeta.compression = o.server.reqCompression.Name() - serverReqMeta.acceptCompression = o.compressors.intersection(o.reqMeta.acceptCompression) - o.server.protocol.addProtocolRequestHeaders(serverReqMeta, o.request.Header) + var excessBuf *bytes.Buffer + if remN < 0 { + // Excess bytes in the buffer, so we need to split the buffer. + buf := o.bufferPool.Get() + _, _ = buf.ReadFrom(io.LimitReader(msgBuf.Buf, size)) + excessBuf, msgBuf.Buf = msgBuf.Buf, buf // swap + defer o.bufferPool.Put(excessBuf) + } else if remN > 0 { + return wroteN, nil // okay to return partial write + } - // Now we can define the transformed response writer (which delays - // much of its logic until it sees the response headers). - flusher := asFlusher(o.writer) - if flusher == nil { - o.reportError(errors.New("http.ResponseWriter must implement http.Flusher")) - return - } - rw := &responseWriter{op: o, delegate: o.writer, flusher: flusher} - defer rw.close() - o.writer = rw + if msgBuf.Src.IsTrailer { + // For compressored trailers handle decompression. + var compressor compressor + if msgBuf.Src.IsCompressed { + compressor = meta.Server.Compressor + } + // Got trailer, done. + msgBuf.stage = stageEOF + // Decompress the trailer if needed. + return wroteN, decodeBuffer( + o.bufferPool, + msgBuf.Buf, + nil, // Empty message + nil, // No codec + compressor, + ) + } - // And finally we can define the transformed request bodies. - switch { - case skipBody: - // drain any contents of body so downstream handler sees empty - o.drainBody(o.request.Body) - case sameRequestCompression && sameRequestCodec && !mustDecodeRequest: - // we do not need to decompress or decode; just transforming envelopes - o.request.Body = &envelopingReader{rw: rw, r: o.request.Body} - default: - tw := &transformingReader{rw: rw, msg: &reqMsg, r: o.request.Body} - o.request.Body = tw - if reqMsg.stage != stageEmpty { - if err := tw.prepareMessage(); err != nil { - tw.err = err + if err := msgBuf.Convert( + o.bufferPool, + o.responseMessage, + meta.Server, + meta.Client, + ); err != nil { + return 0, err + } + + // Append excess, if any. + if excessBuf != nil { + msgBuf.Buf.Write(excessBuf.Bytes()) + } + + case stageBuffered: + // Wait for trailers stores the message buffer to be written + // after the trailers are written. + if msgBuf.Dst.WaitForTrailer { + if o.responseEnd == nil { + o.responseEnd = o.bufferPool.Get() + } + if _, err := msgBuf.WriteTo(o.responseEnd); err != nil { + return wroteN, err + } + msgBuf.Flush() + continue + } + // Otherwise, write the message buffer. + o.writeStatus() + _, err := msgBuf.WriteTo(o.response) + if err != nil { + return wroteN, err + } + o.flusher.Flush() + msgBuf.Flush() + // Loop back to process the next message if excess bytes were + // in the buffer. + + case stageEOF: + if !isEOF { + return wroteN, io.ErrShortWrite } + return wroteN, nil + + default: + return 0, errors.New("invalid message stage") } } +} - o.methodConf.handler.ServeHTTP(o.writer, o.request) +type method struct { + config *methodConfig + client clientProtocolHandler + server serverProtocolHandler } -func (o *operation) resolveMethod(mux *Mux) error { - uriPath := o.request.URL.Path - switch o.client.protocol.protocol() { +func (m *Mux) resolveMethod(request *http.Request) (method, error) { + // Identify the protocol. + clientProtocol := classifyRequest(request) + if clientProtocol == ProtocolUnknown { + return method{}, newHTTPError(http.StatusUnsupportedMediaType, "could not classify protocol") + } + + var ( + config *methodConfig + restTarget *routeTarget + restVars []routeTargetVarMatch + ) + uriPath := request.URL.Path + switch clientProtocol { case ProtocolREST: - var methods routeMethods - o.restTarget, o.restVars, methods = mux.restRoutes.match(uriPath, o.request.Method) - if o.restTarget != nil { - o.methodConf = o.restTarget.config - return nil - } - if len(methods) == 0 { - return errNotFound - } - var sb strings.Builder - for method := range methods { - if sb.Len() > 0 { - sb.WriteByte(',') + var restMethods routeMethods + restTarget, restVars, restMethods = m.restRoutes.match(uriPath, request.Method) + if restTarget == nil { + if len(restMethods) == 0 { + return method{}, errNotFound } - sb.WriteString(method) - } - return &httpError{ - code: http.StatusMethodNotAllowed, - header: http.Header{ - "Allow": []string{sb.String()}, - }, - } - default: - methodConf := mux.methods[uriPath] - if methodConf == nil { - // TODO: if the service is known, but the method is not, we should send to the client - // a proper RPC error (encoded per protocol handler) with an Unimplemented code. - return errNotFound - } - o.restTarget = methodConf.httpRule - if o.request.Method != http.MethodPost { - mayAllowGet, ok := o.client.protocol.(clientProtocolAllowsGet) - allowsGet := ok && mayAllowGet.allowsGetRequests(methodConf) - if !allowsGet { - return &httpError{ - code: http.StatusMethodNotAllowed, - header: http.Header{ - "Allow": []string{http.MethodPost}, - }, + var sb strings.Builder + for method := range restMethods { + if sb.Len() > 0 { + sb.WriteByte(',') } + sb.WriteString(method) } - if allowsGet && o.request.Method != http.MethodGet { - return &httpError{ - code: http.StatusMethodNotAllowed, - header: http.Header{ - "Allow": []string{http.MethodGet + "," + http.MethodPost}, - }, - } + return method{}, &httpError{ + code: http.StatusMethodNotAllowed, + header: http.Header{ + "Allow": []string{sb.String()}, + }, } } - o.methodConf = methodConf - return nil - } -} - -// reportError handles an error that occurs while setting up the operation. It should not be used -// once the underlying server handler has been invoked. For those errors, responseWriter.reportError -// must be used instead. -func (o *operation) reportError(err error) { - defer o.cancel() - - if !o.isValid { - // We don't have enough operation details to render an RPC error, - // so just send a simple HTTP error. - asHTTPError(err).Encode(o.writer) - return + config = restTarget.config + default: + config = m.methods[uriPath] + if config == nil { + return method{}, errNotFound + } + restTarget = config.httpRule } - rw, ok := o.writer.(*responseWriter) - if ok { - rw.reportError(err) - return - } - // No responseWriter created yet, so we duplicate some of its behavior to write an error. - if o.hooks.OnOperationFail != nil { - if hookErr := o.hooks.OnOperationFail(o.request.Context(), o, nil, err); hookErr != nil { - // report the error returned by the hook - err = hookErr + var client clientProtocolHandler + switch clientProtocol { + case ProtocolConnect: + if config.streamType == connect.StreamTypeUnary { + client = connectUnaryClientProtocol{ + config: config, + } + } else { + client = connectStreamClientProtocol{ + config: config, + } + } + case ProtocolGRPC: + client = grpcClientProtocol{ + config: config, + } + case ProtocolGRPCWeb: + client = grpcWebClientProtocol{ + config: config, + } + case ProtocolREST: + client = restClientProtocol{ + target: restTarget, + vars: restVars, + } + default: + return method{}, &httpError{ + code: http.StatusInternalServerError, + err: errors.ErrUnsupported, } } - httpErr := asHTTPError(err) - httpErr.EncodeHeaders(o.writer.Header()) - connErr := asConnectError(err) - end := &responseEnd{err: connErr, httpCode: httpErr.code} - code := o.client.protocol.addProtocolResponseHeaders(responseMeta{end: end}, o.writer.Header()) - o.writer.WriteHeader(code) - trailers := o.client.protocol.encodeEnd(o, end, o.writer, true) - httpMergeTrailers(o.writer.Header(), trailers) -} -func (o *operation) readRequestMessage(rw *responseWriter, reader io.Reader, msg *message) error { - msgLen := -1 - compressed := o.client.reqCompression != nil - if o.clientEnveloper != nil { - var envBuf envelopeBytes - _, err := io.ReadFull(reader, envBuf[:]) - if err != nil { - return err - } - msgLen, compressed, err = o.processRequestEnvelope(envBuf) - if err != nil { - if rw != nil { - rw.reportError(err) + // Identity the protocol to convert to. + serverProtocol := clientProtocol + if _, isSupported := config.protocols[serverProtocol]; !isSupported { + for protocol := protocolMin; protocol <= protocolMax; protocol++ { + if _, isSupported := config.protocols[protocol]; isSupported { + serverProtocol = protocol + break } - return err } } - - buffer := msg.reset(o.bufferPool, true, compressed) - var err error - if msgLen == -1 { //nolint:nestif - limit, grow, makeError, limitErr := o.determineReadLimit() - if limitErr != nil { - if rw != nil { - rw.reportError(limitErr) + var server serverProtocolHandler + switch serverProtocol { + case ProtocolConnect: + if config.streamType == connect.StreamTypeUnary { + server = &connectUnaryServerProtocol{ + config: config, + } + } else { + server = connectStreamServerProtocol{ + config: config, } - return limitErr } - if grow { - buffer.Grow(int(limit)) + case ProtocolGRPC: + server = grpcServerProtocol{ + config: config, } - _, err = io.Copy(buffer, &hardLimitReader{r: reader, rw: rw, limit: limit, makeError: makeError}) - if err == nil && buffer.Len() == 0 { - err = io.EOF + case ProtocolGRPCWeb: + server = grpcWebServerProtocol{ + config: config, } - } else { - buffer.Grow(msgLen) - _, err = io.CopyN(buffer, reader, int64(msgLen)) - if errors.Is(err, io.EOF) { - // EOF is a sentinel that means normal end of stream; replace it so callers know an error occurred - err = io.ErrUnexpectedEOF + case ProtocolREST: + server = &restServerProtocol{ + target: restTarget, + } + default: + return method{}, &httpError{ + code: http.StatusInternalServerError, + err: errors.ErrUnsupported, } } - if err != nil { - return err - } - msg.markReady() - return nil + return method{ + config: config, + client: client, + server: server, + }, nil } -func (o *operation) processRequestEnvelope(envBuf envelopeBytes) (msgLen int, compressed bool, err error) { - env, err := o.clientEnveloper.decodeEnvelope(envBuf) - if err != nil { - return 0, false, malformedRequestError(err) - } - if env.trailer { - return 0, false, malformedRequestError(fmt.Errorf("client stream cannot include status/trailer message")) - } - if limit := o.methodConf.maxMsgBufferBytes; env.length > limit { - return 0, false, bufferLimitError(int64(limit)) - } - return int(env.length), env.compressed, nil +type requestReader struct { + operation *operation } -func (o *operation) determineReadLimit() (limit int64, grow bool, makeError func(int64) error, err error) { - limit = int64(o.methodConf.maxMsgBufferBytes) - if o.contentLen == -1 { - return limit, false, bufferLimitError, nil - } - if o.contentLen > limit { - // content-length header tells us that entity is too large - err := bufferLimitError(limit) - return 0, false, nil, err - } - return o.contentLen, true, contentLengthError, nil +func (r requestReader) Read(data []byte) (int, error) { + return r.operation.read(data) } - -func (o *operation) drainBody(body io.ReadCloser) { - if wt, ok := body.(io.WriterTo); ok { - _, _ = wt.WriteTo(io.Discard) - return - } - buf := o.bufferPool.Get() - defer o.bufferPool.Put(buf) - b := buf.Bytes()[0:buf.Cap()] - _, _ = io.CopyBuffer(io.Discard, body, b) -} - -// envelopingReader will translate between envelope styles as data is read. -// It does not do any decompressing or deserializing of data. -type envelopingReader struct { - rw *responseWriter - r io.ReadCloser - - err error - current io.Reader - mustReleaseCurrent bool - env envelopeBytes - envRemain int -} - -func (r *envelopingReader) Read(data []byte) (n int, err error) { - if r.err != nil { - return 0, r.err - } - if r.current != nil { - bytesRead, err := r.current.Read(data) - isEOF := errors.Is(err, io.EOF) - if bytesRead > 0 && (err == nil || isEOF) { - return bytesRead, nil - } - if err != nil && !isEOF { - r.err = err - return bytesRead, err - } - // otherwise EOF, fall through - } - - if err := r.prepareNext(); err != nil { - r.err = err - return 0, err - } - - if len(data) < r.envRemain { - copy(data, r.env[envelopeLen-r.envRemain:]) - r.envRemain -= len(data) - return len(data), nil - } - var offset int - if r.envRemain > 0 { - copy(data, r.env[envelopeLen-r.envRemain:]) - offset = r.envRemain - r.envRemain = 0 - } - if len(data) > offset { - n, err = r.current.Read(data[offset:]) - } - return offset + n, err -} - -func (r *envelopingReader) Close() error { - if r.mustReleaseCurrent { - buf, ok := r.current.(*bytes.Buffer) - if ok { - r.rw.op.bufferPool.Put(buf) - } - r.current = nil - r.mustReleaseCurrent = false - } - r.err = errors.New("body is closed") - return r.r.Close() -} - -func (r *envelopingReader) prepareNext() error { - var env envelope - switch { - case r.rw.op.clientEnveloper == nil && r.rw.op.serverEnveloper == nil: - // no envelopes to transform, just pass the body through w/ no change - r.current = r.r - r.envRemain = 0 - return nil - case r.rw.op.clientEnveloper == nil: - env.compressed = r.rw.op.client.reqCompression != nil - if r.rw.op.contentLen != -1 { - r.current = &hardLimitReader{r: r.r, rw: r.rw, limit: r.rw.op.contentLen, makeError: contentLengthError} - env.length = uint32(r.rw.op.contentLen) - } else { - // Oof. We have to buffer entire request in order to measure it. - limit := int64(r.rw.op.methodConf.maxMsgBufferBytes) - buf := r.rw.op.bufferPool.Get() - _, err := io.Copy(buf, &hardLimitReader{r: r.r, rw: r.rw, limit: limit}) - if err != nil { - r.rw.op.bufferPool.Put(buf) - r.err = err - return err - } - r.current = buf - r.mustReleaseCurrent = true - env.length = uint32(buf.Len()) - } - default: // clientEnveloper != nil - var envBytes envelopeBytes - _, err := io.ReadFull(r.r, envBytes[:]) - if err != nil { - return err - } - env, err = r.rw.op.clientEnveloper.decodeEnvelope(envBytes) - if err != nil { - err = malformedRequestError(err) - r.rw.reportError(err) - return err - } - r.current = io.LimitReader(r.r, int64(env.length)) - } - - if r.rw.op.serverEnveloper == nil { - r.envRemain = 0 - } else { - r.envRemain = envelopeLen - r.env = r.rw.op.serverEnveloper.encodeEnvelope(env) - } - return nil -} - -// transformingReader transforms the data from the original request -// into a new protocol form as the data is read. It must decompress -// and deserialize each message and then re-serialize (and optionally -// recompress) each message. Since the original incoming protocol may -// have different envelope conventions than the outgoing protocol, it -// also rewrites envelopes. -type transformingReader struct { - rw *responseWriter - msg *message - r io.ReadCloser - - consumedFirst bool - err error - buffer *bytes.Buffer - env envelopeBytes - envRemain int -} - -func (r *transformingReader) Read(data []byte) (n int, err error) { - if r.err != nil { - return 0, r.err - } - - for { - if len(data) < r.envRemain { - copy(data, r.env[envelopeLen-r.envRemain:]) - r.envRemain -= len(data) - return len(data), nil - } - var offset int - if r.envRemain > 0 { - copy(data, r.env[envelopeLen-r.envRemain:]) - offset = r.envRemain - r.envRemain = 0 - } - var err error - if len(data) > offset && r.buffer != nil { - n, err = r.buffer.Read(data[offset:]) - } - if offset+n > 0 { - return offset + n, err - } - - // If we get here, there was nothing in tr.buffer to read, so - // we need to prepare the next message and try again. - - if err := r.rw.op.readRequestMessage(r.rw, r.r, r.msg); err != nil { - // If this is the first request message, the error is EOF, and there's a body - // preparer, we'll allow it and let the preparer produce a message from zero - // request bytes. - if !r.consumedFirst && errors.Is(err, io.EOF) && r.rw.op.clientReqNeedsPrep { - r.msg.markReady() - } else { - r.err = err - return 0, err - } - } - if r.rw.op.hooks.OnClientRequestMessage != nil { - if err := r.msg.advanceToStage(r.rw.op, stageDecoded); err != nil { - r.err = err - return 0, err - } - compressed := r.msg.wasCompressed && r.rw.op.client.reqCompression != nil - if err := r.rw.op.hooks.OnClientRequestMessage(r.rw.op.request.Context(), r.rw.op, r.msg.msg, compressed, r.msg.size); err != nil { - r.rw.reportError(err) - return 0, context.Canceled - } - } - if err := r.prepareMessage(); err != nil { - r.err = err - return 0, err - } - } -} - -func (r *transformingReader) Close() error { - r.err = errors.New("body is closed") - r.msg.release(r.rw.op.bufferPool) - return r.r.Close() -} - -func (r *transformingReader) prepareMessage() error { - r.consumedFirst = true - if err := r.msg.advanceToStage(r.rw.op, stageSend); err != nil { - return err - } - r.buffer = r.msg.sendBuffer() - if r.rw.op.serverEnveloper == nil { - r.envRemain = 0 - return nil - } - // Need to prefix the buffer with an envelope - env := envelope{ - compressed: r.msg.wasCompressed && r.rw.op.server.reqCompression != nil, - length: uint32(r.buffer.Len()), - } - r.env = r.rw.op.serverEnveloper.encodeEnvelope(env) - r.envRemain = envelopeLen - return nil +func (r requestReader) Close() error { + return r.operation.requestMeta.Body.Close() } // responseWriter wraps the original writer and performs the protocol @@ -958,1212 +658,97 @@ func (r *transformingReader) prepareMessage() error { // When the headers are written, the actual transformation that is // needed is determined and a writer decorator created. type responseWriter struct { - op *operation - delegate http.ResponseWriter - flusher http.Flusher - code int - // has WriteHeader or first call to Write occurred? - headersWritten bool - contentLen int - // have headers actually been flushed to delegate? - headersFlushed bool - // have we already written the end of the stream (error/trailers/etc)? - endWritten bool - respMeta *responseMeta - err error - // wraps op.writer; initialized after headers are written - w io.WriteCloser - // may be used in place of op.writer for protocols that must see - // trailers before writing the first bytes of data (like Connect - // and REST unary). - buf *bytes.Buffer + operation *operation } -func (w *responseWriter) Header() http.Header { - return w.delegate.Header() +func (w responseWriter) Header() http.Header { + return w.operation.response.Header() } -func (w *responseWriter) Write(data []byte) (int, error) { - if !w.headersWritten { +func (w responseWriter) Write(data []byte) (int, error) { + if !w.operation.hasResponse { w.WriteHeader(http.StatusOK) } - if w.err != nil { - return 0, w.err - } - return w.w.Write(data) + return w.operation.write(data, false) } -func (w *responseWriter) WriteHeader(statusCode int) { - if w.headersWritten { +func (w responseWriter) WriteHeader(statusCode int) { + if w.operation.hasResponse { return } - w.headersWritten = true - w.code = statusCode - - if w.endWritten { - // Nothing to do: we already sent RPC error to client. - return - } - - var err error - w.contentLen, err = httpExtractContentLength(w.Header()) - if err != nil { - w.reportError(err) - return - } - w.op.rspContentType = w.Header().Get("Content-Type") - respMeta, processBody, err := w.op.server.protocol.extractProtocolResponseHeaders(statusCode, w.Header()) - if err != nil { - w.reportError(err) - return - } - // snapshot trailer keys - trailerKeys := parseMultiHeader(w.Header().Values("Trailer")) - if len(trailerKeys) > 0 { - respMeta.pendingTrailerKeys = make(headerKeys, len(trailerKeys)) - for _, k := range trailerKeys { - respMeta.pendingTrailerKeys.add(k) - } - w.Header().Del("Trailer") - } - - // Remove other headers that might mess up the next leg - w.Header().Del("Content-Encoding") - w.Header().Del("Accept-Encoding") - - w.respMeta = &respMeta - if respMeta.compression == CompressionIdentity { - respMeta.compression = "" // normalize to empty string - } - if respMeta.compression != "" { - respCompression, ok := w.op.compressors[respMeta.compression] - if !ok { - w.reportError(fmt.Errorf("response indicates unsupported compression encoding %q", respMeta.compression)) - return - } - w.op.client.respCompression = respCompression - w.op.server.respCompression = respCompression - } - if respMeta.codec != "" && respMeta.codec != w.op.server.codec.Name() && - !restHTTPBodyResponse(w.op) { - // unexpected content-type for reply - w.reportError(fmt.Errorf("response uses incorrect codec: expecting %q but instead got %q", w.op.server.codec.Name(), respMeta.codec)) - return - } - - if w.op.hooks.OnServerResponseHeaders != nil { - if err := w.op.hooks.OnServerResponseHeaders(w.op.request.Context(), w.op, statusCode, w.Header()); err != nil { - w.reportError(err) - return - } - } - - if respMeta.end != nil { - // RPC failed immediately. - if processBody != nil { - // We have to wait until we receive the body in order to process the error. - w.w = &errorWriter{ - rw: w, - respMeta: w.respMeta, - processBody: processBody, - buffer: w.op.bufferPool.Get(), - } - return - } - // We can send back error response immediately. - w.flushHeaders() - w.w = noResponseBodyWriter{} - return - } - - if w.op.clientPreparer != nil { - w.op.clientRespNeedsPrep = w.op.clientPreparer.responseNeedsPrep(w.op) - } - if w.op.serverPreparer != nil { - w.op.serverRespNeedsPrep = w.op.serverPreparer.responseNeedsPrep(w.op) - } - - sameCodec := w.op.client.codec.Name() == w.op.server.codec.Name() - // even if body encoding uses same content type, we can't treat them as the same - // (which means re-using encoded data) if either side needs to prep the data first - sameResponseCodec := sameCodec && !w.op.clientRespNeedsPrep && !w.op.serverRespNeedsPrep - mustDecodeResponse := !sameResponseCodec || w.op.hooks.OnServerResponseMessage != nil - - respMsg := message{sameCompression: true, sameCodec: sameResponseCodec} - - if mustDecodeResponse { - // We will have to decode and re-encode, so we need the message type. - messageType, err := w.op.methodConf.resolver.FindMessageByName(w.op.methodConf.descriptor.Output().FullName()) - if err != nil { - w.reportError(err) - return - } - respMsg.msg = messageType.New().Interface() - } - - var endMustBeInHeaders bool - if mustBe, ok := w.op.client.protocol.(clientProtocolEndMustBeInHeaders); ok { - endMustBeInHeaders = mustBe.endMustBeInHeaders() - } - var delegate io.Writer - if endMustBeInHeaders { - // We must await the end before we can write headers, which means we have to - // buffer the entire response. - w.buf = w.op.bufferPool.Get() - delegate = &limitWriter{buf: w.buf, limit: w.op.methodConf.maxMsgBufferBytes, rw: w} - } else { - // We can go ahead and flush headers now. - w.flushHeaders() - delegate = w.delegate - } - - // Now we can define the transformed response body. - if sameResponseCodec && !mustDecodeResponse { - // we do not need to decompress or decode - w.w = &envelopingWriter{rw: w, w: delegate} - } else { - w.w = &transformingWriter{rw: w, msg: &respMsg, w: delegate} - } + w.operation.hasResponse = true + w.operation.responseMeta.StatusCode = statusCode } // Unwrap provides access to the underlying response writer. This plays nicely // with ResponseController functionality introduced in Go 1.21 without actually // depending on Go 1.21. -func (w *responseWriter) Unwrap() http.ResponseWriter { - return w.delegate +func (w responseWriter) Unwrap() http.ResponseWriter { + return w.operation.response } -func (w *responseWriter) Flush() { +func (w responseWriter) Flush() { // We expose this method so server can call it and won't panic // or blow-up when doing type conversion. But it's a no-op // since we automatically flush at message boundaries when // transforming the response body. } -func (w *responseWriter) flushMessage() { - if w.buf != nil { - // we are buffering until we see trailers, so we don't - // want to actually flush the underlying response writer yet - return - } - w.flusher.Flush() -} - -func (w *responseWriter) reportError(err error) { - var end responseEnd - if errors.As(err, &end.err) { - end.httpCode = httpStatusCodeFromRPC(end.err.Code()) - } else { - // TODO: maybe this should be CodeUnknown instead? - end.err = connect.NewError(connect.CodeInternal, err) - end.httpCode = http.StatusBadGateway - } - w.reportEnd(&end) -} - -func (w *responseWriter) reportEnd(end *responseEnd) { - if w.endWritten { - // It's possible this could be called in the event of a cascading error, - // where various receivers all call reportEnd. We will only respect the - // first such call and ignore the others. - return - } - if w.respMeta != nil && len(w.respMeta.pendingTrailers) > 0 && len(end.trailers) == 0 { - // add any pending trailers to the end - end.trailers = w.respMeta.pendingTrailers - } - switch { - case w.headersFlushed: - // write error to body or trailers - w.writeEnd(end, false) - case w.respMeta != nil: - w.respMeta.end = end - w.flushHeaders() - default: - w.respMeta = &responseMeta{end: end} - w.flushHeaders() - } - w.flusher.Flush() - // response is done - w.op.cancel() - w.err = context.Canceled -} - -func (w *responseWriter) flushHeaders() { - if w.headersFlushed { - return // already flushed - } - cliRespMeta := *w.respMeta - cliRespMeta.codec = w.op.client.codec.Name() - cliRespMeta.compression = w.op.client.respCompression.Name() - cliRespMeta.acceptCompression = w.op.compressors.intersection(w.respMeta.acceptCompression) - statusCode := w.op.client.protocol.addProtocolResponseHeaders(cliRespMeta, w.Header()) - hasErr := w.respMeta.end != nil && w.respMeta.end.err != nil - // We only buffer full response for unary operations, so if we have an error, - // we ignore anything already written to the buffer. - if w.buf != nil && !hasErr { - w.Header().Set("Content-Length", strconv.Itoa(w.buf.Len())) - } - // TODO: At this point, if the server was gRPC but the client is not, we may have "Trailer" - // headers reserving the use of various metadata keys in trailers. It would be - // cleaner if they were culled and only remained present for sneding to gRPC clients. - w.delegate.WriteHeader(statusCode) - if w.buf != nil { - if !hasErr { - _, _ = w.buf.WriteTo(w.delegate) - } - w.op.bufferPool.Put(w.buf) - w.buf = nil - } - if w.respMeta.end != nil { - // response is done - w.writeEnd(w.respMeta.end, true) - w.err = context.Canceled - } - - w.headersFlushed = true -} - -func (w *responseWriter) close() { - if !w.headersWritten { - // treat as empty successful response - w.WriteHeader(http.StatusOK) - } - if w.w != nil { - _ = w.w.Close() - } - if w.endWritten { - return // all done - } - if w.respMeta.end != nil { - // got end in headers - w.reportEnd(w.respMeta.end) - return - } - // try to get end from trailers - trailer := httpExtractTrailers(w.Header(), w.respMeta.pendingTrailerKeys) - end, err := w.op.server.protocol.extractEndFromTrailers(w.op, trailer) - if err != nil { - w.reportError(err) - return - } - w.reportEnd(&end) -} - -func (w *responseWriter) writeEnd(end *responseEnd, wasInHeaders bool) { - if end.err == nil && w.op.hooks.OnOperationFinish != nil { - w.op.hooks.OnOperationFinish(w.op.request.Context(), w.op, end.trailers) - } else if end.err != nil && w.op.hooks.OnOperationFail != nil { - if hookErr := w.op.hooks.OnOperationFail(w.op.request.Context(), w.op, end.trailers, end.err); hookErr != nil { - // report the error returned by the hook - end.err = asConnectError(hookErr) - } - } - trailers := w.op.client.protocol.encodeEnd(w.op, end, w.delegate, wasInHeaders) - httpMergeTrailers(w.Header(), trailers) - w.endWritten = true -} - -// envelopingWriter will translate between envelope styles as data is -// written. It does not do any decompressing or deserializing of data. -type envelopingWriter struct { - rw *responseWriter - w io.Writer - - initialized bool - err error - writingEnvelope bool - env envelopeBytes - remainingBytes int - current io.Writer - mustReleaseCurrent bool - currentIsTrailer bool - trailerIsCompressed bool -} - -func (w *envelopingWriter) Write(data []byte) (int, error) { - w.maybeInit() - if w.err != nil { - return 0, w.err - } - if w.remainingBytes == -1 { - n, err := w.current.Write(data) - if err != nil { - w.err = err - } - return n, err - } - - var written int - for { - if w.err != nil { - return written, w.err - } - if len(data) < w.remainingBytes { - // not enough data to trigger next action; ingest data and return - n, err := w.writeBytes(data) - w.remainingBytes -= n - written += n - if err != nil { - w.err = err - } - return written, err - } - // ingest remaining needed and trigger next action - n, err := w.writeBytes(data[:w.remainingBytes]) - written += n - data = data[w.remainingBytes:] - w.remainingBytes -= n - if err != nil { - w.err = err - return written, err - } - if w.writingEnvelope { - if err := w.handleEnvelopeWritten(); err != nil { - return written, err - } - continue - } - - if w.currentIsTrailer { - err := w.handleTrailer() - if err != nil { - return written, err - } - } else { - // flush after each message and reset for next envelope - w.rw.flushMessage() - w.writingEnvelope = true - w.remainingBytes = envelopeLen - } - } -} - -func (w *envelopingWriter) writeBytes(data []byte) (int, error) { - if w.writingEnvelope { - copy(w.env[envelopeLen-w.remainingBytes:], data) - return len(data), nil - } - return w.current.Write(data) -} - -func (w *envelopingWriter) handleEnvelopeWritten() error { - w.writingEnvelope = false - env, err := w.rw.op.serverEnveloper.decodeEnvelope(w.env) - if err != nil { - err = malformedRequestError(err) - w.rw.reportError(err) - return err - } - if env.trailer { - // buffer final message, so we can transform it to a responseEnd - if limit := w.rw.op.methodConf.maxMsgBufferBytes; env.length > limit { - err := bufferLimitError(int64(limit)) - w.rw.reportError(err) - return err - } - buf := w.rw.op.bufferPool.Get() - buf.Grow(int(env.length)) - w.current = buf - w.mustReleaseCurrent = true - w.currentIsTrailer = true - w.trailerIsCompressed = env.compressed - w.remainingBytes = int(env.length) - return nil - } - if w.rw.op.clientEnveloper != nil { - envBytes := w.rw.op.clientEnveloper.encodeEnvelope(env) - _, err := w.w.Write(envBytes[:]) - if err != nil { - w.err = err - return err - } - } - w.current = w.w - w.remainingBytes = int(env.length) - return nil -} - -func (w *envelopingWriter) Close() error { - var buf *bytes.Buffer - if w.mustReleaseCurrent { - var ok bool - buf, ok = w.current.(*bytes.Buffer) - if !ok { - lw, ok := w.current.(*limitWriter) - if ok { - buf = lw.buf - } - } - if buf == nil { - return fmt.Errorf("current sink must be *limitWriter or *bytes.Buffer but instead is %T", w.current) - } - defer w.rw.op.bufferPool.Put(buf) - } - if w.remainingBytes == -1 && w.mustReleaseCurrent && w.err == nil { - // We were buffering in order to measure size and create envelope, - // so do that now. - env := envelope{compressed: w.rw.op.client.respCompression != nil, length: uint32(buf.Len())} - envBytes := w.rw.op.clientEnveloper.encodeEnvelope(env) - _, err := w.w.Write(envBytes[:]) - if err != nil { - w.err = err - return err - } - _, err = buf.WriteTo(w.w) - if err != nil { - w.err = err - return err - } - } - var normalEOF bool - if w.writingEnvelope && w.remainingBytes == envelopeLen { - // We were looking for envelope of next message, but no next message in the stream - normalEOF = true - } - if w.remainingBytes > 0 && !normalEOF { - // Unfinished body! - if w.writingEnvelope { - w.rw.reportError(fmt.Errorf("handler only wrote %d out of %d bytes of message envelope", envelopeLen-w.remainingBytes, envelopeLen)) - } else { - w.rw.reportError(fmt.Errorf("handler failed to write final %d bytes of message", w.remainingBytes)) - } - } - w.remainingBytes = 0 - w.current = nil - w.err = errors.New("body is closed") - return nil -} - -func (w *envelopingWriter) maybeInit() { - if w.initialized { - return - } - w.initialized = true - if w.rw.op.serverEnveloper != nil { - w.writingEnvelope = true - w.remainingBytes = envelopeLen - return - } - if w.rw.op.clientEnveloper == nil { - // just pass everything through - w.remainingBytes = -1 - w.current = w.w - return - } - if w.rw.contentLen == -1 { - // Oof, we have to buffer everything to measure the request size - // to construct an envelope. - w.remainingBytes = -1 - buf := w.rw.op.bufferPool.Get() - w.current = &limitWriter{buf: buf, limit: w.rw.op.methodConf.maxMsgBufferBytes, rw: w.rw} - w.mustReleaseCurrent = true - return - } - // synthesize envelope - var env envelope - env.compressed = w.rw.op.client.respCompression != nil - env.length = uint32(w.rw.contentLen) - envBytes := w.rw.op.clientEnveloper.encodeEnvelope(envelope{}) - _, err := w.w.Write(envBytes[:]) - if err != nil { - w.err = err - return - } - w.current = w.w - w.remainingBytes = envelopeLen -} - -func (w *envelopingWriter) handleTrailer() error { - data, ok := w.current.(*bytes.Buffer) - if !ok { - // should not be possible - return fmt.Errorf("trailer must be *limitWriter but instead is %T", w.current) - } - defer w.rw.op.bufferPool.Put(data) - w.mustReleaseCurrent = false - if w.trailerIsCompressed { - uncompressed := w.rw.op.bufferPool.Get() - defer w.rw.op.bufferPool.Put(uncompressed) - if err := w.rw.op.server.respCompression.decompress(uncompressed, data); err != nil { - return err - } - data = uncompressed - } - end, err := w.rw.op.serverEnveloper.decodeEndFromMessage(w.rw.op, data) - if err != nil { - w.rw.reportError(err) - return err - } - end.wasCompressed = w.trailerIsCompressed - w.rw.reportEnd(&end) - w.err = errors.New("final data already written") - return nil -} - -// transformingWriter transforms the data from the original response -// into a new protocol form as the data is written. It must decompress -// and deserialize each message and then re-serialize (and optionally -// recompress) each message. Since the original incoming protocol may -// have different envelope conventions than the outgoing protocol, it -// also rewrites envelopes. -type transformingWriter struct { - rw *responseWriter - msg *message - w io.Writer - - err error - buffer *bytes.Buffer - expectingBytes int - writingEnvelope bool - latestEnvelope envelope -} - -func (w *transformingWriter) Write(data []byte) (int, error) { - if w.err != nil { - return 0, w.err - } - if w.buffer == nil { - w.reset() - } - if w.expectingBytes == -1 { - if limit := int64(w.rw.op.methodConf.maxMsgBufferBytes); int64(len(data))+int64(w.buffer.Len()) > limit { - err := bufferLimitError(limit) - w.rw.reportError(err) - return 0, err - } - return w.buffer.Write(data) - } - - var written int - // For enveloped protocols, it's possible that data contains - // multiple messages, so we need to process in a loop. - for { - if w.err != nil { - return written, w.err - } - remainingBytes := w.expectingBytes - w.buffer.Len() - if len(data) < remainingBytes { - // not enough data to trigger next action; ingest data and return - w.buffer.Write(data) - written += len(data) - break - } - // ingest remaining needed and trigger next action - w.buffer.Write(data[:remainingBytes]) - written += remainingBytes - data = data[remainingBytes:] - if w.writingEnvelope { - var envBytes envelopeBytes - _, _ = w.buffer.Read(envBytes[:]) - var err error - w.latestEnvelope, err = w.rw.op.serverEnveloper.decodeEnvelope(envBytes) - if err != nil { - err = malformedRequestError(err) - w.rw.reportError(err) - return written, err - } - if limit := w.rw.op.methodConf.maxMsgBufferBytes; w.latestEnvelope.length > limit { - err = bufferLimitError(int64(limit)) - w.rw.reportError(err) - return written, err - } - w.buffer = w.msg.reset(w.rw.op.bufferPool, false, w.latestEnvelope.compressed) - w.buffer.Grow(int(w.latestEnvelope.length)) - w.expectingBytes = int(w.latestEnvelope.length) - w.writingEnvelope = false - } else { - if err := w.flushMessage(); err != nil { - w.rw.reportError(err) - return written, err - } - w.expectingBytes = envelopeLen - w.writingEnvelope = true - } - } - return written, nil -} - -func (w *transformingWriter) Close() error { - if w.expectingBytes == -1 { - if err := w.flushMessage(); err != nil { - w.rw.reportError(err) - } - } else if w.buffer != nil && w.buffer.Len() > 0 { - // Unfinished body! - if w.writingEnvelope { - w.rw.reportError(fmt.Errorf("handler only wrote %d out of %d bytes of message envelope", w.buffer.Len(), envelopeLen)) - } else { - w.rw.reportError(fmt.Errorf("handler only wrote %d out of %d bytes of message", w.buffer.Len(), w.expectingBytes)) - } - } - w.expectingBytes = 0 - w.msg.release(w.rw.op.bufferPool) - w.buffer = nil - w.err = errors.New("body is closed") - return nil -} - -func (w *transformingWriter) flushMessage() error { - if w.latestEnvelope.trailer { - data := w.buffer - if w.latestEnvelope.compressed { - data = w.rw.op.bufferPool.Get() - defer w.rw.op.bufferPool.Put(data) - if err := w.rw.op.server.respCompression.decompress(data, w.buffer); err != nil { - return err +func classifyRequest(req *http.Request) Protocol { + contentTypes := req.Header["Content-Type"] + if len(contentTypes) == 0 { //nolint:nestif + // Empty bodies should still have content types. So this should only + // happen for requests with NO body at all. That's only allowed for + // REST calls and Connect GET calls. + connectVersion := req.Header["Connect-Protocol-Version"] + // If this header is present, the intent is clear. But Connect GET + // requests should actually encode this via query string (see below). + if len(connectVersion) == 1 && connectVersion[0] == "1" { + if req.Method == http.MethodGet { + return ProtocolConnect } + return ProtocolUnknown } - end, err := w.rw.op.serverEnveloper.decodeEndFromMessage(w.rw.op, data) - if err != nil { - w.rw.reportError(err) - return err - } - end.wasCompressed = w.latestEnvelope.compressed - w.rw.reportEnd(&end) - w.err = errors.New("final data already written") - return nil - } - - // We've finished reading the message, so we can manually set the stage - w.msg.markReady() - if w.rw.op.hooks.OnServerResponseMessage != nil { - if err := w.msg.advanceToStage(w.rw.op, stageDecoded); err != nil { - return err - } - compressed := w.msg.wasCompressed && w.rw.op.server.respCompression != nil - if err := w.rw.op.hooks.OnServerResponseMessage(w.rw.op.request.Context(), w.rw.op, w.msg.msg, compressed, w.msg.size); err != nil { - return err - } - } - if err := w.msg.advanceToStage(w.rw.op, stageSend); err != nil { - return err - } - buffer := w.msg.sendBuffer() - if enveloper := w.rw.op.clientEnveloper; enveloper != nil { - env := envelope{ - compressed: w.msg.wasCompressed && w.rw.op.client.respCompression != nil, - length: uint32(buffer.Len()), - } - envBytes := enveloper.encodeEnvelope(env) - if _, err := w.w.Write(envBytes[:]); err != nil { - w.err = err - return err - } - } - if _, err := buffer.WriteTo(w.w); err != nil { - w.err = err - return err - } - // flush after each message - w.rw.flushMessage() - - w.reset() - return nil -} - -func (w *transformingWriter) reset() { - if w.rw.op.serverEnveloper != nil { - w.buffer = w.msg.reset(w.rw.op.bufferPool, false, false) - w.expectingBytes = envelopeLen - w.writingEnvelope = true - } else { - isCompressed := w.rw.respMeta.compression != "" - w.buffer = w.msg.reset(w.rw.op.bufferPool, false, isCompressed) - w.expectingBytes = -1 - } -} - -type errorWriter struct { - rw *responseWriter - respMeta *responseMeta - processBody responseEndUnmarshaller - buffer *bytes.Buffer -} - -func (e *errorWriter) Write(data []byte) (int, error) { - if e.buffer == nil { - return 0, errors.New("writer already closed") - } - if limit := int64(e.rw.op.methodConf.maxMsgBufferBytes); int64(len(data))+int64(e.buffer.Len()) > limit { - err := bufferLimitError(limit) - e.rw.reportError(err) - return 0, err - } - return e.buffer.Write(data) -} - -func (e *errorWriter) Close() error { - if e.respMeta.end == nil { - e.respMeta.end = &responseEnd{} - } - bufferPool := e.rw.op.bufferPool - defer bufferPool.Put(e.buffer) - body := e.buffer - if compressPool := e.rw.op.server.respCompression; compressPool != nil { - uncompressed := bufferPool.Get() - defer bufferPool.Put(uncompressed) - if err := compressPool.decompress(uncompressed, body); err != nil { - // can't really just return an error; we have to encode the - // error into the RPC response, so we populate respMeta.end - if e.respMeta.end.httpCode == 0 || e.respMeta.end.httpCode == http.StatusOK { - e.respMeta.end.httpCode = http.StatusInternalServerError + values := req.URL.Query() + if values.Get("connect") == "v1" { + if req.Method != http.MethodGet { + return ProtocolUnknown } - e.respMeta.end.err = connect.NewError(connect.CodeInternal, fmt.Errorf("failed to decompress body: %w", err)) - body = nil - } else { - body = uncompressed - } - } - if body != nil { - e.processBody(e.rw.op.server.codec, body, e.respMeta.end) - } - e.rw.flushHeaders() - e.buffer = nil - return nil -} - -type noResponseBodyWriter struct { -} - -func (c noResponseBodyWriter) Write([]byte) (int, error) { - return 0, errors.New("final data already written") -} - -func (c noResponseBodyWriter) Close() error { - return nil -} - -type limitWriter struct { - buf *bytes.Buffer - limit uint32 - rw *responseWriter -} - -func (l *limitWriter) Write(data []byte) (n int, err error) { - if uint32(l.buf.Len()+len(data)) > l.limit { - err := bufferLimitError(int64(l.limit)) - l.rw.reportError(err) - return 0, err - } - return l.buf.Write(data) -} - -type hardLimitReader struct { - r io.Reader - limit int64 - read int64 - rw *responseWriter - makeError func(int64) error -} - -func (h *hardLimitReader) Read(data []byte) (n int, err error) { - remaining := h.limit - h.read - if remaining < 0 { - return 0, h.error() - } - if int64(len(data)) > remaining { - // allow reading one byte over the limit, so we can distinguish between - // reading exactly the limit vs. reading too much. - data = data[:remaining+1] - } - n, err = h.r.Read(data) - h.read += int64(n) - if h.read > h.limit && (err == nil || errors.Is(err, io.EOF)) { - err := h.error() - if h.rw != nil { - h.rw.reportError(err) + return ProtocolConnect } - return n, err - } - return n, err -} - -func (h *hardLimitReader) error() error { - if h.makeError == nil { - return bufferLimitError(h.limit) - } - return h.makeError(h.limit) -} - -type messageStage int - -const ( - stageEmpty = messageStage(iota) - // This is the stage of a message after the raw data has been read from the client - // or written by the server handler. - // - // At this point either compressed or data fields of the message will be populated - // (depending on whether message data was compressed or not). - stageRead - // This is the stage of a message after the data has been decompressed and decoded. - // - // The msg field of the message is usable at this point. The compressed and data - // fields of the message will remain populated if their values can be re-used. - stageDecoded - // This is the stage of a message after it has been re-encoded and re-compressed - // and is ready to send (to be read by server handler or to be written to client). - // - // Either compressed or data fields of the message will be populated (depending on - // whether message data was compressed or not). - stageSend -) - -func (s messageStage) String() string { - switch s { - case stageEmpty: - return "empty" - case stageRead: - return "read" - case stageDecoded: - return "decoded" - case stageSend: - return "send" - default: - return "unknown" - } -} - -// message represents a single message in an RPC stream. It can be re-used in a stream, -// so we only allocate one and then re-use it for subsequent messages (if stream has -// more than one). -type message struct { - // true if this is a request message read from the client; false if - // this is a response message written by the server. - isRequest bool - - // flags indicating if compressed and data should be preserved after use. - sameCompression, sameCodec bool - // wasCompressed is true if the data was originally compressed; this can - // be false in a stream when the stream envelope's compressed bit is unset. - wasCompressed bool - // original size of the message on the wire, in bytes - size int - - stage messageStage - - // compressed is the compressed bytes; may be nil if the contents have - // already been decompressed into the data field. - compressed *bytes.Buffer - // data is the serialized but uncompressed bytes; may be nil if the - // contents have not yet been decompressed or have been de-serialized - // into the msg field. - data *bytes.Buffer - // msg is the plain message; not valid unless stage is stageDecoded - msg proto.Message -} - -// sendBuffer returns the buffer to use to read message data to be sent. -func (m *message) sendBuffer() *bytes.Buffer { - if m.stage != stageSend { - return nil - } - if m.wasCompressed { - return m.compressed - } - return m.data -} - -// release releases all buffers associated with message to the given pool. -func (m *message) release(pool *bufferPool) { - if m.compressed != nil { - pool.Put(m.compressed) - } - if m.data != nil && m.data != m.compressed { - pool.Put(m.data) - } - m.data, m.compressed, m.msg = nil, nil, nil -} - -// reset arranges for message to be re-used by making sure it has -// a compressed buffer that is ready to accept bytes and no data -// buffer. -func (m *message) reset(pool *bufferPool, isRequest, isCompressed bool) *bytes.Buffer { - m.stage = stageEmpty - m.size = -1 - m.isRequest = isRequest - m.wasCompressed = isCompressed - // we only need one buffer to start, so put - // a non-nil buffer into buffer1 and if we - // have a second non-nil buffer, release it - buffer1, buffer2 := m.compressed, m.data - if buffer1 == nil && buffer2 != nil { - buffer1, buffer2 = buffer2, buffer1 - } - if buffer2 != nil && buffer2 != buffer1 { - pool.Put(buffer2) - } - if buffer1 == nil { - buffer1 = pool.Get() - } else { - buffer1.Reset() - } - if isCompressed { - m.compressed, m.data = buffer1, nil - } else { - m.data, m.compressed = buffer1, nil - } - return buffer1 -} - -func (m *message) markReady() { - m.stage = stageRead - if m.wasCompressed { - m.size = m.compressed.Len() - } else { - m.size = m.data.Len() - } -} - -func (m *message) advanceToStage(op *operation, newStage messageStage) error { - if m.stage == stageEmpty { - return errors.New("message has not yet been read") - } - if m.stage > newStage { - return fmt.Errorf("cannot advance message stage backwards: stage %v > target %v", m.stage, newStage) - } - - if newStage == m.stage { - return nil // no-op + return ProtocolREST } - if newStage == stageSend && m.sameCodec && - (!m.wasCompressed || (m.wasCompressed && m.sameCompression)) { - // We can re-use existing buffer; no more action to take. - m.stage = newStage - return nil // no more action to take + if len(contentTypes) > 1 { + return ProtocolUnknown // Don't allow this. } - + contentType := contentTypes[0] switch { - case m.stage == stageRead && newStage == stageSend: - if !m.sameCodec { - // If the codec is different we have to fully decode the message and - // then fully re-encode. - if err := m.advanceToStage(op, stageDecoded); err != nil { - return err - } - return m.advanceToStage(op, newStage) - } - - // We must de-compress and re-compress the data. - if err := m.decompress(op, false); err != nil { - return err - } - if err := m.compress(op); err != nil { - return err - } - m.stage = newStage - return nil - - case m.stage == stageRead && newStage == stageDecoded: - if m.wasCompressed { - if err := m.decompress(op, m.sameCompression && m.sameCodec); err != nil { - return err - } - } - if err := m.decode(op, m.sameCodec); err != nil { - return err - } - m.stage = newStage - return nil - - case m.stage == stageDecoded && newStage == stageSend: - if !m.sameCodec { - // re-encode - if err := m.encode(op); err != nil { - return err + case strings.HasPrefix(contentType, "application/connect+"): + return ProtocolConnect + case contentType == "application/grpc", strings.HasPrefix(contentType, "application/grpc+"): + return ProtocolGRPC + case contentType == "application/grpc-web", strings.HasPrefix(contentType, "application/grpc-web+"): + return ProtocolGRPCWeb + case strings.HasPrefix(contentType, "application/"): + connectVersion := req.Header["Connect-Protocol-Version"] + if len(connectVersion) == 1 && connectVersion[0] == "1" { + if req.Method == http.MethodGet { + return ProtocolConnect } + return ProtocolConnect } - if m.wasCompressed { - // re-compress - if err := m.compress(op); err != nil { - return err + values := req.URL.Query() + if values.Get("connect") == "v1" { + if req.Method != http.MethodGet { + return ProtocolUnknown } + return ProtocolConnect } - m.stage = newStage - return nil - - default: - return fmt.Errorf("unknown stage transition: stage %v to target %v", m.stage, newStage) - } -} - -// decompress will decompress data in m.compressed into m.data, -// acquiring a new buffer from op's bufferPool if necessary. -// If saveBuffer is true, m.compressed will be unmodified on -// return; otherwise, the buffer will be released to op's -// bufferPool and the field set to nil. -// -// This method should not be called directly as the message's -// buffers could get out of sync with its stage. It should -// only be called from m.advanceToStage. -func (m *message) decompress(op *operation, saveBuffer bool) error { - var pool *compressionPool - if m.isRequest { - pool = op.client.reqCompression - } else { - pool = op.client.respCompression - } - if pool == nil { - // identity compression, so nothing to do - m.data = m.compressed - if !saveBuffer { - m.compressed = nil - } - return nil - } - - var src *bytes.Buffer - if saveBuffer { - // we allocate a new buffer, but not the underlying byte slice - // (it's cheaper than re-compressing later) - src = bytes.NewBuffer(m.compressed.Bytes()) - } else { - src = m.compressed - } - m.data = op.bufferPool.Get() - if err := pool.decompress(m.data, src); err != nil { - return err - } - if !saveBuffer { - op.bufferPool.Put(m.compressed) - m.compressed = nil - } - return nil -} - -// compress will compress data in m.data into m.compressed, -// acquiring a new buffer from op's bufferPool if necessary. -// -// This method should not be called directly as the message's -// buffers could get out of sync with its stage. It should -// only be called from m.advanceToStage. -func (m *message) compress(op *operation) error { - var pool *compressionPool - if m.isRequest { - pool = op.server.reqCompression - } else { - pool = op.server.respCompression - } - if pool == nil { - // identity compression, so nothing to do - m.compressed = m.data - m.data = nil - return nil - } - - m.compressed = op.bufferPool.Get() - if err := pool.compress(m.compressed, m.data); err != nil { - return err - } - op.bufferPool.Put(m.data) - m.data = nil - return nil -} - -// decode will unmarshal data in m.data into m.msg. If -// saveBuffer is true, m.data will be unmodified on return; -// otherwise, the buffer will be released to op's bufferPool -// and the field set to nil. -// -// This method should not be called directly as the message's -// buffers could get out of sync with its stage. It should -// only be called from m.advanceToStage. -func (m *message) decode(op *operation, saveBuffer bool) error { - switch { - case m.isRequest && op.clientReqNeedsPrep: - return op.clientPreparer.prepareUnmarshalledRequest(op, m.data.Bytes(), m.msg) - case !m.isRequest && op.serverRespNeedsPrep: - return op.serverPreparer.prepareUnmarshalledResponse(op, m.data.Bytes(), m.msg) - } - - var codec Codec - if m.isRequest { - codec = op.client.codec - } else { - codec = op.server.codec - } - - if err := codec.Unmarshal(m.data.Bytes(), m.msg); err != nil { - return err - } - if !saveBuffer { - op.bufferPool.Put(m.data) - m.data = nil - } - return nil -} - -// encode will marshal data in m.msg into m.data. -// -// This method should not be called directly as the message's -// buffers could get out of sync with its stage. It should -// only be called from m.advanceToStage. -func (m *message) encode(op *operation) error { - buf := op.bufferPool.Get() - var data []byte - var err error - - switch { - case m.isRequest && op.serverReqNeedsPrep: - data, err = op.serverPreparer.prepareMarshalledRequest(op, buf.Bytes(), m.msg, op.request.Header) - case !m.isRequest && op.clientRespNeedsPrep: - data, err = op.clientPreparer.prepareMarshalledResponse(op, buf.Bytes(), m.msg, op.writer.Header()) + // REST usually uses application/json, but use of google.api.HttpBody means it could + // also use *any* content-type. + fallthrough default: - var codec Codec - if m.isRequest { - codec = op.server.codec - } else { - codec = op.client.codec - } - data, err = codec.MarshalAppend(buf.Bytes(), m.msg) - } - - if err != nil { - op.bufferPool.Put(buf) - m.data = nil - return err - } - m.data = op.bufferPool.Wrap(data, buf) - return nil -} - -type errorFlusher interface { - FlushError() error -} - -type flusherNoError struct { - f errorFlusher -} - -func (f flusherNoError) Flush() { - _ = f.f.FlushError() -} - -func asFlusher(respWriter http.ResponseWriter) http.Flusher { - // This is similar to how http.ResponseController.Flush works. But - // we can't use that since it isn't available prior to Go 1.21. - for { - switch typedWriter := respWriter.(type) { - case http.Flusher: - return typedWriter - case errorFlusher: - return flusherNoError{f: typedWriter} - case interface{ Unwrap() http.ResponseWriter }: - respWriter = typedWriter.Unwrap() - default: - return nil - } + return ProtocolREST } } - -type requestLine struct { - method, path, queryString, httpVersion string -} - -func (l *requestLine) fromRequest(req *http.Request) { - l.method = req.Method - l.path = req.URL.Path - l.queryString = req.URL.RawQuery - l.httpVersion = req.Proto -} diff --git a/handler_bench_test.go b/handler_bench_test.go index b38caed..f0269b1 100644 --- a/handler_bench_test.go +++ b/handler_bench_test.go @@ -230,9 +230,8 @@ func BenchmarkServeHTTP(b *testing.B) { req.ProtoMajor = 2 req.ProtoMinor = 0 req.Header.Set("Content-Type", "application/grpc+proto") - req.Header.Set("Grpc-Encoding", "gzip") req.Header.Set("Grpc-Timeout", "1S") - req.Header.Set("Grpc-Accept-Encoding", "gzip") + req.Header.Set("Grpc-Accept-Encoding", "identity") b.StartTimer() b.ReportAllocs() diff --git a/handler_test.go b/handler_test.go index 5a1ae97..dc23512 100644 --- a/handler_test.go +++ b/handler_test.go @@ -15,27 +15,12 @@ package vanguard import ( - "context" - "fmt" - "io" "net/http" "net/http/httptest" "testing" - "connectrpc.com/connect" - testv1 "connectrpc.com/vanguard/internal/gen/vanguard/test/v1" "connectrpc.com/vanguard/internal/gen/vanguard/test/v1/testv1connect" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "google.golang.org/genproto/googleapis/api/httpbody" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/emptypb" - "google.golang.org/protobuf/types/known/wrapperspb" -) - -const ( - testDataString = "abc def ghi" - testCompressedDataString = "nop qrs tuv" // rot13 of above ) func TestHandler_Errors(t *testing.T) { @@ -408,7 +393,7 @@ func TestHandler_Errors(t *testing.T) { } } -func TestHandler_PassThrough(t *testing.T) { +/*func TestHandler_PassThrough(t *testing.T) { t.Parallel() // These cases don't do any transformation and just pass through to the // underlying handler. @@ -734,382 +719,4 @@ func TestHandler_PassThrough(t *testing.T) { } }) } -} - -func TestMessage_AdvanceStage(t *testing.T) { - t.Parallel() - // Tests the state machine for message. - - type testEnviron struct { - abcCodec, xyzCodec *fakeCodec - abcCompression, xyzCompression, otherCompression *fakeCompression - op *operation - } - newTestEnviron := func(isRequest bool) *testEnviron { - abcCodec := &fakeCodec{name: "abc"} - xyzCodec := &fakeCodec{name: "xyz"} - abcCompression := &fakeCompression{name: "abc"} - xyzCompression := &fakeCompression{name: "xyz"} - otherCompression := &fakeCompression{name: "other"} - var clientCodec, serverCodec Codec - var clientReqComp, serverReqComp, respComp *compressionPool - if isRequest { - clientCodec = abcCodec - serverCodec = xyzCodec - clientReqComp = abcCompression.newPool() - serverReqComp = xyzCompression.newPool() - respComp = otherCompression.newPool() - } else { - clientCodec = xyzCodec - serverCodec = abcCodec - clientReqComp = xyzCompression.newPool() - serverReqComp = xyzCompression.newPool() - respComp = abcCompression.newPool() - } - op := &operation{ - bufferPool: &bufferPool{}, - client: clientProtocolDetails{ - codec: clientCodec, - reqCompression: clientReqComp, - respCompression: respComp, - }, - server: serverProtocolDetails{ - codec: serverCodec, - reqCompression: serverReqComp, - respCompression: respComp, - }, - } - return &testEnviron{ - abcCodec: abcCodec, - xyzCodec: xyzCodec, - abcCompression: abcCompression, - xyzCompression: xyzCompression, - otherCompression: otherCompression, - op: op, - } - } - resetEnv := func(env *testEnviron) { - env.abcCodec.marshalCalls = 0 - env.abcCodec.unmarshalCalls = 0 - env.xyzCodec.marshalCalls = 0 - env.xyzCodec.unmarshalCalls = 0 - env.abcCompression.compressorCalls = 0 - env.abcCompression.decompressorCalls = 0 - env.xyzCompression.compressorCalls = 0 - env.xyzCompression.decompressorCalls = 0 - env.otherCompression.compressorCalls = 0 - env.otherCompression.decompressorCalls = 0 - } - type expectedCounts struct { - abcMarshalCalls int - abcUnmarshalCalls int - xyzMarshalCalls int - xyzUnmarshalCalls int - abcCompressCalls int - abcDecompressCalls int - xyzCompressCalls int - xyzDecompressCalls int - } - checkCounts := func(t *testing.T, isRequest bool, env *testEnviron, counts expectedCounts) { - t.Helper() - if !isRequest { - // for responses, compression for both client and server is the same (abc) - counts.abcCompressCalls += counts.xyzCompressCalls - counts.abcDecompressCalls += counts.xyzDecompressCalls - counts.xyzCompressCalls = 0 - counts.xyzDecompressCalls = 0 - } - assert.Equal(t, counts.abcMarshalCalls, env.abcCodec.marshalCalls) - assert.Equal(t, counts.abcUnmarshalCalls, env.abcCodec.unmarshalCalls) - assert.Equal(t, counts.xyzMarshalCalls, env.xyzCodec.marshalCalls) - assert.Equal(t, counts.xyzUnmarshalCalls, env.xyzCodec.unmarshalCalls) - assert.Equal(t, counts.abcCompressCalls, env.abcCompression.compressorCalls) - assert.Equal(t, counts.abcDecompressCalls, env.abcCompression.decompressorCalls) - assert.Equal(t, counts.xyzCompressCalls, env.xyzCompression.compressorCalls) - assert.Equal(t, counts.xyzDecompressCalls, env.xyzCompression.decompressorCalls) - assert.Zero(t, env.otherCompression.compressorCalls) - assert.Zero(t, env.otherCompression.decompressorCalls) - } - - testCases := []struct { - name string - createMessage func() *message - decodedToSend expectedCounts - decodedToSendIfCompressed *expectedCounts - readToSend expectedCounts - readToSendIfCompressed *expectedCounts - }{ - { - name: "same codec, same compression", - createMessage: func() *message { return &message{sameCodec: true, sameCompression: true} }, - // no calls necessary since client payload can be re-used - decodedToSend: expectedCounts{}, - readToSend: expectedCounts{}, - }, - { - name: "same codec, different compression", - createMessage: func() *message { return &message{sameCodec: true} }, - // no calls necessary for uncompressed since payload can be re-used, - // but we have to decompress/recompress for compressed payloads - decodedToSend: expectedCounts{}, - decodedToSendIfCompressed: &expectedCounts{ - xyzCompressCalls: 1, - }, - readToSend: expectedCounts{}, - readToSendIfCompressed: &expectedCounts{ - abcDecompressCalls: 1, - xyzCompressCalls: 1, - }, - }, - { - name: "different codec", - createMessage: func() *message { return &message{} }, - // we must re-encode and re-compress - decodedToSend: expectedCounts{ - xyzMarshalCalls: 1, - }, - decodedToSendIfCompressed: &expectedCounts{ - xyzMarshalCalls: 1, - xyzCompressCalls: 1, - }, - readToSend: expectedCounts{ - abcUnmarshalCalls: 1, - xyzMarshalCalls: 1, - }, - readToSendIfCompressed: &expectedCounts{ - abcDecompressCalls: 1, - abcUnmarshalCalls: 1, - xyzMarshalCalls: 1, - xyzCompressCalls: 1, - }, - }, - } - - for _, compressed := range []bool{true, false} { - compressed := compressed - t.Run(fmt.Sprintf("compressed:%v", compressed), func(t *testing.T) { - t.Parallel() - for _, isRequest := range []bool{true, false} { - isRequest := isRequest - t.Run(fmt.Sprintf("request:%v", isRequest), func(t *testing.T) { - t.Parallel() - for _, testCase := range testCases { - testCase := testCase - t.Run(testCase.name, func(t *testing.T) { - t.Parallel() - - originalData := testDataString - if compressed { - originalData = testCompressedDataString - } - - env := newTestEnviron(isRequest) - msg := testCase.createMessage() - msg.msg = &wrapperspb.StringValue{} - buffer := msg.reset(env.op.bufferPool, isRequest, compressed) - checkStageEmpty(t, msg, compressed) - - buffer.WriteString(originalData) - msg.stage = stageRead - checkStageRead(t, msg, compressed) - - err := msg.advanceToStage(env.op, stageDecoded) - require.NoError(t, err) - // read -> decoded must always decode (and possibly first decompress) - counts := expectedCounts{ - abcUnmarshalCalls: 1, - } - if compressed { - counts.abcDecompressCalls = 1 - } - checkCounts(t, isRequest, env, counts) - checkStageDecoded(t, msg) - - resetEnv(env) - err = msg.advanceToStage(env.op, stageSend) - require.NoError(t, err) - counts = testCase.decodedToSend - if compressed && testCase.decodedToSendIfCompressed != nil { - counts = *testCase.decodedToSendIfCompressed - } - checkCounts(t, isRequest, env, counts) - checkStageSend(t, msg, compressed) - - // Re-create message and this time go directly from read to send - msg = testCase.createMessage() - msg.msg = &wrapperspb.StringValue{} - buffer = msg.reset(env.op.bufferPool, isRequest, compressed) - buffer.WriteString(originalData) - msg.stage = stageRead - - resetEnv(env) - err = msg.advanceToStage(env.op, stageSend) - require.NoError(t, err) - counts = testCase.readToSend - if compressed && testCase.readToSendIfCompressed != nil { - counts = *testCase.readToSendIfCompressed - } - checkCounts(t, isRequest, env, counts) - checkStageSend(t, msg, compressed) - }) - } - }) - } - }) - } -} - -func checkStageEmpty(t *testing.T, msg *message, compressed bool) { - t.Helper() - require.Equal(t, stageEmpty, msg.stage) - if compressed { - require.NotNil(t, msg.compressed) - require.Zero(t, msg.compressed.Len()) - require.Nil(t, msg.data) - } else { - require.Nil(t, msg.compressed) - require.NotNil(t, msg.data) - require.Zero(t, msg.data.Len()) - } - // Should not be possible to advance from empty. - require.Error(t, msg.advanceToStage(nil, stageRead)) - require.Error(t, msg.advanceToStage(nil, stageDecoded)) - require.Error(t, msg.advanceToStage(nil, stageSend)) -} - -func checkStageRead(t *testing.T, msg *message, compressed bool) { - t.Helper() - require.Equal(t, stageRead, msg.stage) - if compressed { - require.NotNil(t, msg.compressed) - require.Equal(t, testCompressedDataString, msg.compressed.String()) - require.Nil(t, msg.data) - } else { - require.Nil(t, msg.compressed) - require.NotNil(t, msg.data) - require.Equal(t, testDataString, msg.data.String()) - } - // Should not be possible to go backwards. - require.Error(t, msg.advanceToStage(nil, stageEmpty)) -} - -func checkStageDecoded(t *testing.T, msg *message) { - t.Helper() - require.Equal(t, stageDecoded, msg.stage) - require.Equal(t, testDataString, msg.msg.(*wrapperspb.StringValue).Value) - // Should not be possible to go backwards. - require.Error(t, msg.advanceToStage(nil, stageRead)) - require.Error(t, msg.advanceToStage(nil, stageEmpty)) -} - -func checkStageSend(t *testing.T, msg *message, compressed bool) { - t.Helper() - if compressed { - require.NotNil(t, msg.compressed) - require.Equal(t, testCompressedDataString, msg.compressed.String()) - // can't assert anything about m.data: if we didn't have to do - // anything to get to send (same codec, same compression), we - // won't have done anything to it; but if we had to re-encode - // and re-compress, it would get released and set to nil - } else { - require.Nil(t, msg.compressed) - require.NotNil(t, msg.data) - require.Equal(t, testDataString, msg.data.String()) - } - require.Equal(t, stageSend, msg.stage) - // Should not be possible to go backwards. - require.Error(t, msg.advanceToStage(nil, stageDecoded)) - require.Error(t, msg.advanceToStage(nil, stageRead)) - require.Error(t, msg.advanceToStage(nil, stageEmpty)) -} - -type fakeCodec struct { - name string - marshalCalls, unmarshalCalls int -} - -func (f *fakeCodec) Name() string { - return f.name -} - -func (f *fakeCodec) MarshalAppend(b []byte, msg proto.Message) ([]byte, error) { - f.marshalCalls++ - val := msg.(*wrapperspb.StringValue).Value - return append(b, ([]byte)(val)...), nil -} - -func (f *fakeCodec) Unmarshal(b []byte, msg proto.Message) error { - f.unmarshalCalls++ - msg.(*wrapperspb.StringValue).Value = string(b) - return nil -} - -type fakeCompression struct { - name string - compressorCalls, decompressorCalls int - reader io.Reader - writer io.Writer -} - -func (f *fakeCompression) newPool() *compressionPool { - return newCompressionPool( - f.name, - func() connect.Compressor { - return (*fakeCompressor)(f) - }, - func() connect.Decompressor { - return (*fakeDecompressor)(f) - }, - ) -} - -type fakeCompressor fakeCompression - -func (f *fakeCompressor) Write(p []byte) (n int, err error) { - rot13(p) - return f.writer.Write(p) -} - -func (f *fakeCompressor) Close() error { - return nil -} - -func (f *fakeCompressor) Reset(writer io.Writer) { - (*fakeCompression)(f).compressorCalls++ - f.writer = writer -} - -type fakeDecompressor fakeCompression - -func (f *fakeDecompressor) Read(p []byte) (n int, err error) { - n, err = f.reader.Read(p) - rot13(p[:n]) - return n, err -} - -func (f *fakeDecompressor) Close() error { - return nil -} - -func (f *fakeDecompressor) Reset(reader io.Reader) error { - (*fakeCompression)(f).decompressorCalls++ - f.reader = reader - return nil -} - -func rot13(data []byte) { - for index, char := range data { - if char >= 'A' && char <= 'Z' { - char += 13 - if char > 'Z' { - char -= 26 - } - } else if char >= 'a' && char <= 'z' { - char += 13 - if char > 'z' { - char -= 26 - } - } - data[index] = char - } -} +}*/ diff --git a/protocol.go b/protocol.go index 9ff93cd..8f02cc6 100644 --- a/protocol.go +++ b/protocol.go @@ -18,13 +18,9 @@ import ( "bytes" "fmt" "io" - "net/http" - "net/textproto" + "net/url" "strings" "time" - - "connectrpc.com/connect" - "google.golang.org/protobuf/proto" ) const envelopeLen = 5 @@ -92,298 +88,95 @@ func (p Protocol) String() string { } } -func (p Protocol) serverHandler(op *operation) serverProtocolHandler { - switch p { - case ProtocolConnect: - if op.methodConf.streamType == connect.StreamTypeUnary { - return connectUnaryServerProtocol{} - } - return connectStreamServerProtocol{} - case ProtocolGRPC: - return grpcServerProtocol{} - case ProtocolGRPCWeb: - return grpcWebServerProtocol{} - case ProtocolREST: - return restServerProtocol{} - default: - return nil - } -} - // clientProtocolHandler handles the protocol used by the client. // This allows the middleware to understand the incoming request // and to send valid responses to the client. type clientProtocolHandler interface { - protocol() Protocol - acceptsStreamType(*operation, connect.StreamType) bool - - // Extracts relevant request metadata from the given headers to - // determine the codec (aka sub-format), compression (aka encoding), - // timeout, etc. The relevant headers are interpreted into the - // returned requestMeta and also *removed* from the given headers. - extractProtocolRequestHeaders(*operation, http.Header) (requestMeta, error) - - // TODO: The following two methods were meant to be agnostic as to whether - // the protocol is a streaming protocol or a unary one. The operations - // are split because a streaming protocol cannot change headers or the - // status code from encodeEnd because headers and status have already - // been written. This requires unary implementations to do extra - // handling of errors in addProtocolResponseHeaders, awkwardly separated - // from the handling in encodeEnd. Worse, if an unexpected error happens - // in encodeEnd, it is too late to change status code or headers. We - // could possibly combine these for unary-only protocols to make the - // implementation simpler. If we do, we'd need a way to swap protocol - // handlers -- so that a REST handler can swap itself out for a - // unary- or streaming-specific implementation once the method is known - // (for streaming upload/download endpoints or in future general support - // for server streaming endpoints). - - // Encodes the given responseMeta as headers into the given target - // headers. If provided, allowedCompression should be used instead - // of meta.allowedCompression when adding "accept-encoding" headers. - // - // The return value is the status code that should be sent to the - // client. If the status code written was anything other than - // 200 OK, the given meta will include a responseEnd that has that - // original code. - // - // Note that this method's responsibility is to decide the status - // code and set headers. When meta.end is non-nil, encodeEnd will - // also be called, which is where a response body and trailers - // can be written. - addProtocolResponseHeaders(meta responseMeta, target http.Header) int - // Encodes the given final disposition of the RPC to the given - // writer. It can also return any trailers to add to the response. - // Some protocols may ignore the writer; some will return no - // trailers. - // - // The given codec represents the sub-format that the client used - // (which could be used, for example, to encode the error). - // - // The wasInHeaders flag indicates that end was signalled in the - // response headers. For some protocols, like gRPC and gRPC-Web, - // this is the difference between a trailers-only response and a - // normal response (where the end is signalled in the response - // body or trailers, not headers). When this is true, the end was - // also already provided to addProtocolResponseHeaders. - encodeEnd(op *operation, end *responseEnd, writer io.Writer, wasInHeaders bool) http.Header - - // String returns a human-readable name/description of protocol. - String() string -} - -// clientProtocolSupportsGet is an optional interface implemented by -// clientProtocolHandler instances that can support the GET HTTP method. -type clientProtocolAllowsGet interface { - allowsGetRequests(*methodConfig) bool + // Protocol returns the protocol used by the client. + Protocol() Protocol + // DecodeRequestHeader extracts request headers into the given requestMeta. + DecodeRequestHeader(*requestMeta) error + // PrepareRequestMessage prepares the given messageBuffer to receive the + // message data from the reader. + PrepareRequestMessage(*messageBuffer, *requestMeta) error + // EncodeRequestHeader given responseMeta as headers into the given target. + EncodeRequestHeader(*responseMeta) error + // PrepareResponseMessage prepares the given messageBuffer to receive the + // message data from the writer. + PrepareResponseMessage(*messageBuffer, *responseMeta) error + // EncodeResponseTrailer encodes the given trailers into the given buffer and headers. + EncodeResponseTrailer(*bytes.Buffer, *responseMeta) error + // Encode the error into the given buffer. + EncodeError(*bytes.Buffer, *responseMeta, error) } // serverProtocolHandler handles the protocol used by the server. // This allows the middleware to send a valid request to the server // and understand the responses it sends. type serverProtocolHandler interface { - protocol() Protocol - - // Encodes the given requestMeta has headers into the given target - // headers. If non-nil, allowedCompression should be used instead - // of meta.allowedCompression when adding "accept-encoding" headers. - addProtocolRequestHeaders(meta requestMeta, target http.Header) - // Returns the response metadata from the headers. - // - // If the response meta's end field is set (i.e. headers indicate RPC - // is over), but the protocol needs to read the response body to - // populate it, it should return a non-nil function as the second - // returned value. This generally only occurs when the RPC fails and - // the body includes error information. If the body includes response - // message data, handlers should NOT set a non-nil end. - // - // If the headers include trailers (such as in the Connect unary - // protocol), but the RPC isn't quite over because the message data - // must still be read from the response body, the handler should - // instead populate the pendingTrailers field of meta. Note that - // this field is ignored if the end field is non-nil. So if the - // end is set to non-nil, the handler should store trailers there. - // - // This function will receive the server's codec (optionally used - // to encode other messages and could be used to decode the error - // body), the body, and a pointer to the responseEnd which should - // be populated with the details. If the response body was compressed, - // it will be decompressed before it is provided to the given function. - extractProtocolResponseHeaders(statusCode int, headers http.Header) (responseMeta, responseEndUnmarshaller, error) - // Called at end of RPC if responseEnd has not been returned by - // extractProtocolResponseHeaders or from an enveloped message - // in the response body whose trailer bit is set. - extractEndFromTrailers(*operation, http.Header) (responseEnd, error) - - // String returns a human-readable name/description of protocol. - String() string -} - -// responseEndUnmarshaller populates the given responseEnd by unmarshalling -// information from the given buffer. If unmarshalling needs to know the -// server's codec, it also provided as the first argument. -type responseEndUnmarshaller func(Codec, *bytes.Buffer, *responseEnd) - -// clientProtocolEndMustBeInHeaders is an optional interface implemented -// by clientProtocolHandler instances to indicate if the end of an RPC -// must be indicated in response headers (not trailers or in the body). -// If a protocol handler does not implement this, it is assumed to be -// false. -type clientProtocolEndMustBeInHeaders interface { - endMustBeInHeaders() bool -} - -// envelopedProtocolHandler is an optional interface implemented -// by clientProtocolHandler and serverProtocolHandler instances -// whose protocol uses an envelope around messages. -type envelopedProtocolHandler interface { - decodeEnvelope(envelopeBytes) (envelope, error) - encodeEnvelope(envelope) envelopeBytes -} - -// serverEnvelopedProtocolHandler is an optional interface implemented -// by serverProtocolHandler instances whose protocol uses an envelope -// around messages. -type serverEnvelopedProtocolHandler interface { - envelopedProtocolHandler - // If a stream includes an envelope with the trailer bit - // set, this is called to parse the message contents. The - // given reader will be decompressed (even if the envelope - // had its compressed bit set). - // - // The given codec represents the sub-format used to send - // the request to the server (which may be used to decode - // the error). - decodeEndFromMessage(*operation, io.Reader) (responseEnd, error) -} - -// requestLineBuilder is an optional interface implemented by -// serverProtocolHandler instances whose HTTP request line -// needs to be computed in a custom manner. By default (for -// protocols that do not implement this), the request line is -// "POST //". -// -// This is necessary for REST and Connect GET requests, which -// can encode parts of the request data into the URI path -// or query string parameters. -type requestLineBuilder interface { - // Returns true if the request message must be known in order - // to compute the request line. - requiresMessageToProvideRequestLine(*operation) bool - // Computes the components of the request line and also - // indicates if the request will include a body or not. The - // body can be omitted for requests where *all* request - // information is supplied in the request line. - requestLine(op *operation, req proto.Message) (urlPath, queryParams, method string, includeBody bool, err error) -} - -// clientBodyPreparer is an optional interface implemented by -// clientProtocolHandler instances whose request messages may -// need to be assembled from sources other than just decoding -// the request or response body. -type clientBodyPreparer interface { - // Returns true if the request message needs to be prepared. - // If it can simply be read and decoded from the request body - // then it does not need to be prepared. But if the message - // data must be merged with parts of the request path or - // query param (etc), it must return true. - requestNeedsPrep(*operation) bool - // Combines the given request body data with other info to - // produce a request message. The given bytes represent the - // uncompressed request body. The given message should be - // populated if/when the method returns nil. - prepareUnmarshalledRequest(op *operation, src []byte, target proto.Message) error - // Returns true if the response message needs to be prepared. - // If it can simply be encoded into the response body then it - // does not need to be prepared. But if the message data must - // be wrapped or some parts discarded, the method must return - // true. - responseNeedsPrep(*operation) bool - // Produces the request body for the given message. The data - // should be appended to the given slice (which will be empty - // but have capacity to accept data) to reduce allocations. - // The given headers may be updated, like if a message has - // content that must go into headers (such as recording a - // custom content-type for uses of google.api.HttpBody). - prepareMarshalledResponse(op *operation, base []byte, src proto.Message, headers http.Header) ([]byte, error) -} - -// serverBodyPreparer is an optional interface implemented by -// serverProtocolHandler instances whose request messages may -// need to be assembled from sources other than just decoding -// the request or response body. -type serverBodyPreparer interface { - // These methods are reversed from clientBodyPreparer: for the - // server side, we have a request message and must produce a - // body; and we have a response body and must extract from that - // a message. - requestNeedsPrep(*operation) bool - prepareMarshalledRequest(op *operation, base []byte, src proto.Message, headers http.Header) ([]byte, error) - responseNeedsPrep(*operation) bool - prepareUnmarshalledResponse(op *operation, src []byte, target proto.Message) error + // Protocol returns the protocol used by the server. + Protocol() Protocol + // EncodeRequestHeader encodes the given requestMeta as headers into the given target. + EncodeRequestHeader(*requestMeta) error + // PrepareRequestMessage prepares the given messageBuffer to receive the + // message data from the reader. + PrepareRequestMessage(*messageBuffer, *requestMeta) error + // DecodeResponseHeader extracts response headers into the given responseMeta. + DecodeRequestHeader(*responseMeta) error + // PrepareResponseMessage prepares the given messageMeta to receive the + // message data from the writer. + PrepareResponseMessage(*messageBuffer, *responseMeta) error + // DecodeResponseTrailer decodes the given trailers into the given buffer and headers. + DecodeResponseTrailer(*bytes.Buffer, *responseMeta) error } // envelopeBytes is an array of bytes representing an encoded envelope. type envelopeBytes [envelopeLen]byte -// envelope is an exploded representation of the 5-byte preamble that appears -// on the wire for enveloped protocols. This form is protocol-agnostic. -type envelope struct { - trailer bool - compressed bool - length uint32 +// encoding represents the encoding used for a request or response. +type encoding struct { + Codec Codec + Compressor compressor } // requestMeta represents the metadata found in request headers that are // protocol-specific. type requestMeta struct { - timeout time.Duration - hasTimeout bool - codec string - compression string - acceptCompression []string + Body io.ReadCloser + Header httpHeader + URL *url.URL + Method string + ProtoMajor int + ProtoMinor int + RequiresBody bool // true if the header depends on the request body + + // Following fields are derived from the request headers. + Timeout time.Duration + CodecName string + CompressionName string + AcceptCompression []string + + // State for the request encoding. + Client encoding + Server encoding } // responseMeta represents the metadata found in response headers that are // protocol-specific. type responseMeta struct { - end *responseEnd - codec string - compression string - acceptCompression []string - pendingTrailers http.Header - pendingTrailerKeys headerKeys -} - -// responseEnd is a protocol-agnostic representation of the disposition -// of an RPC. -type responseEnd struct { - err *connect.Error - trailers http.Header + Header httpHeader + StatusCode int + HasTrailer bool + WroteStatus bool // true if the header has been written - // httpCode is only populated when the responseEnd source contained - // such a code. This happens when the responseEnd comes from the - // response headers, which include the status line. It can also - // occur for REST streaming responses, where the final message may - // include both gRPC and HTTP codes. - httpCode int + // Following fields are derived from the response headers. + CodecName string + CompressionName string + AcceptCompression []string - // For enveloping protocols where the end is in a special stream - // payload, this will be true if that special payload was compressed. - // This can be used by a protocol handler that also encodes the end - // in a stream payload to decide whether to compress the final frame. - wasCompressed bool -} - -type headerKeys map[string]struct{} - -func (k headerKeys) add(key string) { - k[textproto.CanonicalMIMEHeaderKey(key)] = struct{}{} -} - -func (k headerKeys) contains(key string) bool { - _, contains := k[textproto.CanonicalMIMEHeaderKey(key)] - return contains + // State for the response encoding. + Client encoding + Server encoding } // parseMultiHeader parses headers that allow multiple values. It @@ -416,3 +209,14 @@ func parseMultiHeader(vals []string) []string { } return result } + +// type Stream interface { +// RecvHeader() http.Header +// RecvMessage(proto.Message) error +// SendHeader(http.Header) error +// SendMessage(proto.Message) error +// SendTrailer(http.Header) error +// } +// +// type StreamHandler func(Stream) error +// type StreamInterceptor func(info protoreflect.MethodDescriptor, stream Stream, handler StreamHandler) error diff --git a/protocol_connect.go b/protocol_connect.go index 6448b34..13c3ed8 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -37,79 +37,62 @@ import ( ) const ( - protocolNameConnectUnary = protocolNameConnect + " unary" - protocolNameConnectUnaryGet = protocolNameConnectUnary + " (GET)" - protocolNameConnectUnaryPost = protocolNameConnectUnary + " (POST)" - protocolNameConnectStream = protocolNameConnect + " stream" - // TODO: Extract more constants for header names and values. contentTypeJSON = "application/json" + + connectFlagEnvelopeEndStream = 0b00000010 ) -// connectUnaryGetClientProtocol implements the Connect protocol for -// processing unary RPCs received from the client that use GET as the +// connectUnaryClientProtocol implements the Connect protocol for +// processing unary RPCs received from the client that use POST as the // HTTP method. -type connectUnaryGetClientProtocol struct{} +type connectUnaryClientProtocol struct { + config *methodConfig +} -var _ clientProtocolHandler = connectUnaryGetClientProtocol{} -var _ clientProtocolAllowsGet = connectUnaryGetClientProtocol{} -var _ clientProtocolEndMustBeInHeaders = connectUnaryGetClientProtocol{} -var _ clientBodyPreparer = connectUnaryGetClientProtocol{} +var _ clientProtocolHandler = connectUnaryClientProtocol{} -func (c connectUnaryGetClientProtocol) protocol() Protocol { +func (c connectUnaryClientProtocol) Protocol() Protocol { return ProtocolConnect } -func (c connectUnaryGetClientProtocol) acceptsStreamType(_ *operation, streamType connect.StreamType) bool { +func (c connectUnaryClientProtocol) acceptsStreamType(_ *operation, streamType connect.StreamType) bool { return streamType == connect.StreamTypeUnary } -func (c connectUnaryGetClientProtocol) allowsGetRequests(conf *methodConfig) bool { +func (c connectUnaryClientProtocol) allowsGetRequests(conf *methodConfig) bool { methodOpts, ok := conf.descriptor.Options().(*descriptorpb.MethodOptions) return ok && methodOpts.GetIdempotencyLevel() == descriptorpb.MethodOptions_NO_SIDE_EFFECTS } -func (c connectUnaryGetClientProtocol) endMustBeInHeaders() bool { - return true -} - -func (c connectUnaryGetClientProtocol) extractProtocolRequestHeaders(op *operation, headers http.Header) (requestMeta, error) { - var reqMeta requestMeta - if err := connectExtractTimeout(headers, &reqMeta); err != nil { - return reqMeta, err +func (c connectUnaryClientProtocol) decodeGetQuery(meta *requestMeta) error { + if !c.allowsGetRequests(c.config) { + return protocolError("GET requests not allowed for this method") } - query := op.queryValues() - reqMeta.codec = query.Get("encoding") - reqMeta.compression = query.Get("compression") - reqMeta.acceptCompression = parseMultiHeader(headers.Values("Accept-Encoding")) - headers.Del("Accept-Encoding") - headers.Del("Content-Type") - headers.Del("Connect-Protocol-Version") - return reqMeta, nil -} - -func (c connectUnaryGetClientProtocol) addProtocolResponseHeaders(meta responseMeta, headers http.Header) int { - // Response format is the same as unary POST; only requests differ - return connectUnaryPostClientProtocol{}.addProtocolResponseHeaders(meta, headers) -} - -func (c connectUnaryGetClientProtocol) encodeEnd(op *operation, end *responseEnd, writer io.Writer, wasInHeaders bool) http.Header { - // Response format is the same as unary POST; only requests differ - return connectUnaryPostClientProtocol{}.encodeEnd(op, end, writer, wasInHeaders) -} - -func (c connectUnaryGetClientProtocol) requestNeedsPrep(_ *operation) bool { - return true -} + query := meta.URL.Query() + meta.URL.RawQuery = "" + meta.CodecName = query.Get("encoding") + meta.CompressionName = query.Get("compression") + meta.AcceptCompression = parseMultiHeader(meta.Header.Values("Accept-Encoding")) + meta.Header.Del("Accept-Encoding") + meta.Header.Del("Content-Type") + meta.Header.Del("Connect-Protocol-Version") -func (c connectUnaryGetClientProtocol) prepareUnmarshalledRequest(op *operation, src []byte, target proto.Message) error { - if len(src) > 0 { - return fmt.Errorf("connect unary protocol using GET HTTP method should have no body; instead got %d bytes", len(src)) + codec, err := c.config.GetClientCodec(meta.CodecName) + if err != nil { + return err + } + meta.Client.Codec = codec + if meta.CompressionName != "" { + compressor, err := c.config.GetClientCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Client.Compressor = compressor } - // TODO: ideally we could *replace* the request body with the bytes in the query string and then - // otherwise use message and its re-encoding/re-compressing capability. - vals := op.queryValues() - base64Str := vals.Get("base64") + + // Replace the request body with the bytes in the query string. + base64Str := query.Get("base64") var base64enc bool switch base64Str { case "", "0": @@ -117,9 +100,9 @@ func (c connectUnaryGetClientProtocol) prepareUnmarshalledRequest(op *operation, case "1": base64enc = true default: - return fmt.Errorf("query string parameter base64 should be absent or have value 0 or 1; instead got %q", base64Str) + return protocolError("query string parameter base64 should be absent or have value 0 or 1; instead got %q", base64Str) } - msgStr := vals.Get("message") + msgStr := query.Get("message") var msgData []byte if base64enc && msgStr != "" { var err error @@ -134,285 +117,341 @@ func (c connectUnaryGetClientProtocol) prepareUnmarshalledRequest(op *operation, } else { msgData = ([]byte)(msgStr) } - if op.client.reqCompression != nil { - dst := op.bufferPool.Get() - defer op.bufferPool.Put(dst) - if err := op.client.reqCompression.decompress(dst, bytes.NewBuffer(msgData)); err != nil { - return err - } - msgData = dst.Bytes() - } - return op.client.codec.Unmarshal(msgData, target) -} -func (c connectUnaryGetClientProtocol) responseNeedsPrep(_ *operation) bool { - return false -} - -func (c connectUnaryGetClientProtocol) prepareMarshalledResponse(_ *operation, _ []byte, _ proto.Message, _ http.Header) ([]byte, error) { - return nil, errors.New("response does not need preparation") -} - -func (c connectUnaryGetClientProtocol) String() string { - return protocolNameConnectUnaryGet + // Require blocking on request body to encode URL parameters. + meta.RequiresBody = true + meta.Body = io.NopCloser(bytes.NewBuffer(msgData)) + return nil } -// connectUnaryPostClientProtocol implements the Connect protocol for -// processing unary RPCs received from the client that use POST as the -// HTTP method. -type connectUnaryPostClientProtocol struct{} - -var _ clientProtocolHandler = connectUnaryPostClientProtocol{} -var _ clientProtocolEndMustBeInHeaders = connectUnaryPostClientProtocol{} +func (c connectUnaryClientProtocol) DecodeRequestHeader(meta *requestMeta) error { + if c.config.streamType != connect.StreamTypeUnary { + return protocolError("expected unary stream type") + } + if meta.Method == http.MethodGet { + return c.decodeGetQuery(meta) + } -func (c connectUnaryPostClientProtocol) protocol() Protocol { - return ProtocolConnect -} + timeout, err := connectExtractTimeout(meta.Header) + if err != nil { + return err + } + meta.Timeout = timeout + meta.CodecName = strings.TrimPrefix(meta.Header.Get("Content-Type"), "application/") + if meta.CodecName == CodecJSON+"; charset=utf-8" { + // TODO: should we support other text formats that may need charset check? + meta.CodecName = CodecJSON + } + meta.Header.Del("Content-Type") + meta.CompressionName = meta.Header.Get("Content-Encoding") + meta.Header.Del("Content-Encoding") + meta.AcceptCompression = parseMultiHeader(meta.Header.Values("Accept-Encoding")) + meta.Header.Del("Accept-Encoding") + meta.Header.Del("Connect-Protocol-Version") -func (c connectUnaryPostClientProtocol) acceptsStreamType(_ *operation, streamType connect.StreamType) bool { - return streamType == connect.StreamTypeUnary + // Resolve Codecs + codec, err := c.config.GetClientCodec(meta.CodecName) + if err != nil { + return err + } + meta.Client.Codec = codec + if meta.CompressionName != "" { + compressor, err := c.config.GetClientCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Client.Compressor = compressor + } + return nil } -func (c connectUnaryPostClientProtocol) endMustBeInHeaders() bool { - return true +func (c connectUnaryClientProtocol) PrepareRequestMessage(msg *messageBuffer, meta *requestMeta) error { + msg.Src.IsCompressed = meta.Client.Compressor != nil + msg.Src.ReadMode = readModeEOF + return nil } -func (c connectUnaryPostClientProtocol) extractProtocolRequestHeaders(_ *operation, headers http.Header) (requestMeta, error) { - var reqMeta requestMeta - if err := connectExtractTimeout(headers, &reqMeta); err != nil { - return reqMeta, err - } - reqMeta.codec = strings.TrimPrefix(headers.Get("Content-Type"), "application/") - if reqMeta.codec == CodecJSON+"; charset=utf-8" { - // TODO: should we support other text formats that may need charset check? - reqMeta.codec = CodecJSON - } - headers.Del("Content-Type") - reqMeta.compression = headers.Get("Content-Encoding") - headers.Del("Content-Encoding") - reqMeta.acceptCompression = parseMultiHeader(headers.Values("Accept-Encoding")) - headers.Del("Accept-Encoding") - headers.Del("Connect-Protocol-Version") - return reqMeta, nil -} - -func (c connectUnaryPostClientProtocol) addProtocolResponseHeaders(meta responseMeta, headers http.Header) int { - status := http.StatusOK - if meta.end != nil && meta.end.err != nil { - status = httpStatusCodeFromRPC(meta.end.err.Code()) - headers.Set("Content-Type", contentTypeJSON) // error bodies are always in JSON - // TODO: Content-Encoding to compress error? - } else { - headers.Set("Content-Type", "application/"+meta.codec) - if meta.compression != "" { - headers.Set("Content-Encoding", meta.compression) - } +func (c connectUnaryClientProtocol) EncodeRequestHeader(meta *responseMeta) error { + // Resolve Codecs + codec, err := c.config.GetClientCodec(meta.CodecName) + if err != nil { + return err } - if meta.end != nil { - for k, v := range meta.end.trailers { - headers["Trailer-"+k] = v + meta.Client.Codec = codec + meta.Header.Set("Content-Type", "application/"+meta.CodecName) + if meta.CompressionName != "" { + compressor, err := c.config.GetClientCompressor(meta.CompressionName) + if err != nil { + return err } + meta.Client.Compressor = compressor + meta.Header.Set("Content-Encoding", meta.CompressionName) + meta.Header.Set("Accept-Encoding", meta.CompressionName) } - if len(meta.acceptCompression) > 0 { - headers.Set("Accept-Encoding", strings.Join(meta.acceptCompression, ", ")) - } - return status + return nil } -func (c connectUnaryPostClientProtocol) encodeEnd(op *operation, end *responseEnd, writer io.Writer, wasInHeaders bool) http.Header { - if end.err != nil && !wasInHeaders { - // TODO: Uh oh. We already flushed headers and started writing body. What can we do? - // Should this log? If we are using http/2, is there some way we could send - // a "goaway" frame to the client, to indicate abnormal end of stream? - return nil - } - if end.err == nil { - return nil - } - wireErr := connectErrorToWireError(end.err, op.methodConf.resolver) - data, err := json.Marshal(wireErr) - if err != nil { - data = ([]byte)(`{"code": "internal", "message": ` + strconv.Quote("failed to marshal end error: "+err.Error()) + `}`) +func (c connectUnaryClientProtocol) PrepareResponseMessage(msg *messageBuffer, meta *responseMeta) error { + msg.Dst.IsCompressed = meta.Client.Compressor != nil + msg.Dst.WaitForTrailer = true + return nil +} + +func (c connectUnaryClientProtocol) EncodeResponseTrailer(_ *bytes.Buffer, meta *responseMeta) error { + trailer := httpExtractTrailers(meta.Header) + for key, values := range trailer { + meta.Header.Del(key) + meta.Header.Set("Trailer-"+key, strings.Join(values, ", ")) } - _, _ = writer.Write(data) return nil } +func (c connectUnaryClientProtocol) EncodeError(buf *bytes.Buffer, meta *responseMeta, err error) { + // Encode the error as uncompressed JSON. + meta.Header.Del("Content-Encoding") + meta.Header.Set("Content-Type", contentTypeJSON) -func (c connectUnaryPostClientProtocol) String() string { - return protocolNameConnectUnaryPost + cerr := asConnectError(err) + wireErr := connectErrorToWireError(cerr, c.config.resolver) + if err := json.NewEncoder(buf).Encode(wireErr); err != nil { + buf.WriteString(`{"code": "internal", "message": ` + strconv.Quote("failed to marshal end error: "+err.Error()) + `}`) + } + meta.StatusCode = httpStatusCodeFromRPC(cerr.Code()) } // connectUnaryServerProtocol implements the Connect protocol for // sending unary RPCs to the server handler. -type connectUnaryServerProtocol struct{} +type connectUnaryServerProtocol struct { + config *methodConfig + + // State + statusCode int +} -// NB: the latter two interfaces must be implemented to handle GET requests. -var _ serverProtocolHandler = connectUnaryServerProtocol{} -var _ requestLineBuilder = connectUnaryServerProtocol{} -var _ serverBodyPreparer = connectUnaryServerProtocol{} +var _ serverProtocolHandler = (*connectUnaryServerProtocol)(nil) -func (c connectUnaryServerProtocol) protocol() Protocol { +func (c *connectUnaryServerProtocol) Protocol() Protocol { return ProtocolConnect } -func (c connectUnaryServerProtocol) addProtocolRequestHeaders(meta requestMeta, headers http.Header) { - headers.Set("Content-Type", "application/"+meta.codec) - if meta.compression != "" { - headers.Set("Content-Encoding", meta.compression) +func (c *connectUnaryServerProtocol) encodeGetQuery(meta *requestMeta) error { + meta.Method = http.MethodGet + meta.Header.Del("Content-Type") + + codec, ok := meta.Server.Codec.(StableCodec) + if !ok { + return protocolError("cannot use GET with unstable codec") + } + meta.Server.Codec = connectGetRequestCodec{ + config: c.config, + codec: codec, + compressor: meta.Server.Compressor, + meta: meta, + } + // Handle compression in the codec. + meta.Server.Compressor = nil + return nil +} + +func (c *connectUnaryServerProtocol) EncodeRequestHeader(meta *requestMeta) error { + meta.RequiresBody = true // Require body for URL parameters. + + // Resolve codecs + codec, err := c.config.GetServerCodec(meta.CodecName) + if err != nil { + return err } - if len(meta.acceptCompression) > 0 { - headers.Set("Accept-Encoding", strings.Join(meta.acceptCompression, ", ")) + meta.Server.Codec = codec + if meta.CompressionName != "" { + compression, err := c.config.GetServerCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Server.Compressor = compression } - headers.Set("Connect-Protocol-Version", "1") - if meta.hasTimeout { - timeoutStr := connectEncodeTimeout(meta.timeout) + if meta.Timeout > 0 { + timeoutStr := connectEncodeTimeout(meta.Timeout) if timeoutStr != "" { - headers.Set("Connect-Timeout-Ms", timeoutStr) + meta.Header.Set("Connect-Timeout-Ms", timeoutStr) } } + meta.Header.Set("Accept", "application/"+meta.CodecName) + + // Encode as GET request if possible. + meta.URL.Path = c.config.methodPath + if c.useGet(meta) { + return c.encodeGetQuery(meta) + } + meta.Method = http.MethodPost + meta.URL.RawQuery = "" + + meta.Header.Set("Content-Type", "application/"+meta.CodecName) + if meta.CompressionName != "" { + meta.Header.Set("Content-Encoding", meta.CompressionName) + } + if len(meta.AcceptCompression) > 0 { + meta.Header.Set("Accept-Encoding", strings.Join(meta.AcceptCompression, ", ")) + } + meta.Header.Set("Connect-Protocol-Version", "1") + + return nil } -func (c connectUnaryServerProtocol) extractProtocolResponseHeaders(statusCode int, headers http.Header) (responseMeta, responseEndUnmarshaller, error) { - var respMeta responseMeta - contentType := headers.Get("Content-Type") +func (c *connectUnaryServerProtocol) PrepareRequestMessage(msg *messageBuffer, meta *requestMeta) error { + msg.Dst.IsCompressed = meta.Server.Compressor != nil + return nil +} + +func (c *connectUnaryServerProtocol) DecodeRequestHeader(meta *responseMeta) error { + contentType := meta.Header.Get("Content-Type") + meta.Header.Del("Content-Type") + meta.CompressionName = meta.Header.Get("Content-Encoding") + meta.Header.Del("Content-Encoding") + meta.AcceptCompression = parseMultiHeader(meta.Header.Values("Accept-Encoding")) + meta.Header.Del("Accept-Encoding") + c.statusCode = meta.StatusCode // Save for DecodeTrailer. + switch { case strings.HasPrefix(contentType, "application/"): - respMeta.codec = strings.TrimPrefix(contentType, "application/") + meta.CodecName = strings.TrimPrefix(contentType, "application/") default: - respMeta.codec = contentType + "?" - } - headers.Del("Content-Type") - respMeta.compression = headers.Get("Content-Encoding") - headers.Del("Content-Encoding") - respMeta.acceptCompression = parseMultiHeader(headers.Values("Accept-Encoding")) - headers.Del("Accept-Encoding") - trailers := connectExtractUnaryTrailers(headers) - - var endUnmarshaller responseEndUnmarshaller - if statusCode == http.StatusOK { - respMeta.pendingTrailers = trailers - } else { - // Content-Type must be application/json for errors or else it's invalid - if contentType != contentTypeJSON { - respMeta.codec = contentType + "?" - } else { - respMeta.codec = "" + // Invalid codec, try capture the error in DecodeMessage. + return nil + } + + // Encode connect trailers as HTTP trailers. + trailers := connectExtractUnaryTrailers(meta.Header) + for key, values := range trailers { + meta.Header.Del(key) + for _, value := range values { + meta.Header.Add(http.TrailerPrefix+key, value) } - respMeta.end = &responseEnd{ - wasCompressed: respMeta.compression != "", - trailers: trailers, + } + + // Resolve codecs, only if the response is successful. + // Otherwise we default to JSON using a connect compatible JSON codec. + if c.statusCode == 200 { + codec, err := c.config.GetServerCodec(meta.CodecName) + if err != nil { + return err } - endUnmarshaller = func(_ Codec, buf *bytes.Buffer, end *responseEnd) { - var wireErr connectWireError - if err := json.Unmarshal(buf.Bytes(), &wireErr); err != nil { - end.err = connect.NewError(connect.CodeInternal, err) - return - } - end.err = wireErr.toConnectError() + meta.Server.Codec = codec + } + if meta.CompressionName != "" { + compression, err := c.config.GetServerCompressor(meta.CompressionName) + if err != nil { + return err } + meta.Server.Compressor = compression } - return respMeta, endUnmarshaller, nil -} - -func (c connectUnaryServerProtocol) extractEndFromTrailers(_ *operation, _ http.Header) (responseEnd, error) { - return responseEnd{}, nil + return nil } -func (c connectUnaryServerProtocol) requestNeedsPrep(op *operation) bool { - return c.useGet(op) +func (c *connectUnaryServerProtocol) PrepareResponseMessage(msg *messageBuffer, meta *responseMeta) error { + msg.Src.ReadMode = readModeEOF // Read until EOF + msg.Src.IsCompressed = meta.Server.Compressor != nil + if meta.Server.Codec == nil || c.statusCode/100 != 2 { + msg.Src.IsTrailer = true + } + return nil } -func (c connectUnaryServerProtocol) useGet(op *operation) bool { - methodOptions, _ := op.methodConf.descriptor.Options().(*descriptorpb.MethodOptions) - _, isStable := op.server.codec.(StableCodec) - return op.request.Method == http.MethodGet && isStable && - methodOptions.GetIdempotencyLevel() == descriptorpb.MethodOptions_NO_SIDE_EFFECTS -} +// DecodeResponseTrailer decodes the error from the last message, if any. +// Trailers are set in the HTTP response headers. +func (c *connectUnaryServerProtocol) DecodeResponseTrailer(buf *bytes.Buffer, _ *responseMeta) error { + if c.statusCode/100 == 2 { + return nil + } + if buf.Len() == 0 { + code := httpStatusCodeToRPC(c.statusCode) + return connect.NewWireError(code, errors.New("no error details")) + } -func (c connectUnaryServerProtocol) prepareMarshalledRequest(_ *operation, _ []byte, _ proto.Message, _ http.Header) ([]byte, error) { - // NB: This would be called when requestNeedsPrep returns true, for GET requests. - // In that case, there is no request body, so we can return a nil result. - // The request data will actually be put into the URL in that case. - // See the requestLine method below. - return nil, nil + var wireErr connectWireError + if err := json.Unmarshal(buf.Bytes(), &wireErr); err != nil { + return connect.NewError(connect.CodeInternal, fmt.Errorf("failed to unmarshal error: %w", err)) + } + return wireErr.toConnectError() } -func (c connectUnaryServerProtocol) responseNeedsPrep(_ *operation) bool { - return false +func (c *connectUnaryServerProtocol) useGet(meta *requestMeta) bool { + methodOptions, _ := c.config.descriptor.Options().(*descriptorpb.MethodOptions) + noSideEffects := methodOptions.GetIdempotencyLevel() == descriptorpb.MethodOptions_NO_SIDE_EFFECTS + _, isStable := meta.Server.Codec.(StableCodec) + return isStable && noSideEffects } -func (c connectUnaryServerProtocol) prepareUnmarshalledResponse(_ *operation, _ []byte, _ proto.Message) error { - return errors.New("response does not need preparation") +type connectGetRequestCodec struct { + config *methodConfig + codec StableCodec + compressor compressor + meta *requestMeta } -func (c connectUnaryServerProtocol) requiresMessageToProvideRequestLine(op *operation) bool { - return c.useGet(op) +func (c connectGetRequestCodec) Name() string { + return "connect-get+" + c.codec.Name() } - -func (c connectUnaryServerProtocol) requestLine(op *operation, msg proto.Message) (urlPath, queryParams, method string, includeBody bool, err error) { - if !c.useGet(op) { - return op.methodConf.methodPath, "", http.MethodPost, true, nil - } +func (c connectGetRequestCodec) MarshalAppend(dst []byte, msg proto.Message) ([]byte, error) { vals := make(url.Values, 5) vals.Set("connect", "v1") + vals.Set("encoding", c.codec.Name()) - vals.Set("encoding", op.server.codec.Name()) - buf := op.bufferPool.Get() - stableMarshaler, _ := op.server.codec.(StableCodec) // c.useGet called above already checked this - data, err := stableMarshaler.MarshalAppendStable(buf.Bytes(), msg) + dst, err := c.codec.MarshalAppendStable(dst, msg) if err != nil { - op.bufferPool.Put(buf) - return "", "", "", false, err - } - buf = op.bufferPool.Wrap(data, buf) - defer op.bufferPool.Put(buf) - - encoded := op.bufferPool.Get() - defer op.bufferPool.Put(encoded) - if op.server.reqCompression != nil { - vals.Set("compression", op.server.reqCompression.Name()) - if err := op.server.reqCompression.compress(encoded, buf); err != nil { - return "", "", "", false, err + return nil, err + } + + msgRaw := dst + + // TODO: move the compression logic outside of the codec. + if c.compressor != nil { + var tmp bytes.Buffer + src := bytes.NewBuffer(msgRaw) + vals.Set("compression", c.compressor.Name()) + if err := c.compressor.compress(&tmp, src); err != nil { + return nil, err } - // for the next step, we want encoded empty and data to be the message source - buf, encoded = encoded, buf // swap so writing to encoded doesn't mutate data - encoded.Reset() - data = buf.Bytes() + msgRaw = tmp.Bytes() } var msgStr string - if stableMarshaler.IsBinary() || op.server.reqCompression != nil { - b64encodedLen := base64.RawURLEncoding.EncodedLen(len(data)) + if c.codec.IsBinary() || c.compressor != nil { vals.Set("base64", "1") - encoded.Grow(b64encodedLen) - encodedBytes := encoded.Bytes()[:b64encodedLen] - base64.RawURLEncoding.Encode(encodedBytes, data) - msgStr = string(encodedBytes) + msgStr = base64.RawURLEncoding.EncodeToString(msgRaw) } else { - msgStr = string(data) + msgStr = string(dst) } vals.Set("message", msgStr) - queryString := vals.Encode() - if uint32(len(op.methodConf.methodPath)+len(queryString)+1) > op.methodConf.maxGetURLBytes { + c.meta.URL.RawQuery = vals.Encode() + c.meta.Method = http.MethodGet + size := len(c.config.methodPath) + len(c.meta.URL.RawQuery) + if size >= int(c.config.maxGetURLBytes) { // URL is too big; fall back to POST - return op.methodConf.methodPath, "", http.MethodPost, true, nil + // TODO: should we try to compress the message? + c.meta.Header.Set("Content-Type", "application/"+c.codec.Name()) + c.meta.Header.Set("Connect-Protocol-Version", "1") + c.meta.Header.Del("Content-Encoding") + c.meta.Method = http.MethodPost + c.meta.URL.RawQuery = "" + return dst, nil } - return op.methodConf.methodPath, vals.Encode(), http.MethodGet, false, nil + // Successfully encoded as GET request. + return dst[:0], nil } - -func (c connectUnaryServerProtocol) String() string { - return protocolNameConnectUnary +func (c connectGetRequestCodec) Unmarshal(_ []byte, _ proto.Message) error { + // This should never be called. + return fmt.Errorf("unimplemented") } // connectStreamClientProtocol implements the Connect protocol for // processing streaming RPCs received from the client. -type connectStreamClientProtocol struct{} +type connectStreamClientProtocol struct { + config *methodConfig +} var _ clientProtocolHandler = connectStreamClientProtocol{} -var _ envelopedProtocolHandler = connectStreamClientProtocol{} -func (c connectStreamClientProtocol) protocol() Protocol { +func (c connectStreamClientProtocol) Protocol() Protocol { return ProtocolConnect } @@ -420,182 +459,226 @@ func (c connectStreamClientProtocol) acceptsStreamType(_ *operation, streamType return streamType != connect.StreamTypeUnary } -func (c connectStreamClientProtocol) extractProtocolRequestHeaders(_ *operation, headers http.Header) (requestMeta, error) { - var reqMeta requestMeta - if err := connectExtractTimeout(headers, &reqMeta); err != nil { - return reqMeta, err +func (c connectStreamClientProtocol) DecodeRequestHeader(meta *requestMeta) error { + if meta.Method != http.MethodPost { + return protocolError("expected POST method") } - reqMeta.codec = strings.TrimPrefix(headers.Get("Content-Type"), "application/connect+") - headers.Del("Content-Type") - reqMeta.compression = headers.Get("Connect-Content-Encoding") - headers.Del("Connect-Content-Encoding") - reqMeta.acceptCompression = parseMultiHeader(headers.Values("Connect-Accept-Encoding")) - headers.Del("Connect-Accept-Encoding") - return reqMeta, nil -} + timeout, err := connectExtractTimeout(meta.Header) + if err != nil { + return err + } + meta.Timeout = timeout -func (c connectStreamClientProtocol) addProtocolResponseHeaders(meta responseMeta, headers http.Header) int { - headers.Set("Content-Type", "application/connect+"+meta.codec) - if meta.compression != "" { - headers.Set("Connect-Content-Encoding", meta.compression) + meta.CodecName = strings.TrimPrefix(meta.Header.Get("Content-Type"), "application/connect+") + meta.Header.Del("Content-Type") + meta.CompressionName = meta.Header.Get("Connect-Content-Encoding") + meta.Header.Del("Connect-Content-Encoding") + meta.AcceptCompression = parseMultiHeader(meta.Header.Values("Connect-Accept-Encoding")) + meta.Header.Del("Connect-Accept-Encoding") + // Resolve codecs + codec, err := c.config.GetClientCodec(meta.CodecName) + if err != nil { + return err } - if len(meta.acceptCompression) > 0 { - headers.Set("Connect-Accept-Encoding", strings.Join(meta.acceptCompression, ", ")) + meta.Client.Codec = codec + if meta.CompressionName != "" { + compressor, err := c.config.GetClientCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Client.Compressor = compressor } - return http.StatusOK + return nil } -func (c connectStreamClientProtocol) encodeEnd(op *operation, end *responseEnd, writer io.Writer, _ bool) http.Header { - streamEnd := &connectStreamEnd{Metadata: end.trailers} - if end.err != nil { - streamEnd.Error = connectErrorToWireError(end.err, op.methodConf.resolver) +func (c connectStreamClientProtocol) PrepareRequestMessage(msg *messageBuffer, meta *requestMeta) error { + if msg.Buf.Len() < 5 { + return io.ErrShortBuffer // ask for more data + } + flags, size, err := readEnvelope(msg.Buf) + if err != nil { + return err } - buffer := op.bufferPool.Get() - defer op.bufferPool.Put(buffer) - enc := json.NewEncoder(buffer) - if err := enc.Encode(streamEnd); err != nil { - buffer.WriteString(`{"error": {"code": "internal", "message": ` + strconv.Quote(err.Error()) + `}}`) + msg.Src.Size = size + if flags&flagEnvelopeCompressed != 0 { + msg.Src.IsCompressed = true } - // TODO: compress? - env := envelope{trailer: true, length: uint32(buffer.Len())} - envBytes := c.encodeEnvelope(env) - _, _ = writer.Write(envBytes[:]) - _, _ = buffer.WriteTo(writer) return nil } -func (c connectStreamClientProtocol) decodeEnvelope(envBytes envelopeBytes) (envelope, error) { - flags := envBytes[0] - if flags != 0 && flags != 1 { - return envelope{}, fmt.Errorf("invalid compression flag: must be 0 or 1; instead got %d", flags) +func (c connectStreamClientProtocol) EncodeRequestHeader(meta *responseMeta) error { + meta.Header.Set("Content-Type", "application/connect+"+meta.CodecName) + if meta.CompressionName != "" { + meta.Header.Set("Connect-Content-Encoding", meta.CompressionName) + meta.Header.Set("Connect-Accept-Encoding", meta.CompressionName) + } + // Resolve codecs + codec, err := c.config.GetClientCodec(meta.CodecName) + if err != nil { + return err + } + meta.Client.Codec = codec + if meta.CompressionName != "" { + compressor, err := c.config.GetClientCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Client.Compressor = compressor } - return envelope{ - compressed: flags == 1, - length: binary.BigEndian.Uint32(envBytes[1:]), - }, nil + return nil } -func (c connectStreamClientProtocol) encodeEnvelope(env envelope) envelopeBytes { - var envBytes envelopeBytes - if env.compressed { - envBytes[0] = 1 +func (c connectStreamClientProtocol) PrepareResponseMessage(msg *messageBuffer, meta *responseMeta) error { + msg.Dst.IsEnvelope = true + msg.Dst.IsCompressed = meta.Client.Compressor != nil + if msg.Dst.IsCompressed { + msg.Dst.Flags |= flagEnvelopeCompressed } - if env.trailer { - envBytes[0] |= 2 + return nil +} +func (c connectStreamClientProtocol) EncodeResponseTrailer(buf *bytes.Buffer, meta *responseMeta) error { + metadata := http.Header(httpExtractTrailers(meta.Header)) + streamEnd := &connectStreamEnd{ + Metadata: metadata, } - binary.BigEndian.PutUint32(envBytes[1:], env.length) - return envBytes + connectEncodeStreamEnd(buf, streamEnd) + return nil } -func (c connectStreamClientProtocol) String() string { - return protocolNameConnectStream +func (c connectStreamClientProtocol) EncodeError(buf *bytes.Buffer, meta *responseMeta, err error) { + // Encode the error as uncompressed JSON. + cerr := asConnectError(err) + wireErr := connectErrorToWireError(cerr, c.config.resolver) + metadata := http.Header(httpExtractTrailers(meta.Header)) + streamEnd := &connectStreamEnd{ + Metadata: metadata, + Error: wireErr, + } + connectEncodeStreamEnd(buf, streamEnd) } // connectStreamServerProtocol implements the Connect protocol for // sending streaming RPCs to the server handler. -type connectStreamServerProtocol struct{} +type connectStreamServerProtocol struct { + config *methodConfig +} var _ serverProtocolHandler = connectStreamServerProtocol{} -var _ serverEnvelopedProtocolHandler = connectStreamServerProtocol{} -func (c connectStreamServerProtocol) protocol() Protocol { +func (c connectStreamServerProtocol) Protocol() Protocol { return ProtocolConnect } -func (c connectStreamServerProtocol) addProtocolRequestHeaders(meta requestMeta, headers http.Header) { - headers.Set("Content-Type", "application/connect+"+meta.codec) - if meta.compression != "" { - headers.Set("Connect-Content-Encoding", meta.compression) +func (c connectStreamServerProtocol) EncodeRequestHeader(meta *requestMeta) error { + meta.Header.Set("Content-Type", "application/connect+"+meta.CodecName) + if meta.CompressionName != "" { + meta.Header.Set("Connect-Content-Encoding", meta.CompressionName) + meta.Header.Set("Connect-Accept-Encoding", meta.CompressionName) + } + if meta.Timeout > 0 { + meta.Header.Set("Connect-Timeout-Ms", connectEncodeTimeout(meta.Timeout)) + } + // Resovlve Codecs + codec, err := c.config.GetServerCodec(meta.CodecName) + if err != nil { + return err } - if len(meta.acceptCompression) > 0 { - headers.Set("Connect-Accept-Encoding", strings.Join(meta.acceptCompression, ", ")) + meta.Server.Codec = codec + if meta.CompressionName != "" { + compression, err := c.config.GetServerCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Server.Compressor = compression } - if meta.hasTimeout { - headers.Set("Connect-Timeout-Ms", connectEncodeTimeout(meta.timeout)) + meta.URL.Path = c.config.methodPath + meta.Method = http.MethodPost + return nil +} + +func (c connectStreamServerProtocol) PrepareRequestMessage(msg *messageBuffer, meta *requestMeta) error { + msg.Dst.IsEnvelope = true + if meta.Server.Compressor != nil { + msg.Dst.IsCompressed = true + msg.Dst.Flags |= flagEnvelopeCompressed } + return nil } -func (c connectStreamServerProtocol) extractProtocolResponseHeaders(statusCode int, headers http.Header) (responseMeta, responseEndUnmarshaller, error) { - var respMeta responseMeta - contentType := headers.Get("Content-Type") +func (c connectStreamServerProtocol) DecodeRequestHeader(meta *responseMeta) error { + contentType := meta.Header.Get("Content-Type") switch { case strings.HasPrefix(contentType, "application/connect+"): - respMeta.codec = strings.TrimPrefix(contentType, "application/connect+") + meta.CodecName = strings.TrimPrefix(contentType, "application/connect+") default: - respMeta.codec = contentType + "?" - } - headers.Del("Content-Type") - respMeta.compression = headers.Get("Connect-Content-Encoding") - headers.Del("Connect-Content-Encoding") - respMeta.acceptCompression = parseMultiHeader(headers.Values("Connect-Accept-Encoding")) - headers.Del("Connect-Accept-Encoding") - - // See if RPC is already over (unexpected HTTP error or trailers-only response) - if statusCode != http.StatusOK { - if respMeta.end == nil { - respMeta.end = &responseEnd{} - } - if respMeta.end.err == nil { - // TODO: map HTTP status code to an RPC error (opposite of httpStatusCodeFromRPC) - respMeta.end.err = connect.NewError(connect.CodeInternal, fmt.Errorf("unexpected HTTP error: %d %s", statusCode, http.StatusText(statusCode))) + meta.CodecName = contentType + "?" + } + meta.Header.Del("Content-Type") + meta.CompressionName = meta.Header.Get("Connect-Content-Encoding") + meta.Header.Del("Connect-Content-Encoding") + meta.AcceptCompression = parseMultiHeader(meta.Header.Values("Connect-Accept-Encoding")) + meta.Header.Del("Connect-Accept-Encoding") + if meta.StatusCode != http.StatusOK { + return protocolError("expected HTTP status OK (200) but got %d", meta.StatusCode) + } + // Relove Codecs + codec, err := c.config.GetServerCodec(meta.CodecName) + if err != nil { + return err + } + meta.Server.Codec = codec + if meta.CompressionName != "" { + compression, err := c.config.GetServerCompressor(meta.CompressionName) + if err != nil { + return err } + meta.Server.Compressor = compression } - return respMeta, nil, nil -} - -func (c connectStreamServerProtocol) extractEndFromTrailers(_ *operation, _ http.Header) (responseEnd, error) { - return responseEnd{}, errors.New("connect streaming protocol does not use HTTP trailers") + return nil } -func (c connectStreamServerProtocol) decodeEnvelope(envBytes envelopeBytes) (envelope, error) { - flags := envBytes[0] - if flags&0b1111_1100 != 0 { - // invalid bits are set - return envelope{}, fmt.Errorf("invalid frame flags: only lowest two bits may be set; instead got %d", flags) +func (c connectStreamServerProtocol) PrepareResponseMessage(msg *messageBuffer, meta *responseMeta) error { + if msg.Buf.Len() < 5 { + return io.ErrShortBuffer // ask for more data } - return envelope{ - compressed: flags&1 != 0, - trailer: flags&2 != 0, - length: binary.BigEndian.Uint32(envBytes[1:]), - }, nil -} - -func (c connectStreamServerProtocol) encodeEnvelope(env envelope) envelopeBytes { - var envBytes envelopeBytes - if env.compressed { - envBytes[0] = 1 + flags, size, err := readEnvelope(msg.Buf) + if err != nil { + return err + } + msg.Src.Size = size + if flags&flagEnvelopeCompressed != 0 { + msg.Src.IsCompressed = true } - binary.BigEndian.PutUint32(envBytes[1:], env.length) - return envBytes + if flags&connectFlagEnvelopeEndStream != 0 { + msg.Src.IsTrailer = true + } + return nil } -func (c connectStreamServerProtocol) decodeEndFromMessage(op *operation, reader io.Reader) (responseEnd, error) { - // TODO: buffer size limit for headers/trailers; should use http.DefaultMaxHeaderBytes if not configured - buffer := op.bufferPool.Get() - defer op.bufferPool.Put(buffer) - _, err := buffer.ReadFrom(reader) - if err != nil { - return responseEnd{}, err +func (c connectStreamServerProtocol) DecodeResponseTrailer(buf *bytes.Buffer, meta *responseMeta) error { + if buf.Len() == 0 { + return protocolError("expected trailer") } + var streamEnd connectStreamEnd - if err := json.Unmarshal(buffer.Bytes(), &streamEnd); err != nil { - return responseEnd{}, err + if err := json.Unmarshal(buf.Bytes(), &streamEnd); err != nil { + return protocolError("failed to unmarshal trailer: %w", err) + } + // Encode connect trailers as HTTP trailers. + for key, values := range streamEnd.Metadata { + meta.Header.Del(key) + key = http.TrailerPrefix + key + for _, value := range values { + meta.Header.Add(key, value) + } } - var cerr *connect.Error if streamEnd.Error != nil { - cerr = streamEnd.Error.toConnectError() + return streamEnd.Error.toConnectError() } - return responseEnd{ - err: cerr, - trailers: streamEnd.Metadata, - }, nil -} - -func (c connectStreamServerProtocol) String() string { - return protocolNameConnectStream + return nil } -func connectExtractUnaryTrailers(headers http.Header) http.Header { +func connectExtractUnaryTrailers(headers httpHeader) http.Header { var count int for k := range headers { if strings.HasPrefix(k, "Trailer-") { @@ -612,27 +695,25 @@ func connectExtractUnaryTrailers(headers http.Header) http.Header { return result } -func connectExtractTimeout(headers http.Header, meta *requestMeta) error { +func connectExtractTimeout(headers httpHeader) (time.Duration, error) { str := headers.Get("Connect-Timeout-Ms") headers.Del("Connect-Timeout-Ms") if str == "" { - return nil + return 0, nil } timeoutInt, err := strconv.ParseInt(str, 10, 64) if err != nil { - return err + return 0, err } if timeoutInt < 0 { - return fmt.Errorf("timeout header indicated invalid negative value: %d", timeoutInt) + return 0, fmt.Errorf("timeout header indicated invalid negative value: %d", timeoutInt) } timeout := time.Millisecond * time.Duration(timeoutInt) if timeout.Milliseconds() != timeoutInt { // overflow timeout = time.Duration(math.MaxInt64) } - meta.timeout = timeout - meta.hasTimeout = true - return nil + return timeout, nil } func connectEncodeTimeout(timeout time.Duration) string { @@ -707,3 +788,13 @@ type connectStreamEnd struct { Error *connectWireError `json:"error,omitempty"` Metadata http.Header `json:"metadata,omitempty"` } + +func connectEncodeStreamEnd(dst *bytes.Buffer, streamEnd *connectStreamEnd) { + dst.Write([]byte{0, 0, 0, 0, 0}) // empty message + if err := json.NewEncoder(dst).Encode(streamEnd); err != nil { + dst.WriteString(`{"error": {"code": "internal", "message": ` + strconv.Quote(err.Error()) + `}}`) + } + dst.Bytes()[0] |= connectFlagEnvelopeEndStream // set trailer flag + size := uint32(dst.Len() - 5) + binary.BigEndian.PutUint32(dst.Bytes()[1:], size) +} diff --git a/protocol_grpc.go b/protocol_grpc.go index 7ae6015..8a690e4 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -15,6 +15,7 @@ package vanguard import ( + "bufio" "bytes" "encoding/binary" "errors" @@ -32,369 +33,486 @@ import ( "google.golang.org/protobuf/types/known/anypb" ) +const ( + // flagEnvelopeCompressed indicates that the data is compressed. It has the + // same meaning in the gRPC-Web, gRPC-HTTP2, and Connect protocols. + flagEnvelopeCompressed = 0b00000001 + + grpcWebFlagEnvelopeTrailer = 0b10000000 +) + // grpcClientProtocol implements the gRPC protocol for // processing RPCs received from the client. -type grpcClientProtocol struct{} +type grpcClientProtocol struct { + config *methodConfig +} var _ clientProtocolHandler = grpcClientProtocol{} -var _ envelopedProtocolHandler = grpcClientProtocol{} -func (g grpcClientProtocol) protocol() Protocol { +func (g grpcClientProtocol) Protocol() Protocol { return ProtocolGRPC } -func (g grpcClientProtocol) acceptsStreamType(_ *operation, _ connect.StreamType) bool { - return true -} - -func (g grpcClientProtocol) extractProtocolRequestHeaders(_ *operation, headers http.Header) (requestMeta, error) { - headers.Del("Te") // no need to propagate "te: trailers" to requests in different protocols - return grpcExtractRequestMeta("application/grpc", "application/grpc+", headers) -} +func (g grpcClientProtocol) DecodeRequestHeader(meta *requestMeta) error { + if meta.Method != http.MethodPost { + return protocolError("invalid method %q", meta.Method) + } + if meta.ProtoMajor != 2 && meta.ProtoMinor != 0 { + return protocolError("invalid HTTP version %q.%q", meta.ProtoMajor, meta.ProtoMinor) + } + meta.Header.Del("Te") // no need to propagate "te: trailers" to requests in different protocols -func (g grpcClientProtocol) addProtocolResponseHeaders(meta responseMeta, headers http.Header) int { - statusCode := grpcAddResponseMeta("application/grpc+", meta, headers) - if len(meta.pendingTrailers) > 0 { - if meta.pendingTrailerKeys == nil { - meta.pendingTrailerKeys = make(headerKeys, len(meta.pendingTrailers)) - } - for k := range meta.pendingTrailers { - meta.pendingTrailerKeys.add(k) + if err := grpcExtractRequestMeta("application/grpc", "application/grpc+", meta); err != nil { + return err + } + // Resolve Codecs + codec, err := g.config.GetClientCodec(meta.CodecName) + if err != nil { + return err + } + meta.Client.Codec = codec + if meta.CompressionName != "" { + comp, err := g.config.GetClientCompressor(meta.CompressionName) + if err != nil { + return err } + meta.Client.Compressor = comp } - for k := range meta.pendingTrailerKeys { - headers.Add("Trailer", textproto.CanonicalMIMEHeaderKey(k)) + return nil +} + +func (g grpcClientProtocol) PrepareRequestMessage(msg *messageBuffer, _ *requestMeta) error { + if msg.Buf.Len() < 5 { + return io.ErrShortBuffer // ask for more data } - if !meta.pendingTrailerKeys.contains("Grpc-Status") { - headers.Add("Trailer", "Grpc-Status") + flags, size, err := readEnvelope(msg.Buf) + if err != nil { + return err } - if !meta.pendingTrailerKeys.contains("Grpc-Message") { - headers.Add("Trailer", "Grpc-Message") + msg.Src.Size = size + if flags&1 != 0 { + msg.Src.IsCompressed = true } - return statusCode + return nil } -func (g grpcClientProtocol) encodeEnd(_ *operation, end *responseEnd, _ io.Writer, wasInHeaders bool) http.Header { - if wasInHeaders { - // already recorded this in call to addProtocolResponseHeaders - return nil +func (g grpcClientProtocol) EncodeRequestHeader(meta *responseMeta) error { + grpcEncodeResponseHeader("application/grpc+", meta) + meta.Header.Set("Trailer", "Grpc-Status, Grpc-Message, Grpc-Status-Details-Bin") + + codec, err := g.config.GetClientCodec(meta.CodecName) + if err != nil { + return err + } + meta.Client.Codec = codec + if meta.CompressionName != "" { + comp, err := g.config.GetClientCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Client.Compressor = comp } - trailers := make(http.Header, len(end.trailers)+3) - grpcWriteEndToTrailers(end, trailers) - return trailers + return nil } -func (g grpcClientProtocol) decodeEnvelope(bytes envelopeBytes) (envelope, error) { - return grpcServerProtocol{}.decodeEnvelope(bytes) +func (g grpcClientProtocol) PrepareResponseMessage(msg *messageBuffer, meta *responseMeta) error { + msg.Dst.IsEnvelope = true + msg.Dst.IsCompressed = meta.Client.Compressor != nil + if msg.Dst.IsCompressed { + msg.Dst.IsCompressed = true + msg.Dst.Flags |= flagEnvelopeCompressed + } + return nil } -func (g grpcClientProtocol) encodeEnvelope(env envelope) envelopeBytes { - return grpcServerProtocol{}.encodeEnvelope(env) +func (g grpcClientProtocol) EncodeResponseTrailer(_ *bytes.Buffer, meta *responseMeta) error { + meta.Header.Set("Grpc-Status", "0") + meta.Header.Set("Grpc-Message", "") + return nil } -func (g grpcClientProtocol) String() string { - return protocolNameGRPC +func (g grpcClientProtocol) EncodeError(_ *bytes.Buffer, meta *responseMeta, err error) { + cerr := asConnectError(err) + grpcEncodeError(cerr, meta.Header) + meta.StatusCode = http.StatusOK // gRPC errors are always HTTP 200 } // grpcServerProtocol implements the gRPC protocol for // sending RPCs to the server handler. -type grpcServerProtocol struct{} +type grpcServerProtocol struct { + config *methodConfig +} var _ serverProtocolHandler = grpcServerProtocol{} -var _ serverEnvelopedProtocolHandler = grpcServerProtocol{} -func (g grpcServerProtocol) protocol() Protocol { +func (g grpcServerProtocol) Protocol() Protocol { return ProtocolGRPC } -func (g grpcServerProtocol) addProtocolRequestHeaders(meta requestMeta, headers http.Header) { - grpcAddRequestMeta("application/grpc+", meta, headers) - headers.Set("Te", "trailers") -} - -func (g grpcServerProtocol) extractProtocolResponseHeaders(statusCode int, headers http.Header) (responseMeta, responseEndUnmarshaller, error) { - return grpcExtractResponseMeta("application/grpc", "application/grpc+", statusCode, headers), nil, nil -} - -func (g grpcServerProtocol) extractEndFromTrailers(_ *operation, trailers http.Header) (responseEnd, error) { - return responseEnd{ - err: grpcExtractErrorFromTrailer(trailers), - trailers: trailers, - }, nil +func (g grpcServerProtocol) EncodeRequestHeader(meta *requestMeta) error { + codec, err := g.config.GetServerCodec(meta.CodecName) + if err != nil { + return err + } + meta.Server.Codec = codec + if meta.CompressionName != "" { + compressor, err := g.config.GetServerCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Server.Compressor = compressor + } + grpcEncodeRequestHeader("application/grpc+", meta) + meta.URL.Path = g.config.methodPath + meta.Header.Set("Te", "trailers") + meta.Method = http.MethodPost + meta.ProtoMajor = 2 + meta.ProtoMinor = 0 + return nil } -func (g grpcServerProtocol) decodeEnvelope(envBytes envelopeBytes) (envelope, error) { - flags := envBytes[0] - if flags != 0 && flags != 1 { - return envelope{}, fmt.Errorf("invalid compression flag: must be 0 or 1; instead got %d", flags) +func (g grpcServerProtocol) PrepareRequestMessage(msg *messageBuffer, meta *requestMeta) error { + msg.Dst.IsEnvelope = true + if meta.Server.Compressor != nil { + msg.Dst.IsCompressed = true + msg.Dst.Flags |= 1 } - return envelope{ - compressed: flags == 1, - length: binary.BigEndian.Uint32(envBytes[1:]), - }, nil + return nil } -func (g grpcServerProtocol) encodeEnvelope(env envelope) envelopeBytes { - var envBytes envelopeBytes - if env.compressed { - envBytes[0] = 1 +func (g grpcServerProtocol) DecodeRequestHeader(meta *responseMeta) error { + if err := grpcExtractResponseMeta("application/grpc", "application/grpc+", meta); err != nil { + return err } - binary.BigEndian.PutUint32(envBytes[1:], env.length) - return envBytes + codec, err := g.config.GetServerCodec(meta.CodecName) + if err != nil { + return err + } + meta.Server.Codec = codec + if meta.CompressionName != "" { + compressor, err := g.config.GetServerCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Server.Compressor = compressor + } + return nil } -func (g grpcServerProtocol) decodeEndFromMessage(_ *operation, _ io.Reader) (responseEnd, error) { - return responseEnd{}, errors.New("gRPC protocol does not allow embedding result/trailers in body") +func (g grpcServerProtocol) PrepareResponseMessage(msg *messageBuffer, meta *responseMeta) error { + if msg.Buf.Len() < 5 { + return io.ErrShortBuffer // ask for more data + } + flags, size, err := readEnvelope(msg.Buf) + if err != nil { + return err + } + msg.Src.Size = size + isCompressed := flags&1 != 0 + if isCompressed { + if meta.Server.Compressor == nil { + return protocolError("server sent compressed message but client did not request compression") + } + msg.Src.IsCompressed = true + } + return nil } -func (g grpcServerProtocol) String() string { - return protocolNameGRPC +func (g grpcServerProtocol) DecodeResponseTrailer(_ *bytes.Buffer, meta *responseMeta) error { + if err := grpcExtractErrorFromTrailer(meta.Header); err != nil { + return err // Connect error + } + return nil } // grpcClientProtocol implements the gRPC protocol for // processing RPCs received from the client. -type grpcWebClientProtocol struct{} +type grpcWebClientProtocol struct { + config *methodConfig +} var _ clientProtocolHandler = grpcWebClientProtocol{} -var _ envelopedProtocolHandler = grpcWebClientProtocol{} -func (g grpcWebClientProtocol) protocol() Protocol { +func (g grpcWebClientProtocol) Protocol() Protocol { return ProtocolGRPCWeb } -func (g grpcWebClientProtocol) acceptsStreamType(_ *operation, _ connect.StreamType) bool { - return true +func (g grpcWebClientProtocol) DecodeRequestHeader(meta *requestMeta) error { + if meta.Method != http.MethodPost { + return protocolError("invalid method %q", meta.Method) + } + if err := grpcExtractRequestMeta("application/grpc-web", "application/grpc-web+", meta); err != nil { + return err + } + // Resolve Codecs + codec, err := g.config.GetClientCodec(meta.CodecName) + if err != nil { + return err + } + meta.Client.Codec = codec + if meta.CompressionName != "" { + compressor, err := g.config.GetClientCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Client.Compressor = compressor + } + return nil } -func (g grpcWebClientProtocol) extractProtocolRequestHeaders(_ *operation, headers http.Header) (requestMeta, error) { - return grpcExtractRequestMeta("application/grpc-web", "application/grpc-web+", headers) +func (g grpcWebClientProtocol) PrepareRequestMessage(msg *messageBuffer, meta *requestMeta) error { + if msg.Buf.Len() < 5 { + return io.ErrShortBuffer // ask for more data + } + flags, size, err := readEnvelope(msg.Buf) + if err != nil { + return err + } + msg.Src.Size = size + if flags&1 != 0 { + msg.Src.IsCompressed = true + if meta.Client.Compressor == nil { + return protocolError("server sent compressed message but client did not request compression") + } + } + return nil } -func (g grpcWebClientProtocol) addProtocolResponseHeaders(meta responseMeta, headers http.Header) int { - return grpcAddResponseMeta("application/grpc-web+", meta, headers) -} +func (g grpcWebClientProtocol) EncodeRequestHeader(meta *responseMeta) error { + grpcEncodeResponseHeader("application/grpc-web+", meta) -func (g grpcWebClientProtocol) encodeEnd(op *operation, end *responseEnd, writer io.Writer, wasInHeaders bool) http.Header { - if wasInHeaders { - // already recorded this in call to addProtocolResponseHeaders - return nil + codec, err := g.config.GetClientCodec(meta.CodecName) + if err != nil { + return err + } + meta.Client.Codec = codec + if meta.CompressionName != "" { + comp, err := g.config.GetClientCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Client.Compressor = comp } - trailers := make(http.Header, len(end.trailers)+3) - grpcWriteEndToTrailers(end, trailers) - buffer := op.bufferPool.Get() - defer op.bufferPool.Put(buffer) - _ = trailers.Write(buffer) - // TODO: compress? - env := envelope{trailer: true, length: uint32(buffer.Len())} - envBytes := g.encodeEnvelope(env) - _, _ = writer.Write(envBytes[:]) - _, _ = buffer.WriteTo(writer) return nil } -func (g grpcWebClientProtocol) decodeEnvelope(bytes envelopeBytes) (envelope, error) { - return grpcServerProtocol{}.decodeEnvelope(bytes) +func (g grpcWebClientProtocol) PrepareResponseMessage(msg *messageBuffer, meta *responseMeta) error { + msg.Dst.IsEnvelope = true + if meta.Client.Compressor != nil { + msg.Dst.Flags |= 1 + msg.Dst.IsCompressed = true + } + return nil } -func (g grpcWebClientProtocol) encodeEnvelope(env envelope) envelopeBytes { - var envBytes envelopeBytes - if env.compressed { - envBytes[0] = 1 - } - if env.trailer { - envBytes[0] |= 0x80 - } - binary.BigEndian.PutUint32(envBytes[1:], env.length) - return envBytes +func (g grpcWebClientProtocol) EncodeResponseTrailer(buf *bytes.Buffer, meta *responseMeta) error { + trailer := httpExtractTrailers(meta.Header) + trailer["grpc-status"] = []string{"0"} + trailer["grpc-message"] = []string{""} + grpcWebEncodeTrailers(buf, trailer) + return nil } -func (g grpcWebClientProtocol) String() string { - return protocolNameGRPCWeb +func (g grpcWebClientProtocol) EncodeError(buf *bytes.Buffer, meta *responseMeta, err error) { + cerr := asConnectError(err) + if !meta.WroteStatus { + grpcEncodeError(cerr, meta.Header) + meta.StatusCode = http.StatusOK // gRPC errors are always HTTP 200 + } else { + trailer := httpExtractTrailers(meta.Header) + grpcEncodeError(cerr, trailer) + grpcWebEncodeTrailers(buf, trailer) + } } // grpcServerProtocol implements the gRPC-Web protocol for // sending RPCs to the server handler. -type grpcWebServerProtocol struct{} +type grpcWebServerProtocol struct { + config *methodConfig +} var _ serverProtocolHandler = grpcWebServerProtocol{} -var _ serverEnvelopedProtocolHandler = grpcWebServerProtocol{} -func (g grpcWebServerProtocol) protocol() Protocol { +func (g grpcWebServerProtocol) Protocol() Protocol { return ProtocolGRPCWeb } -func (g grpcWebServerProtocol) addProtocolRequestHeaders(meta requestMeta, headers http.Header) { - grpcAddRequestMeta("application/grpc-web+", meta, headers) -} - -func (g grpcWebServerProtocol) extractProtocolResponseHeaders(statusCode int, headers http.Header) (responseMeta, responseEndUnmarshaller, error) { - return grpcExtractResponseMeta("application/grpc-web", "application/grpc-web+", statusCode, headers), nil, nil -} - -func (g grpcWebServerProtocol) extractEndFromTrailers(_ *operation, _ http.Header) (responseEnd, error) { - return responseEnd{}, errors.New("gRPC-Web protocol does not use HTTP trailers") +func (g grpcWebServerProtocol) EncodeRequestHeader(meta *requestMeta) error { + codec, err := g.config.GetServerCodec(meta.CodecName) + if err != nil { + return err + } + meta.Server.Codec = codec + if meta.CompressionName != "" { + comp, err := g.config.GetServerCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Server.Compressor = comp + } + grpcEncodeRequestHeader("application/grpc-web+", meta) + meta.URL.Path = g.config.methodPath + meta.Method = http.MethodPost + return nil } -func (g grpcWebServerProtocol) decodeEnvelope(envBytes envelopeBytes) (envelope, error) { - flags := envBytes[0] - if flags&0b0111_1110 != 0 { - // invalid bits are set - return envelope{}, fmt.Errorf("invalid frame flags: only highest and lowest bits may be set; instead got %d", flags) +func (g grpcWebServerProtocol) PrepareRequestMessage(msg *messageBuffer, meta *requestMeta) error { + msg.Dst.IsEnvelope = true + if meta.Server.Compressor != nil { + msg.Dst.Flags |= 1 + msg.Dst.IsCompressed = true } - return envelope{ - compressed: flags&1 != 0, - trailer: flags&0x80 != 0, - length: binary.BigEndian.Uint32(envBytes[1:]), - }, nil + return nil } -func (g grpcWebServerProtocol) encodeEnvelope(env envelope) envelopeBytes { - // Request streams don't have trailers, so we can re-use the gRPC implementation - // without worrying about gRPC-Web's in-body trailers. - return grpcServerProtocol{}.encodeEnvelope(env) +func (g grpcWebServerProtocol) DecodeRequestHeader(meta *responseMeta) error { + if err := grpcExtractResponseMeta("application/grpc-web", "application/grpc-web+", meta); err != nil { + return err + } + codec, err := g.config.GetServerCodec(meta.CodecName) + if err != nil { + return err + } + meta.Server.Codec = codec + if meta.CompressionName != "" { + compressor, err := g.config.GetServerCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Server.Compressor = compressor + } + return err } -func (g grpcWebServerProtocol) decodeEndFromMessage(op *operation, reader io.Reader) (responseEnd, error) { - // TODO: buffer size limit for headers/trailers; should use http.DefaultMaxHeaderBytes if not configured - buffer := op.bufferPool.Get() - defer op.bufferPool.Put(buffer) - _, err := buffer.ReadFrom(reader) +func (g grpcWebServerProtocol) PrepareResponseMessage(msg *messageBuffer, meta *responseMeta) error { + if msg.Buf.Len() < 5 { + return io.ErrShortBuffer // ask for more data + } + flags, size, err := readEnvelope(msg.Buf) if err != nil { - return responseEnd{}, err + return err } - headerLines := bytes.Split(buffer.Bytes(), []byte{'\r', '\n'}) - trailers := make(http.Header, len(headerLines)) - for i, headerLine := range headerLines { - // may have trailing newline, so ignore resulting trailing empty line - if len(headerLine) == 0 { - continue + msg.Src.Size = size + if flags&1 != 0 { + msg.Src.IsCompressed = true + if meta.Server.Compressor == nil { + return protocolError("server sent compressed message but client did not request compression") } - pos := bytes.IndexByte(headerLine, ':') - if pos == -1 { - return responseEnd{}, fmt.Errorf("response body included malformed trailer at line %d", i+1) - } - trailers.Add(string(headerLine[:pos]), strings.TrimSpace(string(headerLine[pos+1:]))) } - return responseEnd{ - err: grpcExtractErrorFromTrailer(trailers), - trailers: trailers, - }, nil + if flags&0b0111_1110 != 0 { + return protocolError("invalid frame flags: only highest and lowest bits may be set; instead got %d", flags) + } + msg.Src.IsTrailer = flags&grpcWebFlagEnvelopeTrailer != 0 + return nil } -func (g grpcWebServerProtocol) String() string { - return protocolNameGRPCWeb +func (g grpcWebServerProtocol) DecodeResponseTrailer(buf *bytes.Buffer, meta *responseMeta) error { + // Check for grpc-Web trailers in headers, otherwise look for trailers in + // the message. + trailer := meta.Header + if meta.Header.Get("Grpc-Status") == "" { + // Per the gRPC-Web specification, trailers should be encoded as an HTTP/1 + // headers block _without_ the terminating newline. To make the headers + // parseable by net/textproto, we need to add the newline. + buf.WriteByte('\n') + bufferedReader := bufio.NewReader(buf) + mimeReader := textproto.NewReader(bufferedReader) + mimeHeader, mimeErr := mimeReader.ReadMIMEHeader() + if mimeErr != nil { + return protocolError("invalid trailer: %w", mimeErr) + } + mimeHeader.Del("Content-Type") + for key, vals := range mimeHeader { + if !strings.HasPrefix(key, "Grpc-") { + key = http.TrailerPrefix + key + } + trailer[key] = vals + } + } + if err := grpcExtractErrorFromTrailer(trailer); err != nil { + return err + } + return nil } -func grpcExtractRequestMeta(contentTypeShort, contentTypePrefix string, headers http.Header) (requestMeta, error) { - var reqMeta requestMeta - if err := grpcExtractTimeoutFromHeaders(headers, &reqMeta); err != nil { - return reqMeta, err +func grpcExtractRequestMeta(contentTypeShort, contentTypePrefix string, meta *requestMeta) error { + if err := grpcExtractTimeoutFromHeaders(meta.Header, meta); err != nil { + return err } - contentType := headers.Get("Content-Type") + contentType := meta.Header.Get("Content-Type") if contentType == contentTypeShort { - reqMeta.codec = CodecProto + meta.CodecName = CodecProto } else { - reqMeta.codec = strings.TrimPrefix(contentType, contentTypePrefix) + meta.CodecName = strings.TrimPrefix(contentType, contentTypePrefix) } - headers.Del("Content-Type") - reqMeta.compression = headers.Get("Grpc-Encoding") - headers.Del("Grpc-Encoding") - reqMeta.acceptCompression = parseMultiHeader(headers.Values("Grpc-Accept-Encoding")) - headers.Del("Grpc-Accept-Encoding") - return reqMeta, nil + meta.Header.Del("Content-Type") + meta.CompressionName = meta.Header.Get("Grpc-Encoding") + meta.Header.Del("Grpc-Encoding") + meta.AcceptCompression = parseMultiHeader(meta.Header.Values("Grpc-Accept-Encoding")) + meta.Header.Del("Grpc-Accept-Encoding") + return nil } -func grpcExtractResponseMeta(contentTypeShort, contentTypePrefix string, statusCode int, headers http.Header) responseMeta { - var respMeta responseMeta - contentType := headers.Get("Content-Type") +func grpcExtractResponseMeta(contentTypeShort, contentTypePrefix string, meta *responseMeta) error { + contentType := meta.Header.Get("Content-Type") switch { case contentType == contentTypeShort: - respMeta.codec = CodecProto + meta.CodecName = CodecProto case strings.HasPrefix(contentType, contentTypePrefix): - respMeta.codec = strings.TrimPrefix(contentType, contentTypePrefix) + meta.CodecName = strings.TrimPrefix(contentType, contentTypePrefix) default: - respMeta.codec = contentType + "?" + meta.CodecName = contentType + "?" } - headers.Del("Content-Type") - respMeta.compression = headers.Get("Grpc-Encoding") - headers.Del("Grpc-Encoding") - respMeta.acceptCompression = parseMultiHeader(headers.Values("Grpc-Accept-Encoding")) - headers.Del("Grpc-Accept-Encoding") + meta.Header.Del("Content-Type") + meta.CompressionName = meta.Header.Get("Grpc-Encoding") + meta.Header.Del("Grpc-Encoding") + meta.AcceptCompression = parseMultiHeader(meta.Header.Values("Grpc-Accept-Encoding")) + meta.Header.Del("Grpc-Accept-Encoding") // See if RPC is already over (unexpected HTTP error or trailers-only response) - if len(headers.Values("Grpc-Status")) > 0 { - connErr := grpcExtractErrorFromTrailer(headers) - respMeta.end = &responseEnd{ - err: connErr, - httpCode: statusCode, - } - headers.Del("Grpc-Status") - headers.Del("Grpc-Message") - headers.Del("Grpc-Status-Details-Bin") - if contentType == "" { - // no need to report "?" codec if no content-type on a trailers-only response - respMeta.codec = "" + if len(meta.Header.Values("Grpc-Status")) > 0 { + if err := grpcExtractErrorFromTrailer(meta.Header); err != nil { + return err } } - if statusCode != http.StatusOK { - if respMeta.end == nil { - respMeta.end = &responseEnd{} - } - if respMeta.end.err == nil { - // TODO: map HTTP status code to an RPC error (opposite of httpStatusCodeFromRPC) - respMeta.end.err = connect.NewError(connect.CodeInternal, fmt.Errorf("unexpected HTTP error: %d %s", statusCode, http.StatusText(statusCode))) - } + if meta.StatusCode != http.StatusOK { + return connect.NewError(connect.CodeUnknown, + fmt.Errorf("unexpected HTTP error: %d %s", + meta.StatusCode, http.StatusText(meta.StatusCode))) } - return respMeta + return nil } -func grpcAddRequestMeta(contentTypePrefix string, meta requestMeta, headers http.Header) { - headers.Set("Content-Type", contentTypePrefix+meta.codec) - if meta.compression != "" { - headers.Set("Grpc-Encoding", meta.compression) - } - if len(meta.acceptCompression) > 0 { - headers.Set("Grpc-Accept-Encoding", strings.Join(meta.acceptCompression, ", ")) +func grpcEncodeRequestHeader(contentTypePrefix string, meta *requestMeta) { + meta.Header.Set("Content-Type", contentTypePrefix+meta.CodecName) + if meta.Server.Compressor != nil { + meta.Header.Set("Grpc-Encoding", meta.CompressionName) + meta.Header.Set("Grpc-Accept-Encoding", meta.CompressionName) } - if meta.hasTimeout { - timeoutStr := grpcEncodeTimeout(meta.timeout) - headers.Set("Grpc-Timeout", timeoutStr) + if meta.Timeout > 0 { + timeoutStr := grpcEncodeTimeout(meta.Timeout) + meta.Header.Set("Grpc-Timeout", timeoutStr) } } -func grpcAddResponseMeta(contentTypePrefix string, meta responseMeta, headers http.Header) int { - if meta.end != nil { - grpcWriteEndToTrailers(meta.end, headers) - return http.StatusOK +func grpcEncodeResponseHeader(contentTypePrefix string, meta *responseMeta) { + meta.Header.Set("Content-Type", contentTypePrefix+meta.CodecName) + if meta.CompressionName != "" { + meta.Header.Set("Grpc-Encoding", meta.CompressionName) + meta.Header.Set("Grpc-Accept-Encoding", meta.CompressionName) } - headers.Set("Content-Type", contentTypePrefix+meta.codec) - if meta.compression != "" { - headers.Set("Grpc-Encoding", meta.compression) - } - if len(meta.acceptCompression) > 0 { - headers.Set("Grpc-Accept-Encoding", strings.Join(meta.acceptCompression, ", ")) - } - return http.StatusOK } -func grpcWriteEndToTrailers(respEnd *responseEnd, trailers http.Header) { - for k, v := range respEnd.trailers { - trailers[k] = v +func grpcEncodeError(cerr *connect.Error, trailers httpHeader) { + trailers.Set("Grpc-Status", strconv.Itoa(int(cerr.Code()))) + trailers.Set("Grpc-Message", grpcPercentEncode(cerr.Message())) + if len(cerr.Details()) == 0 { + return } - if respEnd.err == nil { - trailers.Set("Grpc-Status", "0") - trailers.Set("Grpc-Message", "") - } else { - trailers.Set("Grpc-Status", strconv.Itoa(int(respEnd.err.Code()))) - trailers.Set("Grpc-Message", grpcPercentEncode(respEnd.err.Message())) - if len(respEnd.err.Details()) == 0 { - return - } - stat := grpcStatusFromError(respEnd.err) - bin, err := proto.Marshal(stat) - if err == nil { - trailers.Set("Grpc-Status-Details-Bin", connect.EncodeBinaryHeader(bin)) - } + stat := grpcStatusFromError(cerr) + bin, err := proto.Marshal(stat) + if err == nil { + trailers.Set("Grpc-Status-Details-Bin", connect.EncodeBinaryHeader(bin)) } } @@ -495,7 +613,7 @@ func grpcShouldEscape(char byte) bool { // binary Protobuf format, even if the messages in the request/response stream // use a different codec. Consequently, this function needs a Protobuf codec to // unmarshal error information in the headers. -func grpcExtractErrorFromTrailer(trailers http.Header) *connect.Error { +func grpcExtractErrorFromTrailer(trailers httpHeader) *connect.Error { grpcStatus := trailers.Get("Grpc-Status") grpcMsg := trailers.Get("Grpc-Message") grpcDetails := trailers.Get("Grpc-Status-Details-Bin") @@ -563,7 +681,7 @@ func grpcExtractErrorFromTrailer(trailers http.Header) *connect.Error { return trailerErr } -func grpcExtractTimeoutFromHeaders(headers http.Header, meta *requestMeta) error { +func grpcExtractTimeoutFromHeaders(headers httpHeader, meta *requestMeta) error { timeoutStr := headers.Get("Grpc-Timeout") headers.Del("Grpc-Timeout") if timeoutStr == "" { @@ -573,8 +691,7 @@ func grpcExtractTimeoutFromHeaders(headers http.Header, meta *requestMeta) error if err != nil { return err } - meta.timeout = timeout - meta.hasTimeout = true + meta.Timeout = timeout return nil } @@ -647,3 +764,19 @@ func grpcTimeoutUnitLookup(unit byte) time.Duration { return 0 } } + +func grpcWebEncodeTrailers(dst *bytes.Buffer, trailer httpHeader) { + for key, values := range trailer { + lowerKey := strings.ToLower(key) + if lowerKey == key { + continue + } + delete(trailer, key) + trailer[lowerKey] = values + } + dst.Write([]byte{0, 0, 0, 0, 0}) // empty message + _ = http.Header(trailer).Write(dst) + dst.Bytes()[0] |= grpcWebFlagEnvelopeTrailer // set trailer flag + size := uint32(dst.Len() - 5) + binary.BigEndian.PutUint32(dst.Bytes()[1:], size) +} diff --git a/protocol_grpc_test.go b/protocol_grpc_test.go index bed9d4d..a31f8a3 100644 --- a/protocol_grpc_test.go +++ b/protocol_grpc_test.go @@ -36,7 +36,7 @@ func TestGRPCErrorWriter(t *testing.T) { err := fmt.Errorf("test error: %s", "Hello, 世界") cerr := connect.NewWireError(connect.CodeUnauthenticated, err) rec := httptest.NewRecorder() - grpcWriteEndToTrailers(&responseEnd{err: cerr}, rec.Header()) + grpcEncodeError(cerr, httpHeader(rec.Header())) assert.Equal(t, "16", rec.Header().Get("Grpc-Status")) assert.Equal(t, "test error: Hello, %E4%B8%96%E7%95%8C", rec.Header().Get("Grpc-Message")) @@ -44,7 +44,7 @@ func TestGRPCErrorWriter(t *testing.T) { assert.Equal(t, "", rec.Header().Get("Grpc-Status-Details-Bin")) assert.Len(t, rec.Body.Bytes(), 0) - got := grpcExtractErrorFromTrailer(rec.Header()) + got := grpcExtractErrorFromTrailer(httpHeader(rec.Header())) assert.Equal(t, cerr, got) // Now again, but this time an error with details @@ -52,14 +52,14 @@ func TestGRPCErrorWriter(t *testing.T) { require.NoError(t, err) cerr.AddDetail(errDetail) rec = httptest.NewRecorder() - grpcWriteEndToTrailers(&responseEnd{err: cerr}, rec.Header()) + grpcEncodeError(cerr, httpHeader(rec.Header())) assert.Equal(t, "16", rec.Header().Get("Grpc-Status")) assert.Equal(t, "test error: Hello, %E4%B8%96%E7%95%8C", rec.Header().Get("Grpc-Message")) assert.Equal(t, "CBASGXRlc3QgZXJyb3I6IEhlbGxvLCDkuJbnlYwaOAovdHlwZS5nb29nbGVhcGlzLmNvbS9nb29nbGUucHJvdG9idWYuU3RyaW5nVmFsdWUSBQoDZm9v", rec.Header().Get("Grpc-Status-Details-Bin")) assert.Len(t, rec.Body.Bytes(), 0) - got = grpcExtractErrorFromTrailer(rec.Header()) + got = grpcExtractErrorFromTrailer(httpHeader(rec.Header())) assert.Equal(t, cerr, got) } diff --git a/protocol_http.go b/protocol_http.go index 4d2d639..ce836d7 100644 --- a/protocol_http.go +++ b/protocol_http.go @@ -290,35 +290,51 @@ func httpEncodePathValues(input protoreflect.Message, target *routeTarget) ( return path, query, nil } -func httpExtractTrailers(headers http.Header, knownTrailerKeys headerKeys) http.Header { - trailers := make(http.Header, len(knownTrailerKeys)) - for key, vals := range headers { - if strings.HasPrefix(key, http.TrailerPrefix) { - trailers[strings.TrimPrefix(key, http.TrailerPrefix)] = vals - delete(headers, key) - continue - } - if _, expected := knownTrailerKeys[key]; expected { - trailers[key] = vals - delete(headers, key) - continue - } +type httpHeader map[string][]string + +func (h httpHeader) Get(key string) string { + if vals := h[key]; len(vals) > 0 { + return vals[0] + } else if vals := h[http.TrailerPrefix+key]; len(vals) > 0 { + return vals[0] } - return trailers + return "" +} +func (h httpHeader) Set(key, value string) { + h[key] = []string{value} +} +func (h httpHeader) Add(key, value string) { + h[key] = append(h[key], value) +} +func (h httpHeader) Del(key string) { + delete(h, key) + delete(h, http.TrailerPrefix+key) +} +func (h httpHeader) Values(key string) []string { + return h[key] } -func httpMergeTrailers(header http.Header, trailer http.Header) { - for key, vals := range trailer { - if !strings.HasPrefix(key, http.TrailerPrefix) { - key = http.TrailerPrefix + key - } - for _, val := range vals { - header.Add(key, val) +func httpExtractTrailers(header httpHeader) httpHeader { + // Parse the trailer keys from the Trailer header. + keys := parseMultiHeader(header["Trailer"]) + trailerKeys := make(map[string]struct{}, len(keys)) + for _, key := range keys { + trailerKeys[key] = struct{}{} + } + trailer := make(httpHeader, len(trailerKeys)) + for key, vals := range header { + if strings.HasPrefix(key, http.TrailerPrefix) { + key = strings.TrimPrefix(key, http.TrailerPrefix) + } else if _, ok := trailerKeys[key]; !ok { + continue } + trailer[key] = vals } + return trailer } -func httpExtractContentLength(headers http.Header) (int, error) { +// TODO: fix content-length. +func httpExtractContentLength(headers http.Header) (int, error) { //nolint:deadcode,unused contentLenStr := headers.Get("Content-Length") if contentLenStr == "" { return -1, nil diff --git a/protocol_rest.go b/protocol_rest.go index c1864a1..522c314 100644 --- a/protocol_rest.go +++ b/protocol_rest.go @@ -17,8 +17,7 @@ package vanguard import ( "bytes" "fmt" - "io" - "net/http" + "net/url" "strconv" "strings" "time" @@ -28,385 +27,493 @@ import ( "google.golang.org/protobuf/reflect/protoreflect" ) -type restClientProtocol struct{} +type restClientProtocol struct { + target *routeTarget + vars []routeTargetVarMatch +} var _ clientProtocolHandler = restClientProtocol{} -var _ clientBodyPreparer = restClientProtocol{} -var _ clientProtocolEndMustBeInHeaders = restClientProtocol{} // restClientProtocol implements the REST protocol for // processing RPCs received from the client. -func (r restClientProtocol) protocol() Protocol { +func (r restClientProtocol) Protocol() Protocol { return ProtocolREST } -func (r restClientProtocol) acceptsStreamType(op *operation, streamType connect.StreamType) bool { - switch streamType { - case connect.StreamTypeUnary: - return true - case connect.StreamTypeClient: - return restHTTPBodyRequest(op) - case connect.StreamTypeServer: - // TODO: support server streams even when body is not google.api.HttpBody - return restHTTPBodyResponse(op) - default: - return false +func (r restClientProtocol) DecodeRequestHeader(meta *requestMeta) error { + if !r.acceptsStreamType() { + return protocolError("REST client protocol does not support %s stream", r.target.config.streamType) } -} - -func (r restClientProtocol) endMustBeInHeaders() bool { - // TODO: when we support server streams over REST, this should return false when streaming - return true -} - -func (r restClientProtocol) extractProtocolRequestHeaders(op *operation, headers http.Header) (requestMeta, error) { - var reqMeta requestMeta - reqMeta.compression = headers.Get("Content-Encoding") - headers.Del("Content-Encoding") + meta.CompressionName = meta.Header.Get("Content-Encoding") + meta.Header.Del("Content-Encoding") // TODO: A REST client could use "q" weights in the `Accept-Encoding` header, which // would currently cause the middleware to not recognize the compression. // We may want to address this. We'd need to sort the values by their weight // since other protocols don't allow weights with acceptable encodings. - reqMeta.acceptCompression = parseMultiHeader(headers.Values("Accept-Encoding")) - headers.Del("Accept-Encoding") + meta.AcceptCompression = parseMultiHeader(meta.Header.Values("Accept-Encoding")) + meta.Header.Del("Accept-Encoding") - reqMeta.codec = CodecJSON // if actually a custom content-type, handled by body preparer methods - contentType := headers.Get("Content-Type") + meta.CodecName = CodecJSON // if actually a custom content-type, handled by body preparer methods + contentType := meta.Header.Get("Content-Type") if contentType != "" && contentType != "application/json" && contentType != "application/json; charset=utf-8" && - !restHTTPBodyRequest(op) { - // invalid content-type - reqMeta.codec = contentType + "?" + !restIsHTTPBody(r.target.config.descriptor.Input(), r.target.requestBodyFields) { + meta.CodecName = contentType + "?" // invalid } - headers.Del("Content-Type") + meta.Header.Del("Content-Type") - if timeoutStr := headers.Get("X-Server-Timeout"); timeoutStr != "" { + if timeoutStr := meta.Header.Get("X-Server-Timeout"); timeoutStr != "" { timeout, err := strconv.ParseFloat(timeoutStr, 64) if err != nil { - return requestMeta{}, err + return err } - reqMeta.timeout = time.Duration(timeout * float64(time.Second)) + meta.Timeout = time.Duration(timeout * float64(time.Second)) } - return reqMeta, nil -} -func (r restClientProtocol) addProtocolResponseHeaders(meta responseMeta, headers http.Header) int { - isErr := meta.end != nil && meta.end.err != nil - // TODO: this formulation might only be valid when meta.codec is JSON; support other codecs. - // Headers are only set if they are not already set, specially to allow - // for google.api.HttpBody payloads. - if headers["Content-Type"] == nil { - headers["Content-Type"] = []string{"application/" + meta.codec} - } - // TODO: Content-Encoding to compress error, too? - if !isErr && meta.compression != "" { - headers["Content-Encoding"] = []string{meta.compression} + // Resolve codecs + codec, err := r.target.config.GetClientCodec(meta.CodecName) + if err != nil { + return err } - if len(meta.acceptCompression) != 0 { - headers["Accept-Encoding"] = []string{strings.Join(meta.acceptCompression, ", ")} + urlClone := *meta.URL + meta.Client.Codec = &restRequestCodec{ + target: r.target, + vars: r.vars, + url: &urlClone, + codec: codec, + contentType: contentType, } - if isErr { - return httpStatusCodeFromRPC(meta.end.err.Code()) + if meta.CompressionName != "" { + comp, err := r.target.config.GetClientCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Client.Compressor = comp } - return http.StatusOK + return nil } -func (r restClientProtocol) encodeEnd(op *operation, end *responseEnd, writer io.Writer, wasInHeaders bool) http.Header { - cerr := end.err - if cerr != nil && !wasInHeaders { - // TODO: Uh oh. We already flushed headers and started writing body. What can we do? - // Should this log? If we are using http/2, is there some way we could send - // a "goaway" frame to the client, to indicate abnormal end of stream? - return nil - } - if cerr == nil { - return nil - } - stat := grpcStatusFromError(cerr) - bin, err := op.client.codec.MarshalAppend(nil, stat) - if err != nil { - // TODO: This is always uses JSON whereas above we use the given codec. - // If/when we support codecs for REST other than JSON, what should - // we do here? - bin = []byte(`{"code": 13, "message": ` + strconv.Quote("failed to marshal end error: "+err.Error()) + `}`) +func (r restClientProtocol) acceptsStreamType() bool { + switch r.target.config.streamType { + case connect.StreamTypeUnary: + return true + case connect.StreamTypeClient: + return restIsHTTPBody(r.target.config.descriptor.Input(), r.target.requestBodyFields) + case connect.StreamTypeServer: + return restIsHTTPBody(r.target.config.descriptor.Output(), r.target.responseBodyFields) + case connect.StreamTypeBidi: + return false } - // TODO: compress? - _, _ = writer.Write(bin) - return nil + return false } -func (r restClientProtocol) requestNeedsPrep(op *operation) bool { - return len(op.restTarget.vars) != 0 || - len(op.request.URL.Query()) != 0 || - op.restTarget.requestBodyFields != nil || - restHTTPBodyRequest(op) +func (r restClientProtocol) PrepareRequestMessage(msg *messageBuffer, meta *requestMeta) error { + msg.Src.IsCompressed = meta.Client.Compressor != nil + msg.Src.IsTrailer = false + msg.Src.ReadMode = readModeEOF + if r.target.config.streamType&connect.StreamTypeClient != 0 { + msg.Src.ReadMode = readModeChunk + } + return nil } -func (r restClientProtocol) prepareUnmarshalledRequest(op *operation, src []byte, target proto.Message) error { - if err := r.prepareUnmarshalledRequestFromBody(op, src, target); err != nil { - return err - } - // Now pull in the fields from the URI path: - msg := target.ProtoReflect() - for i := len(op.restVars) - 1; i >= 0; i-- { - variable := op.restVars[i] - if err := setParameter(msg, variable.fields, variable.value); err != nil { +func (r restClientProtocol) EncodeRequestHeader(meta *responseMeta) error { + var baseCodec Codec + if meta.CodecName == CodecJSON { + codec, err := r.target.config.GetClientCodec(CodecJSON) + if err != nil { return err } + baseCodec = codec + meta.Header.Set("Content-Type", "application/json") } - // And finally from the query string: - for fieldPath, values := range op.queryValues() { - fields, err := resolvePathToDescriptors(msg.Descriptor(), fieldPath) + meta.Client.Codec = &restResponseCodec{ + target: r.target, + meta: meta, // Encode URL values. + codec: baseCodec, + } + if meta.CompressionName != "" { + comp, err := r.target.config.GetClientCompressor(meta.CompressionName) if err != nil { return err } - for _, value := range values { - if err := setParameter(msg, fields, value); err != nil { - return err - } - } + meta.Client.Compressor = comp + meta.Header.Set("Content-Encoding", meta.CompressionName) + meta.Header.Set("Accept-Encoding", meta.CompressionName) } return nil } -func (r restClientProtocol) prepareUnmarshalledRequestFromBody(op *operation, src []byte, target proto.Message) error { - if op.restTarget.requestBodyFields == nil { - if len(src) > 0 { - return fmt.Errorf("request should have no body; instead got %d bytes", len(src)) - } - return nil - } +func (r restClientProtocol) PrepareResponseMessage(msg *messageBuffer, meta *responseMeta) error { + msg.Dst.Flags = 0 + msg.Dst.IsEnvelope = false + msg.Dst.IsCompressed = meta.Client.Compressor != nil - msg, leafField, err := getBodyField(op.restTarget.requestBodyFields, target.ProtoReflect(), protoreflect.Message.Mutable) - if err != nil { - return err - } - - if leafField == nil && restIsHTTPBody(msg.Descriptor(), nil) { - fields := msg.Descriptor().Fields() - contentType := op.reqContentType - msg.Set(fields.ByName("content_type"), protoreflect.ValueOfString(contentType)) - msg.Set(fields.ByName("data"), protoreflect.ValueOfBytes(src)) - return nil + // TODO: support httpBody compression + isHTTPBody := restIsHTTPBody(r.target.config.descriptor.Input(), r.target.requestBodyFields) + if isHTTPBody { + // TODO: support full stream compression + msg.Dst.IsCompressed = false } - - if len(src) == 0 { - // No data to unmarshal. - return nil - } - if leafField == nil { - return op.client.codec.Unmarshal(src, msg.Interface()) - } - restCodec, ok := op.client.codec.(RESTCodec) - if !ok { - return fmt.Errorf("codec %q (%T) does not implement RESTCodec, so non-message request body cannot be unmarshalled", - op.client.codec.Name(), op.client.codec) - } - return restCodec.UnmarshalField(src, msg.Interface(), leafField) + return nil } -func (r restClientProtocol) responseNeedsPrep(op *operation) bool { - return len(op.restTarget.responseBodyFields) != 0 || - restHTTPBodyResponse(op) +func (r restClientProtocol) EncodeResponseTrailer(_ *bytes.Buffer, _ *responseMeta) error { + return nil } -func (r restClientProtocol) prepareMarshalledResponse(op *operation, base []byte, src proto.Message, headers http.Header) ([]byte, error) { - if restHTTPBodyResponse(op) { - msg := src.ProtoReflect() - for _, field := range op.restTarget.responseBodyFields { - msg = msg.Get(field).Message() - } - if !msg.IsValid() { - return base, nil - } - desc := msg.Descriptor() - dataField := desc.Fields().ByName("data") - contentField := desc.Fields().ByName("content_type") - contentType := msg.Get(contentField).String() - bytes := msg.Get(dataField).Bytes() - if contentType != "" { - headers.Set("Content-Type", contentType) - } - return bytes, nil - } +func (r restClientProtocol) EncodeError(buf *bytes.Buffer, meta *responseMeta, err error) { + // Encode the error as uncompressed JSON. + meta.Header.Del("Content-Encoding") + meta.Header.Set("Content-Type", "application/json") - msg, leafField, err := getBodyField(op.restTarget.responseBodyFields, src.ProtoReflect(), protoreflect.Message.Get) + cerr := asConnectError(err) + stat := grpcStatusFromError(cerr) + meta.StatusCode = httpStatusCodeFromRPC(cerr.Code()) + + codec, err := r.target.config.GetClientCodec(CodecJSON) if err != nil { - return nil, err + codec = DefaultJSONCodec(r.target.config.resolver) } - if leafField == nil { - return op.client.codec.MarshalAppend(base, msg.Interface()) + if err = marshal(buf, stat, codec); err != nil { + bin := []byte(`{"code": 13, "message": ` + strconv.Quote("failed to marshal end error: "+err.Error()) + `}`) + _, _ = buf.Write(bin) } - restCodec, ok := op.client.codec.(RESTCodec) - if !ok { - return nil, - fmt.Errorf("codec %q (%T) does not implement RESTCodec, so non-message response body cannot be marshalled", - op.client.codec.Name(), op.client.codec) - } - return restCodec.MarshalAppendField(base, msg.Interface(), leafField) -} - -func (r restClientProtocol) String() string { - return protocolNameREST } // restServerProtocol implements the REST protocol for // sending RPCs to the server handler. -type restServerProtocol struct{} +type restServerProtocol struct { + target *routeTarget + + // For error encoding. + statusCode int + contentType string +} -var _ serverProtocolHandler = restServerProtocol{} -var _ requestLineBuilder = restServerProtocol{} -var _ serverBodyPreparer = restServerProtocol{} +var _ serverProtocolHandler = (*restServerProtocol)(nil) -func (r restServerProtocol) protocol() Protocol { +func (r *restServerProtocol) Protocol() Protocol { return ProtocolREST } -func (r restServerProtocol) addProtocolRequestHeaders(meta requestMeta, headers http.Header) { - // TODO: don't set content-type on no body requests. - headers["Content-Type"] = []string{"application/" + meta.codec} - if meta.compression != "" { - headers["Content-Encoding"] = []string{meta.compression} +func (r *restServerProtocol) EncodeRequestHeader(meta *requestMeta) error { + codec, err := r.target.config.GetServerCodec(meta.CodecName) + if err != nil { + return err } - if len(meta.acceptCompression) != 0 { - headers["Accept-Encoding"] = []string{strings.Join(meta.acceptCompression, ", ")} + meta.Server.Codec = &restRequestCodec{ + target: r.target, + url: meta.URL, + header: meta.Header, + codec: codec, } - if meta.timeout != 0 { + meta.Method = r.target.method + + // Encode header values for the request body. + if r.target.requestBodyFieldPath != "" { + // Content-Type may be overridden by the body preparer. + meta.Header.Set("Content-Type", "application/"+meta.CodecName) + + if meta.CompressionName != "" { + compressor, err := r.target.config.GetServerCompressor(meta.CompressionName) + if err != nil { + return err + } + meta.Server.Compressor = compressor + meta.Header.Set("Content-Encoding", meta.CompressionName) + meta.Header.Set("Accept-Encoding", meta.CompressionName) + } + } + if meta.Timeout != 0 { // Encode timeout as a float in seconds. - value := strconv.FormatFloat(meta.timeout.Seconds(), 'E', -1, 64) - headers["X-Server-Timeout"] = []string{value} + value := strconv.FormatFloat(meta.Timeout.Seconds(), 'E', -1, 64) + meta.Header.Set("X-Server-Timeout", value) } + // Require request body to encode URL values. + meta.RequiresBody = true + return nil } -func (r restServerProtocol) extractProtocolResponseHeaders(statusCode int, headers http.Header) (responseMeta, responseEndUnmarshaller, error) { - contentType := headers.Get("Content-Type") - if statusCode/100 != 2 { - return responseMeta{ - end: &responseEnd{httpCode: statusCode}, - }, func(_ Codec, buf *bytes.Buffer, end *responseEnd) { - if err := httpErrorFromResponse(statusCode, contentType, buf); err != nil { - end.err = err - end.httpCode = httpStatusCodeFromRPC(err.Code()) - } - }, nil - } - var meta responseMeta +func (r *restServerProtocol) PrepareRequestMessage(msg *messageBuffer, meta *requestMeta) error { + msg.Dst.IsCompressed = meta.Server.Compressor != nil + return nil +} + +func (r *restServerProtocol) DecodeRequestHeader(meta *responseMeta) error { + r.statusCode = meta.StatusCode + r.contentType = meta.Header.Get("Content-Type") + meta.Header.Del("Content-Type") switch { - case contentType == "application/json": - meta.codec = CodecJSON - case strings.HasPrefix(contentType, "application/"): - meta.codec = strings.TrimPrefix(contentType, "application/") - if n := strings.Index(meta.codec, ";"); n != -1 { - meta.codec = meta.codec[:n] + case r.contentType == "application/json": + meta.CodecName = CodecJSON + case strings.HasPrefix(r.contentType, "application/"): + meta.CodecName = strings.TrimPrefix(r.contentType, "application/") + if n := strings.Index(meta.CodecName, ";"); n != -1 { + meta.CodecName = meta.CodecName[:n] + } + } + codec, err := r.target.config.GetServerCodec(meta.CodecName) + if err != nil { + return err + } + meta.Server.Codec = &restResponseCodec{ + target: r.target, + codec: codec, + meta: meta, + contentType: r.contentType, + } + + meta.CompressionName = meta.Header.Get("Content-Encoding") + meta.Header.Del("Content-Encoding") + if meta.CompressionName != "" { + compressor, err := r.target.config.GetServerCompressor(meta.CompressionName) + if err != nil { + return err } - default: - meta.codec = "" + meta.Server.Compressor = compressor } - headers.Del("Content-Type") - meta.compression = headers.Get("Content-Encoding") - headers.Del("Content-Encoding") + meta.AcceptCompression = parseMultiHeader(meta.Header.Values("Accept-Encoding")) + meta.Header.Del("Accept-Encoding") - meta.acceptCompression = parseMultiHeader(headers.Values("Accept-Encoding")) - headers.Del("Accept-Encoding") - return meta, nil, nil + return nil } -func (r restServerProtocol) extractEndFromTrailers(_ *operation, _ http.Header) (responseEnd, error) { - return responseEnd{}, nil +func (r *restServerProtocol) PrepareResponseMessage(msg *messageBuffer, meta *responseMeta) error { + msg.Src.IsCompressed = meta.Server.Compressor != nil + msg.Src.ReadMode = readModeEOF + if r.target.config.streamType&connect.StreamTypeServer != 0 { + msg.Src.ReadMode = readModeChunk + } + if r.statusCode/100 != 2 { + msg.Src.IsTrailer = true + msg.Src.ReadMode = readModeEOF + } + return nil } -func (r restServerProtocol) requestNeedsPrep(op *operation) bool { - if op.restTarget == nil { - return false // no REST bindings +func (r *restServerProtocol) DecodeResponseTrailer(buf *bytes.Buffer, _ *responseMeta) error { + if r.statusCode/100 == 2 { + return nil + } + if cerr := httpErrorFromResponse(r.statusCode, r.contentType, buf); cerr != nil { + return cerr } - return len(op.restTarget.vars) != 0 || - len(op.request.URL.Query()) != 0 || - op.restTarget.requestBodyFields != nil + return nil } -func (r restServerProtocol) prepareMarshalledRequest(op *operation, base []byte, src proto.Message, headers http.Header) ([]byte, error) { - if op.restTarget.requestBodyFields == nil { - return base, nil +type restRequestCodec struct { + target *routeTarget + vars []routeTargetVarMatch + codec Codec + url *url.URL // Clone or original URL. + header httpHeader // Empty or original header. + contentType string + count int +} + +func (c *restRequestCodec) Name() string { + if c.codec == nil { + return "rest" + } + return "rest-" + c.codec.Name() +} +func (c *restRequestCodec) MarshalAppend(dst []byte, msg proto.Message) ([]byte, error) { + defer func() { c.count++ }() + dst, err := c.marshalBody(dst, msg) + if err != nil { + return nil, err + } + if c.count > 0 { + return dst, nil } - msg, leafField, err := getBodyField(op.restTarget.requestBodyFields, src.ProtoReflect(), protoreflect.Message.Get) + path, query, err := httpEncodePathValues(msg.ProtoReflect(), c.target) if err != nil { return nil, err } - if restHTTPBodyRequest(op) { + c.url.Path = path + c.url.RawQuery = query.Encode() + return dst, nil +} +func (c *restRequestCodec) Unmarshal(src []byte, dst proto.Message) error { + defer func() { c.count++ }() + if err := c.unmarshalBody(src, dst); err != nil { + return err + } + if c.count > 0 { + return nil + } + // Now pull in the fields from the URI path: + msg := dst.ProtoReflect() + for i := len(c.vars) - 1; i >= 0; i-- { + variable := c.vars[i] + if err := setParameter(msg, variable.fields, variable.value); err != nil { + return err + } + } + // And finally from the query string: + for fieldPath, values := range c.url.Query() { + fields, err := resolvePathToDescriptors(msg.Descriptor(), fieldPath) + if err != nil { + return err + } + for _, value := range values { + if err := setParameter(msg, fields, value); err != nil { + return err + } + } + } + return nil +} + +func (c *restRequestCodec) marshalBody(dst []byte, src proto.Message) ([]byte, error) { + if c.target.requestBodyFields == nil { + return dst, nil + } + msg, leafField, err := getBodyField(c.target.requestBodyFields, src.ProtoReflect(), protoreflect.Message.Get) + if err != nil { + return nil, err + } + if leafField == nil && restIsHTTPBody(msg.Descriptor(), nil) { fields := msg.Descriptor().Fields() contentType := msg.Get(fields.ByName("content_type")).String() bytes := msg.Get(fields.ByName("data")).Bytes() - headers.Set("Content-Type", contentType) + if c.count == 0 && contentType != "" { + c.header.Set("Content-Type", contentType) + } return bytes, nil } if leafField == nil { - return op.server.codec.MarshalAppend(base, msg.Interface()) + return c.codec.MarshalAppend(dst, msg.Interface()) } - restCodec, ok := op.server.codec.(RESTCodec) - if !ok { - return nil, - fmt.Errorf("codec %q (%T) does not implement RESTCodec, so non-message request body cannot be marshalled", - op.server.codec.Name(), op.server.codec) + restCodec, err := asRESTCodec(c.codec) + if err != nil { + return nil, err } - return restCodec.MarshalAppendField(base, msg.Interface(), leafField) + return restCodec.MarshalAppendField(dst, msg.Interface(), leafField) } -func (r restServerProtocol) responseNeedsPrep(op *operation) bool { - return len(op.restTarget.responseBodyFieldPath) != 0 || - restHTTPBodyResponse(op) -} +func (c *restRequestCodec) unmarshalBody(src []byte, dst proto.Message) error { + if c.target.requestBodyFields == nil { + if len(src) > 0 { + return fmt.Errorf("request should have no body; instead got %d bytes", len(src)) + } + return nil + } + // Reset the message to clear any existing values, since we may only + // be partially encoding it. + proto.Reset(dst) -func (r restServerProtocol) prepareUnmarshalledResponse(op *operation, src []byte, target proto.Message) error { - msg, leafField, err := getBodyField(op.restTarget.responseBodyFields, target.ProtoReflect(), protoreflect.Message.Mutable) + msg, leafField, err := getBodyField(c.target.requestBodyFields, dst.ProtoReflect(), protoreflect.Message.Mutable) if err != nil { return err } - if restHTTPBodyResponse(op) { + + if leafField == nil && restIsHTTPBody(msg.Descriptor(), nil) { fields := msg.Descriptor().Fields() - contentType := op.rspContentType - msg.Set(fields.ByName("content_type"), protoreflect.ValueOfString(contentType)) - msg.Set(fields.ByName("data"), protoreflect.ValueOfBytes(src)) + if c.count == 0 { + msg.Set(fields.ByName("content_type"), protoreflect.ValueOfString(c.contentType)) + } + // Take ownership of the bytes. + data := make([]byte, len(src)) + copy(data, src) + msg.Set(fields.ByName("data"), protoreflect.ValueOfBytes(data)) return nil } if leafField == nil { - return op.server.codec.Unmarshal(src, msg.Interface()) + return c.codec.Unmarshal(src, msg.Interface()) } - restCodec, ok := op.server.codec.(RESTCodec) - if !ok { - return fmt.Errorf("codec %q (%T) does not implement RESTCodec, so non-message response body cannot be unmarshalled", - op.server.codec.Name(), op.server.codec) + restCodec, err := asRESTCodec(c.codec) + if err != nil { + return err } return restCodec.UnmarshalField(src, msg.Interface(), leafField) } -func (r restServerProtocol) requiresMessageToProvideRequestLine(_ *operation) bool { - return true +type restResponseCodec struct { + target *routeTarget + codec Codec + meta *responseMeta + contentType string + count int } -func (r restServerProtocol) requestLine(op *operation, req proto.Message) (urlPath, queryParams, method string, includeBody bool, err error) { - path, query, err := httpEncodePathValues(req.ProtoReflect(), op.restTarget) - if err != nil { - return "", "", "", false, err +func (c *restResponseCodec) Name() string { + if c.codec == nil { + return "rest" } - urlPath = path - queryParams = query.Encode() - includeBody = op.restTarget.requestBodyFields != nil // can be len(0) if body is '*' - // TODO: Should this return an error if URL (path + query string) is greater than op.methodConf.maxGetURLSz? - return urlPath, queryParams, op.restTarget.method, includeBody, nil + return "rest-" + c.codec.Name() } -func (r restServerProtocol) String() string { - return protocolNameREST +func (c *restResponseCodec) MarshalAppend(dst []byte, src proto.Message) ([]byte, error) { + defer func() { c.count++ }() + msg, leafField, err := getBodyField(c.target.responseBodyFields, src.ProtoReflect(), protoreflect.Message.Get) + if err != nil { + return nil, err + } + if leafField == nil && restIsHTTPBody(msg.Descriptor(), nil) { + if !msg.IsValid() { + return dst, nil + } + fields := msg.Descriptor().Fields() + dataField := fields.ByName("data") + contentField := fields.ByName("content_type") + contentType := msg.Get(contentField).String() + bytes := msg.Get(dataField).Bytes() + if c.count == 0 && contentType != "" { + c.meta.Header.Set("Content-Type", contentType) + } + return bytes, nil + } + if leafField == nil { + return c.codec.MarshalAppend(dst, msg.Interface()) + } + restCodec, err := asRESTCodec(c.codec) + if err != nil { + return nil, err + } + return restCodec.MarshalAppendField(dst, msg.Interface(), leafField) } - -func restHTTPBodyRequest(op *operation) bool { - return restIsHTTPBody(op.methodConf.descriptor.Input(), op.restTarget.requestBodyFields) +func (c *restResponseCodec) Unmarshal(src []byte, dst proto.Message) error { + defer func() { c.count++ }() + msg, leafField, err := getBodyField(c.target.responseBodyFields, dst.ProtoReflect(), protoreflect.Message.Mutable) + if err != nil { + return err + } + if leafField == nil && restIsHTTPBody(msg.Descriptor(), nil) { + if len(src) == 0 { + return nil + } + fields := msg.Descriptor().Fields() + dataField := fields.ByName("data") + contentField := fields.ByName("content_type") + if c.count == 0 { + msg.Set(contentField, protoreflect.ValueOfString(c.contentType)) + } + // Take ownership of the bytes. + data := make([]byte, len(src)) + copy(data, src) + msg.Set(dataField, protoreflect.ValueOfBytes(data)) + return nil + } + if leafField == nil { + return c.codec.Unmarshal(src, msg.Interface()) + } + restCodec, err := asRESTCodec(c.codec) + if err != nil { + return err + } + return restCodec.UnmarshalField(src, msg.Interface(), leafField) } -func restHTTPBodyResponse(op *operation) bool { - return restIsHTTPBody(op.methodConf.descriptor.Output(), op.restTarget.responseBodyFields) +func asRESTCodec(codec Codec) (RESTCodec, error) { + restCodec, ok := codec.(RESTCodec) + if !ok { + return nil, fmt.Errorf("codec %q (%T) does not implement RESTCodec", codec.Name(), codec) + } + return restCodec, nil } func restIsHTTPBody(msg protoreflect.MessageDescriptor, bodyPath []protoreflect.FieldDescriptor) bool { diff --git a/router.go b/router.go index fc66b45..a36944d 100644 --- a/router.go +++ b/router.go @@ -219,8 +219,8 @@ type routeTarget struct { path []string verb string requestBodyFieldPath string - requestBodyFields []protoreflect.FieldDescriptor responseBodyFieldPath string + requestBodyFields []protoreflect.FieldDescriptor responseBodyFields []protoreflect.FieldDescriptor vars []routeTargetVar } diff --git a/vanguard.go b/vanguard.go index d31cd89..9195ee0 100644 --- a/vanguard.go +++ b/vanguard.go @@ -15,16 +15,13 @@ package vanguard import ( - "context" "fmt" "math" "net/http" "strings" - "time" "connectrpc.com/connect" "google.golang.org/genproto/googleapis/api/annotations" - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/reflect/protoregistry" ) @@ -149,7 +146,7 @@ type Mux struct { // the HooksCallback and the UnknownHandler are defined, the HooksCallback will be // invoked first. The UnknownHandler will only be invoked if the callback returns // a nil error. - HooksCallback func(context.Context, Operation) (Hooks, error) + // HooksCallback func(context.Context, Operation) (Hooks, error) bufferPool bufferPool codecs codecMap @@ -243,9 +240,9 @@ func (m *Mux) RegisterService(handler http.Handler, serviceDesc protoreflect.Ser svcOpts.maxMsgBufferBytes = DefaultMaxMessageBufferBytes } - if svcOpts.hooksCallback == nil { - svcOpts.hooksCallback = m.HooksCallback - } + // if svcOpts.hooksCallback == nil { + // svcOpts.hooksCallback = m.HooksCallback + // } if svcOpts.resolver == nil { svcOpts.resolver = m.TypeResolver @@ -349,18 +346,37 @@ func (m *Mux) registerMethod(handler http.Handler, methodDesc protoreflect.Metho if _, ok := m.methods[methodPath]; ok { return fmt.Errorf("duplicate registration: method %s has already been configured", methodDesc.FullName()) } + codecs := make(map[string]Codec, len(opts.codecNames)) + allCodecs := make(map[string]Codec, len(m.codecs)) + for name, codecFn := range m.codecs { + codec := codecFn(opts.resolver) + allCodecs[name] = codec + if _, ok := opts.codecNames[name]; ok { + codecs[name] = codec + } + } + compressors := make(map[string]compressor, len(opts.compressorNames)) + allCompressors := make(map[string]compressor, len(m.compressors)) + for name, comp := range m.compressors { + allCompressors[name] = comp + if _, ok := opts.compressorNames[name]; ok { + compressors[name] = comp + } + } methodConf := &methodConfig{ descriptor: methodDesc, methodPath: methodPath, handler: handler, resolver: opts.resolver, protocols: opts.protocols, - codecNames: opts.codecNames, + allCodecs: allCodecs, + allCompressors: allCompressors, + codecs: codecs, + compressors: compressors, preferredCodec: opts.preferredCodec, - compressorNames: opts.compressorNames, maxMsgBufferBytes: opts.maxMsgBufferBytes, maxGetURLBytes: opts.maxGetURLBytes, - hooksCallback: opts.hooksCallback, + //hooksCallback: opts.hooksCallback, } if m.methods == nil { m.methods = make(map[string]*methodConfig, 1) @@ -534,7 +550,7 @@ func WithMaxGetURLBytes(limit uint32) ServiceOption { }) } -// WithHooksCallback sets the given callback for hooking into the RPC flow. +/* // WithHooksCallback sets the given callback for hooking into the RPC flow. // This overrides any callback defined on the Mux. // // See Mux.HooksCallback for more information. @@ -687,7 +703,7 @@ func (h Hooks) isEmpty() bool { h.OnServerResponseMessage == nil && h.OnOperationFinish == nil && h.OnOperationFail == nil -} +}*/ // TypeResolver can resolve message and extension types and is used to instantiate // messages as needed for the middleware to serialize/de-serialize request and @@ -713,22 +729,68 @@ type serviceOptions struct { preferredCodec string maxMsgBufferBytes uint32 maxGetURLBytes uint32 - hooksCallback func(context.Context, Operation) (Hooks, error) + // hooksCallback func(context.Context, Operation) (Hooks, error) } type methodConfig struct { - descriptor protoreflect.MethodDescriptor - methodPath string - streamType connect.StreamType - handler http.Handler - resolver TypeResolver - protocols map[Protocol]struct{} - codecNames, compressorNames map[string]struct{} - preferredCodec string - httpRule *routeTarget // First HTTP rule, if any. - maxMsgBufferBytes uint32 - maxGetURLBytes uint32 - hooksCallback func(context.Context, Operation) (Hooks, error) + descriptor protoreflect.MethodDescriptor + methodPath string + streamType connect.StreamType + handler http.Handler + resolver TypeResolver + protocols map[Protocol]struct{} + codecs map[string]Codec // Supported by the server. + compressors map[string]compressor // Supported by the server. + allCodecs map[string]Codec // Supported by the client. + allCompressors map[string]compressor // Supported by the client. + preferredCodec string + httpRule *routeTarget // First HTTP rule, if any. + maxMsgBufferBytes uint32 + maxGetURLBytes uint32 + // hooksCallback func(context.Context, Operation) (Hooks, error) +} + +func (c *methodConfig) GetClientCodec(name string) (Codec, error) { + if codec, ok := c.allCodecs[name]; ok { + return codec, nil + } + return nil, fmt.Errorf("codec %s is not known", name) +} +func (c *methodConfig) GetServerCodec(name string) (Codec, error) { + if codec, ok := c.codecs[name]; ok { + return codec, nil + } + return nil, fmt.Errorf("codec %s is not known", name) +} +func (c *methodConfig) GetClientCompressor(name string) (compressor, error) { + if name == "" || name == CompressionIdentity { + return nil, nil //nolint:nilnil // nil is a valid compressor + } + if comp, ok := c.allCompressors[name]; ok { + return comp, nil + } + return nil, fmt.Errorf("compression algorithm %s is not known", name) +} +func (c *methodConfig) GetServerCompressor(name string) (compressor, error) { + if name == "" || name == CompressionIdentity { + return nil, nil //nolint:nilnil // nil is a valid compressor + } + if comp, ok := c.compressors[name]; ok { + return comp, nil + } + return nil, fmt.Errorf("compression algorithm %s is not known", name) +} +func (c *methodConfig) ResolveServerCodecName(clientName string) string { + if c.codecs[clientName] != nil { + return clientName + } + return c.preferredCodec +} +func (c *methodConfig) ResolveServerCompressorName(clientName string) string { + if c.compressors[clientName] != nil { + return clientName + } + return CompressionIdentity } // computeSet returns a resolved set of values of type T, preferring the given values if diff --git a/vanguard_restxrpc_test.go b/vanguard_restxrpc_test.go index e2e63d4..bcf8016 100644 --- a/vanguard_restxrpc_test.go +++ b/vanguard_restxrpc_test.go @@ -16,6 +16,7 @@ package vanguard import ( "bytes" + "crypto/rand" "encoding/json" "fmt" "io" @@ -107,6 +108,8 @@ func TestMux_RESTxRPC(t *testing.T) { } } } + largePayload := make([]byte /* 1MB */, maxRecycleBufferSize) + _, _ = rand.Read(largePayload) type input struct { method string @@ -166,10 +169,11 @@ func TestMux_RESTxRPC(t *testing.T) { meta http.Header } type testRequest struct { - name string - input input - stream testStream - output output + name string + input input + stream testStream + output output + isLarge bool } testRequests := []testRequest{{ name: "GetBook", @@ -428,6 +432,59 @@ func TestMux_RESTxRPC(t *testing.T) { code: http.StatusOK, body: &emptypb.Empty{}, }, + }, { + name: "Upload-Large", + isLarge: true, + input: input{ + method: http.MethodPost, + path: "/message.txt:upload", + body: &httpbody.HttpBody{ + ContentType: "text/plain", + Data: largePayload, + }, + }, + stream: testStream{ + method: testv1connect.ContentServiceUploadProcedure, + msgs: func() []testMsg { + var testMsgs []testMsg + payload := largePayload + for i := 0; len(payload) > 0; i++ { + size := len(payload) + if size > chunkMessageSize { + size = chunkMessageSize + } + part := payload[:size] + payload = payload[size:] + + msg := &testv1.UploadRequest{ + File: &httpbody.HttpBody{ + Data: part, + }, + } + if i == 0 { + msg.Filename = "message.txt" + msg.File.ContentType = "text/plain" + } + testMsgs = append(testMsgs, + testMsg{in: &testMsgIn{ + msg: msg, + }}, + ) + t.Log(msg.Filename) + } + testMsgs = append(testMsgs, + testMsg{out: &testMsgOut{ + msg: &emptypb.Empty{}, + }}, + ) + + return testMsgs + }(), + }, + output: output{ + code: http.StatusOK, + body: &emptypb.Empty{}, + }, }, { name: "Download", input: input{ @@ -512,7 +569,7 @@ func TestMux_RESTxRPC(t *testing.T) { req.Header.Set("Test", t.Name()) // for interceptor t.Log(req.Method, req.URL.String()) - debug, _ := httputil.DumpRequest(req, true) + debug, _ := httputil.DumpRequest(req, !testCase.isLarge) t.Log("req:", string(debug)) rsp := httptest.NewRecorder() @@ -520,7 +577,7 @@ func TestMux_RESTxRPC(t *testing.T) { result := rsp.Result() defer result.Body.Close() - debug, _ = httputil.DumpResponse(result, true) + debug, _ = httputil.DumpResponse(result, !testCase.isLarge) t.Log("rsp:", string(debug)) // Check response diff --git a/vanguard_test.go b/vanguard_test.go index 1f7f317..ff2a10a 100644 --- a/vanguard_test.go +++ b/vanguard_test.go @@ -37,12 +37,12 @@ import ( "github.com/stretchr/testify/require" "google.golang.org/genproto/googleapis/api/annotations" "google.golang.org/genproto/googleapis/api/httpbody" + "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" - "google.golang.org/protobuf/types/known/emptypb" ) -func TestMux_BufferTooLargeFails(t *testing.T) { +/*func TestMux_BufferTooLargeFails(t *testing.T) { t.Parallel() // Cases where we buffer: @@ -548,7 +548,7 @@ func TestMux_BufferTooLargeFails(t *testing.T) { } }) } -} +}*/ func TestMux_ConnectGetUsesPostIfRequestTooLarge(t *testing.T) { t.Parallel() @@ -652,7 +652,7 @@ func TestMux_ConnectGetUsesPostIfRequestTooLarge(t *testing.T) { } } -func TestMux_MessageHooks(t *testing.T) { +/*func TestMux_MessageHooks(t *testing.T) { t.Parallel() // NB: These cases are identical to the pass-through cases, but should // not just pass through when a request or response hook is configured. @@ -1565,7 +1565,7 @@ func TestMux_HookOrder(t *testing.T) { } }) } -} +}*/ func TestRuleSelector(t *testing.T) { t.Parallel() @@ -1576,7 +1576,9 @@ func TestRuleSelector(t *testing.T) { testv1connect.UnimplementedLibraryServiceHandler{}, connect.WithInterceptors(&interceptor), )) - mux := &Mux{} + mux := &Mux{ + Protocols: []Protocol{ProtocolGRPC}, + } assert.NoError(t, mux.RegisterServiceByName(serveMux, testv1connect.LibraryServiceName)) assert.ErrorContains(t, mux.RegisterRules(&annotations.HttpRule{ @@ -1648,6 +1650,10 @@ func TestRuleSelector(t *testing.T) { assert.Equal(t, http.StatusOK, result.StatusCode) assert.Equal(t, "application/json", result.Header.Get("Content-Type")) assert.Equal(t, "world", result.Header.Get("Message")) + + var book testv1.Book + assert.NoError(t, protojson.Unmarshal(rsp.Body.Bytes(), &book)) + assert.Equal(t, "shelves/123/books/456", book.Name) } type testStream struct { @@ -1833,9 +1839,8 @@ func (i *testInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { if !ok { return nil, fmt.Errorf("expected proto.Message, got %T", req.Any()) } - diff := cmp.Diff(msg, inn.msg, protocmp.Transform()) - if diff != "" { - return nil, fmt.Errorf("message didn't match: %s", diff) + if !assertEqual(stream.T, inn.msg, msg) { + return nil, fmt.Errorf("message didn't match") } if out.err != nil { @@ -1911,10 +1916,8 @@ func (i *testInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc case err != nil: return err // not expecting an error default: - diff := cmp.Diff(got, msg.msg, protocmp.Transform()) - assert.Empty(stream.T, diff, "message didn't match") - if diff != "" { - return fmt.Errorf("message didn't match: %s", diff) + if !assertEqual(stream.T, got, msg.msg) { + return fmt.Errorf("message didn't match") } } case *testMsgOut: @@ -1971,7 +1974,7 @@ func (i *testInterceptor) restUnaryHandler( if err != nil { return err } - if comp != nil && len(body) > 0 && encoding != "" { + if comp != nil && len(body) > 0 && encoding != "" && encoding != "identity" { assert.Equal(stream.T, comp.Name(), encoding, "expected encoding") var dst bytes.Buffer if err := comp.decompress(&dst, bytes.NewBuffer(body)); err != nil { @@ -1996,8 +1999,7 @@ func (i *testInterceptor) restUnaryHandler( } } } - diff := cmp.Diff(got, inn.msg, protocmp.Transform()) - assert.Empty(stream.T, diff, "message didn't match") + assertEqual(stream.T, got, inn.msg) // Write headers. for key, values := range stream.rspHeader { @@ -2025,7 +2027,7 @@ func (i *testInterceptor) restUnaryHandler( if err != nil { return err } - if comp != nil && acceptEncoding != "" { + if comp != nil && acceptEncoding != "" && acceptEncoding != "identity" { assert.Equal(stream.T, comp.Name(), acceptEncoding, "expected encoding") rsp.Header().Set("Content-Encoding", comp.Name()) var dst bytes.Buffer @@ -2317,16 +2319,21 @@ func protocolAssertMiddleware( "Grpc-Encoding": allowedCompression, } case ProtocolConnect: - if strings.HasPrefix(req.Header.Get("Content-Type"), "application/connect") { + switch { + case strings.HasPrefix(req.Header.Get("Content-Type"), "application/connect"): wantHdr = map[string][]string{ "Content-Type": {fmt.Sprintf("application/connect+%s", codec)}, "Connect-Content-Encoding": allowedCompression, } - } else { + case req.Method == http.MethodPost: wantHdr = map[string][]string{ "Content-Type": {fmt.Sprintf("application/%s", codec)}, "Content-Encoding": allowedCompression, } + case req.Method == http.MethodGet: + // GET requests are not allowed to have a body, so we can't + // have a Content-Type header. + wantHdr = nil } default: http.Error(rsp, "unknown protocol", http.StatusInternalServerError) @@ -2400,6 +2407,8 @@ func runRPCTestCase[Client any]( if receivedErr == nil { assert.NoError(t, err) } else { + t.Log("expected error:", receivedErr) + t.Log("actual error:", err) assert.Equal(t, receivedErr.Code(), connect.CodeOf(err)) } // Also check the error observed by the server. @@ -2438,7 +2447,7 @@ func runRPCTestCase[Client any]( require.Len(t, responses, len(expectedResponses)) for i, msg := range responses { want := expectedResponses[i] - assert.Empty(t, cmp.Diff(want, msg, protocmp.Transform())) + assertEqual(t, want, msg) } } @@ -2449,7 +2458,7 @@ func disableCompression(server *httptest.Server) { type testCaseNameContextKey struct{} -type hookKind int +/*type hookKind int const ( hookKindInit = hookKind(iota + 1) @@ -2554,7 +2563,7 @@ func (h *testHooks) getEvents(t *testing.T) (Operation, []hookKind) { } t.Fatal("unreachable") return nil, nil -} +}*/ func newConnectError(code connect.Code, msg string) *connect.Error { err := connect.NewError(code, errors.New(msg)) @@ -2563,3 +2572,15 @@ func newConnectError(code connect.Code, msg string) *connect.Error { err.Meta() return err } + +func assertEqual(t *testing.T, want, got proto.Message) bool { + // Cmp can show non-linear behaviour for large messages + // so check proto.Equal before diffing. + // See: https://github.com/google/go-cmp/issues/335 + ok := proto.Equal(want, got) + if !ok { + diff := cmp.Diff(want, got, protocmp.Transform()) + assert.Empty(t, diff, "message didn't match") + } + return ok +}