From 72857a071b53117af528ae192850718c78e16a17 Mon Sep 17 00:00:00 2001 From: Edward McFarlane <3036610+emcfarlane@users.noreply.github.com> Date: Wed, 27 Sep 2023 16:00:08 +0100 Subject: [PATCH] Cleanup receiver names and remove most nolints (#76) --- .golangci.yml | 18 +- buffer_pool.go | 15 +- compression.go | 4 +- handler.go | 1103 ++++++++--------- handler_test.go | 7 +- .../examples/connect+grpc/cmd/server/main.go | 6 +- internal/examples/fileserver/main.go | 17 +- params.go | 2 - protocol.go | 8 +- protocol_connect.go | 8 +- protocol_grpc.go | 190 +-- protocol_grpc_test.go | 17 +- protocol_http.go | 3 +- protocol_rest.go | 5 +- router.go | 46 +- router_test.go | 31 +- vanguard.go | 26 +- vanguard_examples_test.go | 25 +- vanguard_rpcxrest_test.go | 18 +- vanguard_rpcxrpc_test.go | 19 +- vanguard_test.go | 131 +- 21 files changed, 848 insertions(+), 851 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index db5de77..e853e72 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -65,6 +65,22 @@ issues: # checks from err113 are useful. - "err113: do not define dynamic errors.*" exclude-rules: - - path: internal/examples/pets/.*\.go + - path: internal/examples/.*/.*\.go linters: - forbidigo # log.Fatal, fmt.Printf used in example programs + - gosec + - gochecknoglobals + - path: ".*_test.go" + linters: + - dupl # allow duplicate string literals for testing + - forcetypeassert + - nilerr # allow encoding error and returning nil + - path: vanguard_examples_test.go + linters: + - gocritic # allow log.Fatal for examples + - path: handler.go + linters: + - contextcheck # use request context + - path: params.go + linters: + - goconst # allow string literals for WKT names diff --git a/buffer_pool.go b/buffer_pool.go index c0a3399..abd8e19 100644 --- a/buffer_pool.go +++ b/buffer_pool.go @@ -29,24 +29,21 @@ type bufferPool struct { } func newBufferPool() *bufferPool { - return &bufferPool{ - Pool: sync.Pool{ - New: func() any { - return bytes.NewBuffer(make([]byte, 0, initialBufferSize)) - }, - }, - } + return &bufferPool{} } func (b *bufferPool) Get() *bytes.Buffer { - return b.Pool.Get().(*bytes.Buffer) //nolint:forcetypeassert + if buffer, ok := b.Pool.Get().(*bytes.Buffer); ok { + buffer.Reset() + return buffer + } + return bytes.NewBuffer(make([]byte, 0, initialBufferSize)) } func (b *bufferPool) Put(buffer *bytes.Buffer) { if buffer.Cap() > maxRecycleBufferSize { return } - buffer.Reset() b.Pool.Put(buffer) } diff --git a/compression.go b/compression.go index d62cf37..74b6a62 100644 --- a/compression.go +++ b/compression.go @@ -72,7 +72,7 @@ func (p *compressionPool) compress(dst, src *bytes.Buffer) error { if src.Len() == 0 { return nil } - comp := p.compressors.Get().(connect.Compressor) //nolint:forcetypeassert,errcheck + comp, _ := p.compressors.Get().(connect.Compressor) defer p.compressors.Put(comp) comp.Reset(dst) @@ -91,7 +91,7 @@ func (p *compressionPool) decompress(dst, src *bytes.Buffer) error { if src.Len() == 0 { return nil } - decomp := p.decompressors.Get().(connect.Decompressor) //nolint:forcetypeassert,errcheck + decomp, _ := p.decompressors.Get().(connect.Decompressor) defer p.decompressors.Put(decomp) if err := decomp.Reset(src); err != nil { diff --git a/handler.go b/handler.go index 099b33a..a75efe2 100644 --- a/handler.go +++ b/handler.go @@ -38,7 +38,6 @@ type handler struct { canDecompress []string } -//nolint:contextcheck func (h *handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { op := h.newOperation(writer, request) err := op.validate(h.mux, h.codecs) @@ -52,7 +51,7 @@ func (h *handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { } if callback != nil { var hookErr error - if op.hooks, hookErr = callback(op.request.Context(), op); hookErr != nil { //nolint:contextcheck + if op.hooks, hookErr = callback(op.request.Context(), op); hookErr != nil { useUnknownHandler = false err = hookErr } @@ -297,108 +296,108 @@ type operation struct { var _ Operation = (*operation)(nil) -func (op *operation) IsValid() bool { - return op.isValid +func (o *operation) IsValid() bool { + return o.isValid } -func (op *operation) HTTPRequestLine() (method, path, queryString, httpVersion string) { - return op.requestLine.method, op.requestLine.path, op.requestLine.queryString, op.requestLine.httpVersion +func (o *operation) HTTPRequestLine() (method, path, queryString, httpVersion string) { + return o.requestLine.method, o.requestLine.path, o.requestLine.queryString, o.requestLine.httpVersion } -func (op *operation) Method() protoreflect.MethodDescriptor { - if op.methodConf == nil { +func (o *operation) Method() protoreflect.MethodDescriptor { + if o.methodConf == nil { return nil } - return op.methodConf.descriptor + return o.methodConf.descriptor } -func (op *operation) Deadline() (time.Time, bool) { - return op.deadline, op.reqMeta.hasTimeout +func (o *operation) Deadline() (time.Time, bool) { + return o.deadline, o.reqMeta.hasTimeout } -func (op *operation) ClientInfo() PeerInfo { - return &op.client +func (o *operation) ClientInfo() PeerInfo { + return &o.client } -func (op *operation) HandlerInfo() PeerInfo { - return &op.server +func (o *operation) HandlerInfo() PeerInfo { + return &o.server } -func (op *operation) doNotImplement() {} +func (o *operation) doNotImplement() {} -func (op *operation) validate(mux *Mux, codecs map[codecKey]Codec) error { +func (o *operation) validate(mux *Mux, codecs map[codecKey]Codec) error { // Identify the protocol. - clientProtoHandler, queryVars := classifyRequest(op.request) + clientProtoHandler, queryVars := classifyRequest(o.request) if clientProtoHandler == nil { return newHTTPError(http.StatusUnsupportedMediaType, "could not classify protocol") } - op.client.protocol = clientProtoHandler + o.client.protocol = clientProtoHandler if queryVars != nil { // memoize this, so we don't have to parse query string again later - op.queryVars = queryVars + o.queryVars = queryVars } - op.originalHeaders = op.request.Header.Clone() - op.reqContentType = op.originalHeaders.Get("Content-Type") - op.contentLen = op.request.ContentLength - op.request.ContentLength = -1 // transforming it will likely change it + o.originalHeaders = o.request.Header.Clone() + o.reqContentType = o.originalHeaders.Get("Content-Type") + o.contentLen = o.request.ContentLength + o.request.ContentLength = -1 // transforming it will likely change it // Identify the method being invoked. - err := op.resolveMethod(mux) + err := o.resolveMethod(mux) if err != nil { return err } - if !op.client.protocol.acceptsStreamType(op, op.methodConf.streamType) { - return newHTTPError(http.StatusUnsupportedMediaType, "stream type %s not supported with %s protocol", op.methodConf.streamType, op.client.protocol) + if !o.client.protocol.acceptsStreamType(o, o.methodConf.streamType) { + return newHTTPError(http.StatusUnsupportedMediaType, "stream type %s not supported with %s protocol", o.methodConf.streamType, o.client.protocol) } - if op.methodConf.streamType == connect.StreamTypeBidi && op.request.ProtoMajor < 2 { + if o.methodConf.streamType == connect.StreamTypeBidi && o.request.ProtoMajor < 2 { return newHTTPError(http.StatusHTTPVersionNotSupported, "bidi streams require HTTP/2") } - if clientProtoHandler.protocol() == ProtocolGRPC && op.request.ProtoMajor != 2 { + if clientProtoHandler.protocol() == ProtocolGRPC && o.request.ProtoMajor != 2 { return newHTTPError(http.StatusHTTPVersionNotSupported, "gRPC requires HTTP/2") } // Identify the request encoding and compression. - reqMeta, err := clientProtoHandler.extractProtocolRequestHeaders(op, op.request.Header) + reqMeta, err := clientProtoHandler.extractProtocolRequestHeaders(o, o.request.Header) if err != nil { return newHTTPError(http.StatusBadRequest, err.Error()) } // Remove other headers that might mess up the next leg - if enc := op.request.Header.Get("Content-Encoding"); enc != "" && enc != CompressionIdentity { + if enc := o.request.Header.Get("Content-Encoding"); enc != "" && enc != CompressionIdentity { // If the protocol didn't remove the "Content-Encoding" header in above step, // that's because it models encoding in a different way. In that case, encoding // of the whole response with this header is not valid. return newHTTPError(http.StatusUnsupportedMediaType, "content-encoding %q not allowed for this protocol", enc) } - op.request.Header.Del("Content-Encoding") - op.request.Header.Del("Accept-Encoding") - op.request.Header.Del("Content-Length") + o.request.Header.Del("Content-Encoding") + o.request.Header.Del("Accept-Encoding") + o.request.Header.Del("Content-Length") - op.reqMeta = reqMeta + o.reqMeta = reqMeta if reqMeta.hasTimeout { - op.deadline = time.Now().Add(reqMeta.timeout) + o.deadline = time.Now().Add(reqMeta.timeout) } if reqMeta.compression == CompressionIdentity { reqMeta.compression = "" // normalize to empty string } if reqMeta.compression != "" { var ok bool - op.client.reqCompression, ok = op.compressionPools[reqMeta.compression] + o.client.reqCompression, ok = o.compressionPools[reqMeta.compression] if !ok { return newHTTPError(http.StatusUnsupportedMediaType, "%q compression not supported", reqMeta.compression) } } - op.client.codec = codecs[codecKey{res: op.methodConf.resolver, name: reqMeta.codec}] - if op.client.codec == nil { + o.client.codec = codecs[codecKey{res: o.methodConf.resolver, name: reqMeta.codec}] + if o.client.codec == nil { return newHTTPError(http.StatusUnsupportedMediaType, "%q sub-format not supported", reqMeta.codec) } // Now we can determine the destination protocol details - if _, supportsProtocol := op.methodConf.protocols[clientProtoHandler.protocol()]; supportsProtocol { - op.server.protocol = clientProtoHandler.protocol().serverHandler(op) + if _, supportsProtocol := o.methodConf.protocols[clientProtoHandler.protocol()]; supportsProtocol { + o.server.protocol = clientProtoHandler.protocol().serverHandler(o) } else { for protocol := protocolMin; protocol <= protocolMax; protocol++ { - if _, supportsProtocol := op.methodConf.protocols[protocol]; supportsProtocol { - op.server.protocol = protocol.serverHandler(op) + if _, supportsProtocol := o.methodConf.protocols[protocol]; supportsProtocol { + o.server.protocol = protocol.serverHandler(o) break } } @@ -406,11 +405,11 @@ func (op *operation) validate(mux *Mux, codecs map[codecKey]Codec) error { // Now that we've ruled out the use of bidi streaming above, it's safe to simulate HTTP/2 // for the benefit of gRPC handlers, which require HTTP/2. - if op.server.protocol.protocol() == ProtocolGRPC { - op.request.Proto, op.request.ProtoMajor, op.request.ProtoMinor = "HTTP/2", 2, 0 + if o.server.protocol.protocol() == ProtocolGRPC { + o.request.Proto, o.request.ProtoMajor, o.request.ProtoMinor = "HTTP/2", 2, 0 } - if op.server.protocol.protocol() == ProtocolREST { + if o.server.protocol.protocol() == ProtocolREST { // REST always uses JSON. // TODO: Allow non-JSON encodings with REST? Would require registering content-types with codecs. // Would also require figuring out how to (un)marshal things other than messages when a body @@ -419,63 +418,63 @@ func (op *operation) validate(mux *Mux, codecs map[codecKey]Codec) error { // NB: This is fine to set even if a custom content-type is used via // the use of google.api.HttpBody. The actual content-type and body // data will be written via serverBodyPreparer implementation. - op.server.codec = codecs[codecKey{res: op.methodConf.resolver, name: CodecJSON}] - } else if _, supportsCodec := op.methodConf.codecNames[reqMeta.codec]; supportsCodec { - op.server.codec = op.client.codec + o.server.codec = codecs[codecKey{res: o.methodConf.resolver, name: CodecJSON}] + } else if _, supportsCodec := o.methodConf.codecNames[reqMeta.codec]; supportsCodec { + o.server.codec = o.client.codec } else { - op.server.codec = codecs[codecKey{res: op.methodConf.resolver, name: op.methodConf.preferredCodec}] + o.server.codec = codecs[codecKey{res: o.methodConf.resolver, name: o.methodConf.preferredCodec}] } if reqMeta.compression != "" { - if _, supportsCompression := op.methodConf.compressorNames[reqMeta.compression]; supportsCompression { - op.server.reqCompression = op.client.reqCompression + if _, supportsCompression := o.methodConf.compressorNames[reqMeta.compression]; supportsCompression { + o.server.reqCompression = o.client.reqCompression } // else: we'll just decompress and not recompress // TODO: should we instead pick a supported compression scheme (if there is one)? } - op.isValid = true // Successfully validated! + o.isValid = true // Successfully validated! return nil } -func (op *operation) queryValues() url.Values { - if op.queryVars == nil && op.request.URL.RawQuery != "" { - op.queryVars = op.request.URL.Query() +func (o *operation) queryValues() url.Values { + if o.queryVars == nil && o.request.URL.RawQuery != "" { + o.queryVars = o.request.URL.Query() } - return op.queryVars + return o.queryVars } -func (op *operation) handle() { //nolint:gocyclo - if op.hooks.OnClientRequestHeaders != nil { - if err := op.hooks.OnClientRequestHeaders(op.request.Context(), op, op.request.Header); err != nil { - op.reportError(err) +func (o *operation) handle() { //nolint:gocyclo + if o.hooks.OnClientRequestHeaders != nil { + if err := o.hooks.OnClientRequestHeaders(o.request.Context(), o, o.request.Header); err != nil { + o.reportError(err) return } } - op.clientEnveloper, _ = op.client.protocol.(envelopedProtocolHandler) - op.clientPreparer, _ = op.client.protocol.(clientBodyPreparer) - if op.clientPreparer != nil { - op.clientReqNeedsPrep = op.clientPreparer.requestNeedsPrep(op) + o.clientEnveloper, _ = o.client.protocol.(envelopedProtocolHandler) + o.clientPreparer, _ = o.client.protocol.(clientBodyPreparer) + if o.clientPreparer != nil { + o.clientReqNeedsPrep = o.clientPreparer.requestNeedsPrep(o) } - op.serverEnveloper, _ = op.server.protocol.(serverEnvelopedProtocolHandler) - op.serverPreparer, _ = op.server.protocol.(serverBodyPreparer) - if op.serverPreparer != nil { - op.serverReqNeedsPrep = op.serverPreparer.requestNeedsPrep(op) + o.serverEnveloper, _ = o.server.protocol.(serverEnvelopedProtocolHandler) + o.serverPreparer, _ = o.server.protocol.(serverBodyPreparer) + if o.serverPreparer != nil { + o.serverReqNeedsPrep = o.serverPreparer.requestNeedsPrep(o) } - serverRequestBuilder, _ := op.server.protocol.(requestLineBuilder) + serverRequestBuilder, _ := o.server.protocol.(requestLineBuilder) var requireMessageForRequestLine bool if serverRequestBuilder != nil { - requireMessageForRequestLine = serverRequestBuilder.requiresMessageToProvideRequestLine(op) + requireMessageForRequestLine = serverRequestBuilder.requiresMessageToProvideRequestLine(o) } - sameRequestCompression := op.client.reqCompression.Name() == op.server.reqCompression.Name() - sameCodec := op.client.codec.Name() == op.server.codec.Name() + sameRequestCompression := o.client.reqCompression.Name() == o.server.reqCompression.Name() + sameCodec := o.client.codec.Name() == o.server.codec.Name() // even if body encoding uses same content type, we can't treat them as the same // (which means re-using encoded data) if either side needs to prep the data first - sameRequestCodec := sameCodec && !op.clientReqNeedsPrep && !op.serverReqNeedsPrep - mustDecodeRequest := !sameRequestCodec || requireMessageForRequestLine || op.hooks.OnClientRequestMessage != nil + sameRequestCodec := sameCodec && !o.clientReqNeedsPrep && !o.serverReqNeedsPrep + mustDecodeRequest := !sameRequestCodec || requireMessageForRequestLine || o.hooks.OnClientRequestMessage != nil reqMsg := message{ sameCompression: sameRequestCompression, @@ -484,34 +483,34 @@ func (op *operation) handle() { //nolint:gocyclo if mustDecodeRequest { // Need the message type to decode - messageType, err := op.methodConf.resolver.FindMessageByName(op.methodConf.descriptor.Input().FullName()) + messageType, err := o.methodConf.resolver.FindMessageByName(o.methodConf.descriptor.Input().FullName()) if err != nil { - op.reportError(err) + o.reportError(err) return } reqMsg.msg = messageType.New().Interface() } - if (op.hooks.OnClientRequestMessage != nil && op.methodConf.streamType == connect.StreamTypeUnary) || + if (o.hooks.OnClientRequestMessage != nil && o.methodConf.streamType == connect.StreamTypeUnary) || requireMessageForRequestLine { // Go ahead and process first request message - switch err := op.readRequestMessage(nil, op.request.Body, &reqMsg); { + switch err := o.readRequestMessage(nil, o.request.Body, &reqMsg); { case errors.Is(err, io.EOF): // okay for the first message: means empty message data reqMsg.markReady() case err != nil: - op.reportError(err) + o.reportError(err) return } - if err := reqMsg.advanceToStage(op, stageDecoded); err != nil { - op.reportError(err) + if err := reqMsg.advanceToStage(o, stageDecoded); err != nil { + o.reportError(err) return } - if op.hooks.OnClientRequestMessage != nil { - compressed := reqMsg.wasCompressed && op.client.reqCompression != nil - err := op.hooks.OnClientRequestMessage(op.request.Context(), op, reqMsg.msg, compressed, reqMsg.size) + if o.hooks.OnClientRequestMessage != nil { + compressed := reqMsg.wasCompressed && o.client.reqCompression != nil + err := o.hooks.OnClientRequestMessage(o.request.Context(), o, reqMsg.msg, compressed, reqMsg.size) if err != nil { - op.reportError(err) + o.reportError(err) return } } @@ -521,53 +520,53 @@ func (op *operation) handle() { //nolint:gocyclo if serverRequestBuilder != nil { var hasBody bool var err error - op.request.URL.Path, op.request.URL.RawQuery, op.request.Method, hasBody, err = - serverRequestBuilder.requestLine(op, reqMsg.msg) + o.request.URL.Path, o.request.URL.RawQuery, o.request.Method, hasBody, err = + serverRequestBuilder.requestLine(o, reqMsg.msg) if err != nil { - op.reportError(err) + o.reportError(err) return } skipBody = !hasBody // Recompute if the server needs to prep the request, now that we've modified // properties of op.request. - if op.serverPreparer != nil { - op.serverReqNeedsPrep = op.serverPreparer.requestNeedsPrep(op) + if o.serverPreparer != nil { + o.serverReqNeedsPrep = o.serverPreparer.requestNeedsPrep(o) } } else { // if no request line builder, use simple request layout - op.request.URL.Path = op.methodConf.methodPath - op.request.URL.RawQuery = "" - op.request.Method = http.MethodPost + o.request.URL.Path = o.methodConf.methodPath + o.request.URL.RawQuery = "" + o.request.Method = http.MethodPost } - op.request.URL.ForceQuery = false - svrReqMeta := op.reqMeta - svrReqMeta.codec = op.server.codec.Name() - svrReqMeta.compression = op.server.reqCompression.Name() - svrReqMeta.acceptCompression = intersect(op.reqMeta.acceptCompression, op.canDecompress) - op.server.protocol.addProtocolRequestHeaders(svrReqMeta, op.request.Header) + o.request.URL.ForceQuery = false + serverReqMeta := o.reqMeta + serverReqMeta.codec = o.server.codec.Name() + serverReqMeta.compression = o.server.reqCompression.Name() + serverReqMeta.acceptCompression = intersect(o.reqMeta.acceptCompression, o.canDecompress) + o.server.protocol.addProtocolRequestHeaders(serverReqMeta, o.request.Header) // Now we can define the transformed response writer (which delays // much of its logic until it sees the response headers). - flusher := asFlusher(op.writer) + flusher := asFlusher(o.writer) if flusher == nil { - op.reportError(errors.New("http.ResponseWriter must implement http.Flusher")) + o.reportError(errors.New("http.ResponseWriter must implement http.Flusher")) return } - rw := &responseWriter{op: op, delegate: op.writer, flusher: flusher} + rw := &responseWriter{op: o, delegate: o.writer, flusher: flusher} defer rw.close() - op.writer = rw + o.writer = rw // And finally we can define the transformed request bodies. switch { case skipBody: // drain any contents of body so downstream handler sees empty - op.drainBody(op.request.Body) + o.drainBody(o.request.Body) case sameRequestCompression && sameRequestCodec && !mustDecodeRequest: // we do not need to decompress or decode; just transforming envelopes - op.request.Body = &envelopingReader{rw: rw, r: op.request.Body} + o.request.Body = &envelopingReader{rw: rw, r: o.request.Body} default: - tw := &transformingReader{rw: rw, msg: &reqMsg, r: op.request.Body} - op.request.Body = tw + tw := &transformingReader{rw: rw, msg: &reqMsg, r: o.request.Body} + o.request.Body = tw if reqMsg.stage != stageEmpty { if err := tw.prepareMessage(); err != nil { tw.err = err @@ -575,17 +574,17 @@ func (op *operation) handle() { //nolint:gocyclo } } - op.methodConf.handler.ServeHTTP(op.writer, op.request) + o.methodConf.handler.ServeHTTP(o.writer, o.request) } -func (op *operation) resolveMethod(mux *Mux) error { - uriPath := op.request.URL.Path - switch op.client.protocol.protocol() { +func (o *operation) resolveMethod(mux *Mux) error { + uriPath := o.request.URL.Path + switch o.client.protocol.protocol() { case ProtocolREST: var methods routeMethods - op.restTarget, op.restVars, methods = mux.restRoutes.match(uriPath, op.request.Method) - if op.restTarget != nil { - op.methodConf = op.restTarget.config + o.restTarget, o.restVars, methods = mux.restRoutes.match(uriPath, o.request.Method) + if o.restTarget != nil { + o.methodConf = o.restTarget.config return nil } if len(methods) == 0 { @@ -611,9 +610,9 @@ func (op *operation) resolveMethod(mux *Mux) error { // a proper RPC error (encoded per protocol handler) with an Unimplemented code. return errNotFound } - op.restTarget = methodConf.httpRule - if op.request.Method != http.MethodPost { - mayAllowGet, ok := op.client.protocol.(clientProtocolAllowsGet) + o.restTarget = methodConf.httpRule + if o.request.Method != http.MethodPost { + mayAllowGet, ok := o.client.protocol.(clientProtocolAllowsGet) allowsGet := ok && mayAllowGet.allowsGetRequests(methodConf) if !allowsGet { return &httpError{ @@ -623,7 +622,7 @@ func (op *operation) resolveMethod(mux *Mux) error { }, } } - if allowsGet && op.request.Method != http.MethodGet { + if allowsGet && o.request.Method != http.MethodGet { return &httpError{ code: http.StatusMethodNotAllowed, header: http.Header{ @@ -632,7 +631,7 @@ func (op *operation) resolveMethod(mux *Mux) error { } } } - op.methodConf = methodConf + o.methodConf = methodConf return nil } } @@ -640,48 +639,48 @@ func (op *operation) resolveMethod(mux *Mux) error { // reportError handles an error that occurs while setting up the operation. It should not be used // once the underlying server handler has been invoked. For those errors, responseWriter.reportError // must be used instead. -func (op *operation) reportError(err error) { - defer op.cancel() +func (o *operation) reportError(err error) { + defer o.cancel() - if !op.isValid { + if !o.isValid { // We don't have enough operation details to render an RPC error, // so just send a simple HTTP error. - asHTTPError(err).Encode(op.writer) + asHTTPError(err).Encode(o.writer) return } - rw, ok := op.writer.(*responseWriter) + rw, ok := o.writer.(*responseWriter) if ok { rw.reportError(err) return } // No responseWriter created yet, so we duplicate some of its behavior to write an error. - if op.hooks.OnOperationFail != nil { - if hookErr := op.hooks.OnOperationFail(op.request.Context(), op, nil, err); hookErr != nil { + if o.hooks.OnOperationFail != nil { + if hookErr := o.hooks.OnOperationFail(o.request.Context(), o, nil, err); hookErr != nil { // report the error returned by the hook err = hookErr } } httpErr := asHTTPError(err) - httpErr.EncodeHeaders(op.writer.Header()) + httpErr.EncodeHeaders(o.writer.Header()) connErr := asConnectError(err) end := &responseEnd{err: connErr, httpCode: httpErr.code} - code := op.client.protocol.addProtocolResponseHeaders(responseMeta{end: end}, op.writer.Header()) - op.writer.WriteHeader(code) - trailers := op.client.protocol.encodeEnd(op, end, op.writer, true) - httpMergeTrailers(op.writer.Header(), trailers) + code := o.client.protocol.addProtocolResponseHeaders(responseMeta{end: end}, o.writer.Header()) + o.writer.WriteHeader(code) + trailers := o.client.protocol.encodeEnd(o, end, o.writer, true) + httpMergeTrailers(o.writer.Header(), trailers) } -func (op *operation) readRequestMessage(rw *responseWriter, reader io.Reader, msg *message) error { +func (o *operation) readRequestMessage(rw *responseWriter, reader io.Reader, msg *message) error { msgLen := -1 - compressed := op.client.reqCompression != nil - if op.clientEnveloper != nil { + compressed := o.client.reqCompression != nil + if o.clientEnveloper != nil { var envBuf envelopeBytes _, err := io.ReadFull(reader, envBuf[:]) if err != nil { return err } - msgLen, compressed, err = op.processRequestEnvelope(envBuf) + msgLen, compressed, err = o.processRequestEnvelope(envBuf) if err != nil { if rw != nil { rw.reportError(err) @@ -690,10 +689,10 @@ func (op *operation) readRequestMessage(rw *responseWriter, reader io.Reader, ms } } - buffer := msg.reset(op.bufferPool, true, compressed) + buffer := msg.reset(o.bufferPool, true, compressed) var err error if msgLen == -1 { //nolint:nestif - limit, grow, makeError, limitErr := op.determineReadLimit() + limit, grow, makeError, limitErr := o.determineReadLimit() if limitErr != nil { if rw != nil { rw.reportError(limitErr) @@ -722,40 +721,40 @@ func (op *operation) readRequestMessage(rw *responseWriter, reader io.Reader, ms return nil } -func (op *operation) processRequestEnvelope(envBuf envelopeBytes) (msgLen int, compressed bool, err error) { - env, err := op.clientEnveloper.decodeEnvelope(envBuf) +func (o *operation) processRequestEnvelope(envBuf envelopeBytes) (msgLen int, compressed bool, err error) { + env, err := o.clientEnveloper.decodeEnvelope(envBuf) if err != nil { return 0, false, malformedRequestError(err) } if env.trailer { return 0, false, malformedRequestError(fmt.Errorf("client stream cannot include status/trailer message")) } - if limit := op.methodConf.maxMsgBufferBytes; env.length > limit { + if limit := o.methodConf.maxMsgBufferBytes; env.length > limit { return 0, false, bufferLimitError(int64(limit)) } return int(env.length), env.compressed, nil } -func (op *operation) determineReadLimit() (limit int64, grow bool, makeError func(int64) error, err error) { - limit = int64(op.methodConf.maxMsgBufferBytes) - if op.contentLen == -1 { +func (o *operation) determineReadLimit() (limit int64, grow bool, makeError func(int64) error, err error) { + limit = int64(o.methodConf.maxMsgBufferBytes) + if o.contentLen == -1 { return limit, false, bufferLimitError, nil } - if op.contentLen > limit { + if o.contentLen > limit { // content-length header tells us that entity is too large err := bufferLimitError(limit) return 0, false, nil, err } - return op.contentLen, true, contentLengthError, nil + return o.contentLen, true, contentLengthError, nil } -func (op *operation) drainBody(body io.ReadCloser) { +func (o *operation) drainBody(body io.ReadCloser) { if wt, ok := body.(io.WriterTo); ok { _, _ = wt.WriteTo(io.Discard) return } - buf := op.bufferPool.Get() - defer op.bufferPool.Put(buf) + buf := o.bufferPool.Get() + defer o.bufferPool.Put(buf) b := buf.Bytes()[0:buf.Cap()] _, _ = io.CopyBuffer(io.Discard, body, b) } @@ -773,105 +772,105 @@ type envelopingReader struct { envRemain int } -func (er *envelopingReader) Read(data []byte) (n int, err error) { - if er.err != nil { - return 0, er.err +func (r *envelopingReader) Read(data []byte) (n int, err error) { + if r.err != nil { + return 0, r.err } - if er.current != nil { - bytesRead, err := er.current.Read(data) + if r.current != nil { + bytesRead, err := r.current.Read(data) isEOF := errors.Is(err, io.EOF) if bytesRead > 0 && (err == nil || isEOF) { return bytesRead, nil } if err != nil && !isEOF { - er.err = err + r.err = err return bytesRead, err } // otherwise EOF, fall through } - if err := er.prepareNext(); err != nil { - er.err = err + if err := r.prepareNext(); err != nil { + r.err = err return 0, err } - if len(data) < er.envRemain { - copy(data, er.env[envelopeLen-er.envRemain:]) - er.envRemain -= len(data) + if len(data) < r.envRemain { + copy(data, r.env[envelopeLen-r.envRemain:]) + r.envRemain -= len(data) return len(data), nil } var offset int - if er.envRemain > 0 { - copy(data, er.env[envelopeLen-er.envRemain:]) - offset = er.envRemain - er.envRemain = 0 + if r.envRemain > 0 { + copy(data, r.env[envelopeLen-r.envRemain:]) + offset = r.envRemain + r.envRemain = 0 } if len(data) > offset { - n, err = er.current.Read(data[offset:]) + n, err = r.current.Read(data[offset:]) } return offset + n, err } -func (er *envelopingReader) Close() error { - if er.mustReleaseCurrent { - buf, ok := er.current.(*bytes.Buffer) +func (r *envelopingReader) Close() error { + if r.mustReleaseCurrent { + buf, ok := r.current.(*bytes.Buffer) if ok { - er.rw.op.bufferPool.Put(buf) + r.rw.op.bufferPool.Put(buf) } - er.current = nil - er.mustReleaseCurrent = false + r.current = nil + r.mustReleaseCurrent = false } - er.err = errors.New("body is closed") - return er.r.Close() + r.err = errors.New("body is closed") + return r.r.Close() } -func (er *envelopingReader) prepareNext() error { +func (r *envelopingReader) prepareNext() error { var env envelope switch { - case er.rw.op.clientEnveloper == nil && er.rw.op.serverEnveloper == nil: + case r.rw.op.clientEnveloper == nil && r.rw.op.serverEnveloper == nil: // no envelopes to transform, just pass the body through w/ no change - er.current = er.r - er.envRemain = 0 + r.current = r.r + r.envRemain = 0 return nil - case er.rw.op.clientEnveloper == nil: - env.compressed = er.rw.op.client.reqCompression != nil - if er.rw.op.contentLen != -1 { - er.current = &hardLimitReader{r: er.r, rw: er.rw, limit: er.rw.op.contentLen, makeError: contentLengthError} - env.length = uint32(er.rw.op.contentLen) + case r.rw.op.clientEnveloper == nil: + env.compressed = r.rw.op.client.reqCompression != nil + if r.rw.op.contentLen != -1 { + r.current = &hardLimitReader{r: r.r, rw: r.rw, limit: r.rw.op.contentLen, makeError: contentLengthError} + env.length = uint32(r.rw.op.contentLen) } else { // Oof. We have to buffer entire request in order to measure it. - limit := int64(er.rw.op.methodConf.maxMsgBufferBytes) - buf := er.rw.op.bufferPool.Get() - _, err := io.Copy(buf, &hardLimitReader{r: er.r, rw: er.rw, limit: limit}) + limit := int64(r.rw.op.methodConf.maxMsgBufferBytes) + buf := r.rw.op.bufferPool.Get() + _, err := io.Copy(buf, &hardLimitReader{r: r.r, rw: r.rw, limit: limit}) if err != nil { - er.rw.op.bufferPool.Put(buf) - er.err = err + r.rw.op.bufferPool.Put(buf) + r.err = err return err } - er.current = buf - er.mustReleaseCurrent = true + r.current = buf + r.mustReleaseCurrent = true env.length = uint32(buf.Len()) } default: // clientEnveloper != nil var envBytes envelopeBytes - _, err := io.ReadFull(er.r, envBytes[:]) + _, err := io.ReadFull(r.r, envBytes[:]) if err != nil { return err } - env, err = er.rw.op.clientEnveloper.decodeEnvelope(envBytes) + env, err = r.rw.op.clientEnveloper.decodeEnvelope(envBytes) if err != nil { err = malformedRequestError(err) - er.rw.reportError(err) + r.rw.reportError(err) return err } - er.current = io.LimitReader(er.r, int64(env.length)) + r.current = io.LimitReader(r.r, int64(env.length)) } - if er.rw.op.serverEnveloper == nil { - er.envRemain = 0 + if r.rw.op.serverEnveloper == nil { + r.envRemain = 0 } else { - er.envRemain = envelopeLen - er.env = er.rw.op.serverEnveloper.encodeEnvelope(env) + r.envRemain = envelopeLen + r.env = r.rw.op.serverEnveloper.encodeEnvelope(env) } return nil } @@ -894,26 +893,26 @@ type transformingReader struct { envRemain int } -func (tr *transformingReader) Read(data []byte) (n int, err error) { - if tr.err != nil { - return 0, tr.err +func (r *transformingReader) Read(data []byte) (n int, err error) { + if r.err != nil { + return 0, r.err } for { - if len(data) < tr.envRemain { - copy(data, tr.env[envelopeLen-tr.envRemain:]) - tr.envRemain -= len(data) + if len(data) < r.envRemain { + copy(data, r.env[envelopeLen-r.envRemain:]) + r.envRemain -= len(data) return len(data), nil } var offset int - if tr.envRemain > 0 { - copy(data, tr.env[envelopeLen-tr.envRemain:]) - offset = tr.envRemain - tr.envRemain = 0 + if r.envRemain > 0 { + copy(data, r.env[envelopeLen-r.envRemain:]) + offset = r.envRemain + r.envRemain = 0 } var err error - if len(data) > offset && tr.buffer != nil { - n, err = tr.buffer.Read(data[offset:]) + if len(data) > offset && r.buffer != nil { + n, err = r.buffer.Read(data[offset:]) } if offset+n > 0 { return offset + n, err @@ -922,58 +921,58 @@ func (tr *transformingReader) Read(data []byte) (n int, err error) { // If we get here, there was nothing in tr.buffer to read, so // we need to prepare the next message and try again. - if err := tr.rw.op.readRequestMessage(tr.rw, tr.r, tr.msg); err != nil { + if err := r.rw.op.readRequestMessage(r.rw, r.r, r.msg); err != nil { // If this is the first request message, the error is EOF, and there's a body // preparer, we'll allow it and let the preparer produce a message from zero // request bytes. - if !tr.consumedFirst && errors.Is(err, io.EOF) && tr.rw.op.clientReqNeedsPrep { - tr.msg.markReady() + if !r.consumedFirst && errors.Is(err, io.EOF) && r.rw.op.clientReqNeedsPrep { + r.msg.markReady() } else { - tr.err = err + r.err = err return 0, err } } - if tr.rw.op.hooks.OnClientRequestMessage != nil { - if err := tr.msg.advanceToStage(tr.rw.op, stageDecoded); err != nil { - tr.err = err + if r.rw.op.hooks.OnClientRequestMessage != nil { + if err := r.msg.advanceToStage(r.rw.op, stageDecoded); err != nil { + r.err = err return 0, err } - compressed := tr.msg.wasCompressed && tr.rw.op.client.reqCompression != nil - if err := tr.rw.op.hooks.OnClientRequestMessage(tr.rw.op.request.Context(), tr.rw.op, tr.msg.msg, compressed, tr.msg.size); err != nil { - tr.rw.reportError(err) + compressed := r.msg.wasCompressed && r.rw.op.client.reqCompression != nil + if err := r.rw.op.hooks.OnClientRequestMessage(r.rw.op.request.Context(), r.rw.op, r.msg.msg, compressed, r.msg.size); err != nil { + r.rw.reportError(err) return 0, context.Canceled } } - if err := tr.prepareMessage(); err != nil { - tr.err = err + if err := r.prepareMessage(); err != nil { + r.err = err return 0, err } } } -func (tr *transformingReader) Close() error { - tr.err = errors.New("body is closed") - tr.msg.release(tr.rw.op.bufferPool) - return tr.r.Close() +func (r *transformingReader) Close() error { + r.err = errors.New("body is closed") + r.msg.release(r.rw.op.bufferPool) + return r.r.Close() } -func (tr *transformingReader) prepareMessage() error { - tr.consumedFirst = true - if err := tr.msg.advanceToStage(tr.rw.op, stageSend); err != nil { +func (r *transformingReader) prepareMessage() error { + r.consumedFirst = true + if err := r.msg.advanceToStage(r.rw.op, stageSend); err != nil { return err } - tr.buffer = tr.msg.sendBuffer() - if tr.rw.op.serverEnveloper == nil { - tr.envRemain = 0 + r.buffer = r.msg.sendBuffer() + if r.rw.op.serverEnveloper == nil { + r.envRemain = 0 return nil } // Need to prefix the buffer with an envelope env := envelope{ - compressed: tr.msg.wasCompressed && tr.rw.op.server.reqCompression != nil, - length: uint32(tr.buffer.Len()), + compressed: r.msg.wasCompressed && r.rw.op.server.reqCompression != nil, + length: uint32(r.buffer.Len()), } - tr.env = tr.rw.op.serverEnveloper.encodeEnvelope(env) - tr.envRemain = envelopeLen + r.env = r.rw.op.serverEnveloper.encodeEnvelope(env) + r.envRemain = envelopeLen return nil } @@ -1006,81 +1005,81 @@ type responseWriter struct { buf *bytes.Buffer } -func (rw *responseWriter) Header() http.Header { - return rw.delegate.Header() +func (w *responseWriter) Header() http.Header { + return w.delegate.Header() } -func (rw *responseWriter) Write(data []byte) (int, error) { - if !rw.headersWritten { - rw.WriteHeader(http.StatusOK) +func (w *responseWriter) Write(data []byte) (int, error) { + if !w.headersWritten { + w.WriteHeader(http.StatusOK) } - if rw.err != nil { - return 0, rw.err + if w.err != nil { + return 0, w.err } - return rw.w.Write(data) + return w.w.Write(data) } -func (rw *responseWriter) WriteHeader(statusCode int) { - if rw.headersWritten { +func (w *responseWriter) WriteHeader(statusCode int) { + if w.headersWritten { return } - rw.headersWritten = true - rw.code = statusCode + w.headersWritten = true + w.code = statusCode - if rw.endWritten { + if w.endWritten { // Nothing to do: we already sent RPC error to client. return } var err error - rw.contentLen, err = httpExtractContentLength(rw.Header()) + w.contentLen, err = httpExtractContentLength(w.Header()) if err != nil { - rw.reportError(err) + w.reportError(err) return } - rw.op.rspContentType = rw.Header().Get("Content-Type") - respMeta, processBody, err := rw.op.server.protocol.extractProtocolResponseHeaders(statusCode, rw.Header()) + w.op.rspContentType = w.Header().Get("Content-Type") + respMeta, processBody, err := w.op.server.protocol.extractProtocolResponseHeaders(statusCode, w.Header()) if err != nil { - rw.reportError(err) + w.reportError(err) return } // snapshot trailer keys - trailerKeys := parseMultiHeader(rw.Header().Values("Trailer")) + trailerKeys := parseMultiHeader(w.Header().Values("Trailer")) if len(trailerKeys) > 0 { respMeta.pendingTrailerKeys = make(headerKeys, len(trailerKeys)) for _, k := range trailerKeys { respMeta.pendingTrailerKeys.add(k) } - rw.Header().Del("Trailer") + w.Header().Del("Trailer") } // Remove other headers that might mess up the next leg - rw.Header().Del("Content-Encoding") - rw.Header().Del("Accept-Encoding") + w.Header().Del("Content-Encoding") + w.Header().Del("Accept-Encoding") - rw.respMeta = &respMeta + w.respMeta = &respMeta if respMeta.compression == CompressionIdentity { respMeta.compression = "" // normalize to empty string } if respMeta.compression != "" { - respCompression, ok := rw.op.compressionPools[respMeta.compression] + respCompression, ok := w.op.compressionPools[respMeta.compression] if !ok { - rw.reportError(fmt.Errorf("response indicates unsupported compression encoding %q", respMeta.compression)) + w.reportError(fmt.Errorf("response indicates unsupported compression encoding %q", respMeta.compression)) return } - rw.op.client.respCompression = respCompression - rw.op.server.respCompression = respCompression + w.op.client.respCompression = respCompression + w.op.server.respCompression = respCompression } - if respMeta.codec != "" && respMeta.codec != rw.op.server.codec.Name() && - !restHTTPBodyResponse(rw.op) { + if respMeta.codec != "" && respMeta.codec != w.op.server.codec.Name() && + !restHTTPBodyResponse(w.op) { // unexpected content-type for reply - rw.reportError(fmt.Errorf("response uses incorrect codec: expecting %q but instead got %q", rw.op.server.codec.Name(), respMeta.codec)) + w.reportError(fmt.Errorf("response uses incorrect codec: expecting %q but instead got %q", w.op.server.codec.Name(), respMeta.codec)) return } - if rw.op.hooks.OnServerResponseHeaders != nil { - if err := rw.op.hooks.OnServerResponseHeaders(rw.op.request.Context(), rw.op, statusCode, rw.Header()); err != nil { - rw.reportError(err) + if w.op.hooks.OnServerResponseHeaders != nil { + if err := w.op.hooks.OnServerResponseHeaders(w.op.request.Context(), w.op, statusCode, w.Header()); err != nil { + w.reportError(err) return } } @@ -1089,94 +1088,94 @@ func (rw *responseWriter) WriteHeader(statusCode int) { // RPC failed immediately. if processBody != nil { // We have to wait until we receive the body in order to process the error. - rw.w = &errorWriter{ - rw: rw, - respMeta: rw.respMeta, + w.w = &errorWriter{ + rw: w, + respMeta: w.respMeta, processBody: processBody, - buffer: rw.op.bufferPool.Get(), + buffer: w.op.bufferPool.Get(), } return } // We can send back error response immediately. - rw.flushHeaders() - rw.w = noResponseBodyWriter{} + w.flushHeaders() + w.w = noResponseBodyWriter{} return } - if rw.op.clientPreparer != nil { - rw.op.clientRespNeedsPrep = rw.op.clientPreparer.responseNeedsPrep(rw.op) + if w.op.clientPreparer != nil { + w.op.clientRespNeedsPrep = w.op.clientPreparer.responseNeedsPrep(w.op) } - if rw.op.serverPreparer != nil { - rw.op.serverRespNeedsPrep = rw.op.serverPreparer.responseNeedsPrep(rw.op) + if w.op.serverPreparer != nil { + w.op.serverRespNeedsPrep = w.op.serverPreparer.responseNeedsPrep(w.op) } - sameCodec := rw.op.client.codec.Name() == rw.op.server.codec.Name() + sameCodec := w.op.client.codec.Name() == w.op.server.codec.Name() // even if body encoding uses same content type, we can't treat them as the same // (which means re-using encoded data) if either side needs to prep the data first - sameResponseCodec := sameCodec && !rw.op.clientRespNeedsPrep && !rw.op.serverRespNeedsPrep - mustDecodeResponse := !sameResponseCodec || rw.op.hooks.OnServerResponseMessage != nil + sameResponseCodec := sameCodec && !w.op.clientRespNeedsPrep && !w.op.serverRespNeedsPrep + mustDecodeResponse := !sameResponseCodec || w.op.hooks.OnServerResponseMessage != nil respMsg := message{sameCompression: true, sameCodec: sameResponseCodec} if mustDecodeResponse { // We will have to decode and re-encode, so we need the message type. - messageType, err := rw.op.methodConf.resolver.FindMessageByName(rw.op.methodConf.descriptor.Output().FullName()) + messageType, err := w.op.methodConf.resolver.FindMessageByName(w.op.methodConf.descriptor.Output().FullName()) if err != nil { - rw.reportError(err) + w.reportError(err) return } respMsg.msg = messageType.New().Interface() } var endMustBeInHeaders bool - if mustBe, ok := rw.op.client.protocol.(clientProtocolEndMustBeInHeaders); ok { + if mustBe, ok := w.op.client.protocol.(clientProtocolEndMustBeInHeaders); ok { endMustBeInHeaders = mustBe.endMustBeInHeaders() } var delegate io.Writer if endMustBeInHeaders { // We must await the end before we can write headers, which means we have to // buffer the entire response. - rw.buf = rw.op.bufferPool.Get() - delegate = &limitWriter{buf: rw.buf, limit: rw.op.methodConf.maxMsgBufferBytes, rw: rw} + w.buf = w.op.bufferPool.Get() + delegate = &limitWriter{buf: w.buf, limit: w.op.methodConf.maxMsgBufferBytes, rw: w} } else { // We can go ahead and flush headers now. - rw.flushHeaders() - delegate = rw.delegate + w.flushHeaders() + delegate = w.delegate } // Now we can define the transformed response body. if sameResponseCodec && !mustDecodeResponse { // we do not need to decompress or decode - rw.w = &envelopingWriter{rw: rw, w: delegate} + w.w = &envelopingWriter{rw: w, w: delegate} } else { - rw.w = &transformingWriter{rw: rw, msg: &respMsg, w: delegate} + w.w = &transformingWriter{rw: w, msg: &respMsg, w: delegate} } } // Unwrap provides access to the underlying response writer. This plays nicely // with ResponseController functionality introduced in Go 1.21 without actually // depending on Go 1.21. -func (rw *responseWriter) Unwrap() http.ResponseWriter { - return rw.delegate +func (w *responseWriter) Unwrap() http.ResponseWriter { + return w.delegate } -func (rw *responseWriter) Flush() { +func (w *responseWriter) Flush() { // We expose this method so server can call it and won't panic // or blow-up when doing type conversion. But it's a no-op // since we automatically flush at message boundaries when // transforming the response body. } -func (rw *responseWriter) flushMessage() { - if rw.buf != nil { +func (w *responseWriter) flushMessage() { + if w.buf != nil { // we are buffering until we see trailers, so we don't // want to actually flush the underlying response writer yet return } - rw.flusher.Flush() + w.flusher.Flush() } -func (rw *responseWriter) reportError(err error) { +func (w *responseWriter) reportError(err error) { var end responseEnd if errors.As(err, &end.err) { end.httpCode = httpStatusCodeFromRPC(end.err.Code()) @@ -1185,110 +1184,110 @@ func (rw *responseWriter) reportError(err error) { end.err = connect.NewError(connect.CodeInternal, err) end.httpCode = http.StatusBadGateway } - rw.reportEnd(&end) + w.reportEnd(&end) } -func (rw *responseWriter) reportEnd(end *responseEnd) { - if rw.endWritten { +func (w *responseWriter) reportEnd(end *responseEnd) { + if w.endWritten { // It's possible this could be called in the event of a cascading error, // where various receivers all call reportEnd. We will only respect the // first such call and ignore the others. return } - if rw.respMeta != nil && len(rw.respMeta.pendingTrailers) > 0 && len(end.trailers) == 0 { + if w.respMeta != nil && len(w.respMeta.pendingTrailers) > 0 && len(end.trailers) == 0 { // add any pending trailers to the end - end.trailers = rw.respMeta.pendingTrailers + end.trailers = w.respMeta.pendingTrailers } switch { - case rw.headersFlushed: + case w.headersFlushed: // write error to body or trailers - rw.writeEnd(end, false) - case rw.respMeta != nil: - rw.respMeta.end = end - rw.flushHeaders() + w.writeEnd(end, false) + case w.respMeta != nil: + w.respMeta.end = end + w.flushHeaders() default: - rw.respMeta = &responseMeta{end: end} - rw.flushHeaders() + w.respMeta = &responseMeta{end: end} + w.flushHeaders() } - rw.flusher.Flush() + w.flusher.Flush() // response is done - rw.op.cancel() - rw.err = context.Canceled + w.op.cancel() + w.err = context.Canceled } -func (rw *responseWriter) flushHeaders() { - if rw.headersFlushed { +func (w *responseWriter) flushHeaders() { + if w.headersFlushed { return // already flushed } - cliRespMeta := *rw.respMeta - cliRespMeta.codec = rw.op.client.codec.Name() - cliRespMeta.compression = rw.op.client.respCompression.Name() - cliRespMeta.acceptCompression = intersect(rw.respMeta.acceptCompression, rw.op.canDecompress) - statusCode := rw.op.client.protocol.addProtocolResponseHeaders(cliRespMeta, rw.Header()) - hasErr := rw.respMeta.end != nil && rw.respMeta.end.err != nil + cliRespMeta := *w.respMeta + cliRespMeta.codec = w.op.client.codec.Name() + cliRespMeta.compression = w.op.client.respCompression.Name() + cliRespMeta.acceptCompression = intersect(w.respMeta.acceptCompression, w.op.canDecompress) + statusCode := w.op.client.protocol.addProtocolResponseHeaders(cliRespMeta, w.Header()) + hasErr := w.respMeta.end != nil && w.respMeta.end.err != nil // We only buffer full response for unary operations, so if we have an error, // we ignore anything already written to the buffer. - if rw.buf != nil && !hasErr { - rw.Header().Set("Content-Length", strconv.Itoa(rw.buf.Len())) + if w.buf != nil && !hasErr { + w.Header().Set("Content-Length", strconv.Itoa(w.buf.Len())) } // TODO: At this point, if the server was gRPC but the client is not, we may have "Trailer" // headers reserving the use of various metadata keys in trailers. It would be // cleaner if they were culled and only remained present for sneding to gRPC clients. - rw.delegate.WriteHeader(statusCode) - if rw.buf != nil { + w.delegate.WriteHeader(statusCode) + if w.buf != nil { if !hasErr { - _, _ = rw.buf.WriteTo(rw.delegate) + _, _ = w.buf.WriteTo(w.delegate) } - rw.op.bufferPool.Put(rw.buf) - rw.buf = nil + w.op.bufferPool.Put(w.buf) + w.buf = nil } - if rw.respMeta.end != nil { + if w.respMeta.end != nil { // response is done - rw.writeEnd(rw.respMeta.end, true) - rw.err = context.Canceled + w.writeEnd(w.respMeta.end, true) + w.err = context.Canceled } - rw.headersFlushed = true + w.headersFlushed = true } -func (rw *responseWriter) close() { - if !rw.headersWritten { +func (w *responseWriter) close() { + if !w.headersWritten { // treat as empty successful response - rw.WriteHeader(http.StatusOK) + w.WriteHeader(http.StatusOK) } - if rw.w != nil { - _ = rw.w.Close() + if w.w != nil { + _ = w.w.Close() } - if rw.endWritten { + if w.endWritten { return // all done } - if rw.respMeta.end != nil { + if w.respMeta.end != nil { // got end in headers - rw.reportEnd(rw.respMeta.end) + w.reportEnd(w.respMeta.end) return } // try to get end from trailers - trailer := httpExtractTrailers(rw.Header(), rw.respMeta.pendingTrailerKeys) - end, err := rw.op.server.protocol.extractEndFromTrailers(rw.op, trailer) + trailer := httpExtractTrailers(w.Header(), w.respMeta.pendingTrailerKeys) + end, err := w.op.server.protocol.extractEndFromTrailers(w.op, trailer) if err != nil { - rw.reportError(err) + w.reportError(err) return } - rw.reportEnd(&end) + w.reportEnd(&end) } -func (rw *responseWriter) writeEnd(end *responseEnd, wasInHeaders bool) { - if end.err == nil && rw.op.hooks.OnOperationFinish != nil { - rw.op.hooks.OnOperationFinish(rw.op.request.Context(), rw.op, end.trailers) - } else if end.err != nil && rw.op.hooks.OnOperationFail != nil { - if hookErr := rw.op.hooks.OnOperationFail(rw.op.request.Context(), rw.op, end.trailers, end.err); hookErr != nil { +func (w *responseWriter) writeEnd(end *responseEnd, wasInHeaders bool) { + if end.err == nil && w.op.hooks.OnOperationFinish != nil { + w.op.hooks.OnOperationFinish(w.op.request.Context(), w.op, end.trailers) + } else if end.err != nil && w.op.hooks.OnOperationFail != nil { + if hookErr := w.op.hooks.OnOperationFail(w.op.request.Context(), w.op, end.trailers, end.err); hookErr != nil { // report the error returned by the hook end.err = asConnectError(hookErr) } } - trailers := rw.op.client.protocol.encodeEnd(rw.op, end, rw.delegate, wasInHeaders) - httpMergeTrailers(rw.Header(), trailers) - rw.endWritten = true + trailers := w.op.client.protocol.encodeEnd(w.op, end, w.delegate, wasInHeaders) + httpMergeTrailers(w.Header(), trailers) + w.endWritten = true } // envelopingWriter will translate between envelope styles as data is @@ -1308,223 +1307,223 @@ type envelopingWriter struct { trailerIsCompressed bool } -func (ew *envelopingWriter) Write(data []byte) (int, error) { - ew.maybeInit() - if ew.err != nil { - return 0, ew.err +func (w *envelopingWriter) Write(data []byte) (int, error) { + w.maybeInit() + if w.err != nil { + return 0, w.err } - if ew.remainingBytes == -1 { - n, err := ew.current.Write(data) + if w.remainingBytes == -1 { + n, err := w.current.Write(data) if err != nil { - ew.err = err + w.err = err } return n, err } var written int for { - if ew.err != nil { - return written, ew.err + if w.err != nil { + return written, w.err } - if len(data) < ew.remainingBytes { + if len(data) < w.remainingBytes { // not enough data to trigger next action; ingest data and return - n, err := ew.writeBytes(data) - ew.remainingBytes -= n + n, err := w.writeBytes(data) + w.remainingBytes -= n written += n if err != nil { - ew.err = err + w.err = err } return written, err } // ingest remaining needed and trigger next action - n, err := ew.writeBytes(data[:ew.remainingBytes]) + n, err := w.writeBytes(data[:w.remainingBytes]) written += n - data = data[ew.remainingBytes:] - ew.remainingBytes -= n + data = data[w.remainingBytes:] + w.remainingBytes -= n if err != nil { - ew.err = err + w.err = err return written, err } - if ew.writingEnvelope { - if err := ew.handleEnvelopeWritten(); err != nil { + if w.writingEnvelope { + if err := w.handleEnvelopeWritten(); err != nil { return written, err } continue } - if ew.currentIsTrailer { - err := ew.handleTrailer() + if w.currentIsTrailer { + err := w.handleTrailer() if err != nil { return written, err } } else { // flush after each message and reset for next envelope - ew.rw.flushMessage() - ew.writingEnvelope = true - ew.remainingBytes = envelopeLen + w.rw.flushMessage() + w.writingEnvelope = true + w.remainingBytes = envelopeLen } } } -func (ew *envelopingWriter) writeBytes(data []byte) (int, error) { - if ew.writingEnvelope { - copy(ew.env[envelopeLen-ew.remainingBytes:], data) +func (w *envelopingWriter) writeBytes(data []byte) (int, error) { + if w.writingEnvelope { + copy(w.env[envelopeLen-w.remainingBytes:], data) return len(data), nil } - return ew.current.Write(data) + return w.current.Write(data) } -func (ew *envelopingWriter) handleEnvelopeWritten() error { - ew.writingEnvelope = false - env, err := ew.rw.op.serverEnveloper.decodeEnvelope(ew.env) +func (w *envelopingWriter) handleEnvelopeWritten() error { + w.writingEnvelope = false + env, err := w.rw.op.serverEnveloper.decodeEnvelope(w.env) if err != nil { err = malformedRequestError(err) - ew.rw.reportError(err) + w.rw.reportError(err) return err } if env.trailer { // buffer final message, so we can transform it to a responseEnd - if limit := ew.rw.op.methodConf.maxMsgBufferBytes; env.length > limit { + if limit := w.rw.op.methodConf.maxMsgBufferBytes; env.length > limit { err := bufferLimitError(int64(limit)) - ew.rw.reportError(err) + w.rw.reportError(err) return err } - buf := ew.rw.op.bufferPool.Get() + buf := w.rw.op.bufferPool.Get() buf.Grow(int(env.length)) - ew.current = buf - ew.mustReleaseCurrent = true - ew.currentIsTrailer = true - ew.trailerIsCompressed = env.compressed - ew.remainingBytes = int(env.length) + w.current = buf + w.mustReleaseCurrent = true + w.currentIsTrailer = true + w.trailerIsCompressed = env.compressed + w.remainingBytes = int(env.length) return nil } - if ew.rw.op.clientEnveloper != nil { - envBytes := ew.rw.op.clientEnveloper.encodeEnvelope(env) - _, err := ew.w.Write(envBytes[:]) + if w.rw.op.clientEnveloper != nil { + envBytes := w.rw.op.clientEnveloper.encodeEnvelope(env) + _, err := w.w.Write(envBytes[:]) if err != nil { - ew.err = err + w.err = err return err } } - ew.current = ew.w - ew.remainingBytes = int(env.length) + w.current = w.w + w.remainingBytes = int(env.length) return nil } -func (ew *envelopingWriter) Close() error { +func (w *envelopingWriter) Close() error { var buf *bytes.Buffer - if ew.mustReleaseCurrent { + if w.mustReleaseCurrent { var ok bool - buf, ok = ew.current.(*bytes.Buffer) + buf, ok = w.current.(*bytes.Buffer) if !ok { - lw, ok := ew.current.(*limitWriter) + lw, ok := w.current.(*limitWriter) if ok { buf = lw.buf } } if buf == nil { - return fmt.Errorf("current sink must be *limitWriter or *bytes.Buffer but instead is %T", ew.current) + return fmt.Errorf("current sink must be *limitWriter or *bytes.Buffer but instead is %T", w.current) } - defer ew.rw.op.bufferPool.Put(buf) + defer w.rw.op.bufferPool.Put(buf) } - if ew.remainingBytes == -1 && ew.mustReleaseCurrent && ew.err == nil { + if w.remainingBytes == -1 && w.mustReleaseCurrent && w.err == nil { // We were buffering in order to measure size and create envelope, // so do that now. - env := envelope{compressed: ew.rw.op.client.respCompression != nil, length: uint32(buf.Len())} - envBytes := ew.rw.op.clientEnveloper.encodeEnvelope(env) - _, err := ew.w.Write(envBytes[:]) + env := envelope{compressed: w.rw.op.client.respCompression != nil, length: uint32(buf.Len())} + envBytes := w.rw.op.clientEnveloper.encodeEnvelope(env) + _, err := w.w.Write(envBytes[:]) if err != nil { - ew.err = err + w.err = err return err } - _, err = buf.WriteTo(ew.w) + _, err = buf.WriteTo(w.w) if err != nil { - ew.err = err + w.err = err return err } } var normalEOF bool - if ew.writingEnvelope && ew.remainingBytes == envelopeLen { + if w.writingEnvelope && w.remainingBytes == envelopeLen { // We were looking for envelope of next message, but no next message in the stream normalEOF = true } - if ew.remainingBytes > 0 && !normalEOF { + if w.remainingBytes > 0 && !normalEOF { // Unfinished body! - if ew.writingEnvelope { - ew.rw.reportError(fmt.Errorf("handler only wrote %d out of %d bytes of message envelope", envelopeLen-ew.remainingBytes, envelopeLen)) + if w.writingEnvelope { + w.rw.reportError(fmt.Errorf("handler only wrote %d out of %d bytes of message envelope", envelopeLen-w.remainingBytes, envelopeLen)) } else { - ew.rw.reportError(fmt.Errorf("handler failed to write final %d bytes of message", ew.remainingBytes)) + w.rw.reportError(fmt.Errorf("handler failed to write final %d bytes of message", w.remainingBytes)) } } - ew.remainingBytes = 0 - ew.current = nil - ew.err = errors.New("body is closed") + w.remainingBytes = 0 + w.current = nil + w.err = errors.New("body is closed") return nil } -func (ew *envelopingWriter) maybeInit() { - if ew.initialized { +func (w *envelopingWriter) maybeInit() { + if w.initialized { return } - ew.initialized = true - if ew.rw.op.serverEnveloper != nil { - ew.writingEnvelope = true - ew.remainingBytes = envelopeLen + w.initialized = true + if w.rw.op.serverEnveloper != nil { + w.writingEnvelope = true + w.remainingBytes = envelopeLen return } - if ew.rw.op.clientEnveloper == nil { + if w.rw.op.clientEnveloper == nil { // just pass everything through - ew.remainingBytes = -1 - ew.current = ew.w + w.remainingBytes = -1 + w.current = w.w return } - if ew.rw.contentLen == -1 { + if w.rw.contentLen == -1 { // Oof, we have to buffer everything to measure the request size // to construct an envelope. - ew.remainingBytes = -1 - buf := ew.rw.op.bufferPool.Get() - ew.current = &limitWriter{buf: buf, limit: ew.rw.op.methodConf.maxMsgBufferBytes, rw: ew.rw} - ew.mustReleaseCurrent = true + w.remainingBytes = -1 + buf := w.rw.op.bufferPool.Get() + w.current = &limitWriter{buf: buf, limit: w.rw.op.methodConf.maxMsgBufferBytes, rw: w.rw} + w.mustReleaseCurrent = true return } // synthesize envelope var env envelope - env.compressed = ew.rw.op.client.respCompression != nil - env.length = uint32(ew.rw.contentLen) - envBytes := ew.rw.op.clientEnveloper.encodeEnvelope(envelope{}) - _, err := ew.w.Write(envBytes[:]) + env.compressed = w.rw.op.client.respCompression != nil + env.length = uint32(w.rw.contentLen) + envBytes := w.rw.op.clientEnveloper.encodeEnvelope(envelope{}) + _, err := w.w.Write(envBytes[:]) if err != nil { - ew.err = err + w.err = err return } - ew.current = ew.w - ew.remainingBytes = envelopeLen + w.current = w.w + w.remainingBytes = envelopeLen } -func (ew *envelopingWriter) handleTrailer() error { - data, ok := ew.current.(*bytes.Buffer) +func (w *envelopingWriter) handleTrailer() error { + data, ok := w.current.(*bytes.Buffer) if !ok { // should not be possible - return fmt.Errorf("trailer must be *limitWriter but instead is %T", ew.current) - } - defer ew.rw.op.bufferPool.Put(data) - ew.mustReleaseCurrent = false - if ew.trailerIsCompressed { - uncompressed := ew.rw.op.bufferPool.Get() - defer ew.rw.op.bufferPool.Put(uncompressed) - if err := ew.rw.op.server.respCompression.decompress(uncompressed, data); err != nil { + return fmt.Errorf("trailer must be *limitWriter but instead is %T", w.current) + } + defer w.rw.op.bufferPool.Put(data) + w.mustReleaseCurrent = false + if w.trailerIsCompressed { + uncompressed := w.rw.op.bufferPool.Get() + defer w.rw.op.bufferPool.Put(uncompressed) + if err := w.rw.op.server.respCompression.decompress(uncompressed, data); err != nil { return err } data = uncompressed } - end, err := ew.rw.op.serverEnveloper.decodeEndFromMessage(ew.rw.op, data) + end, err := w.rw.op.serverEnveloper.decodeEndFromMessage(w.rw.op, data) if err != nil { - ew.rw.reportError(err) + w.rw.reportError(err) return err } - end.wasCompressed = ew.trailerIsCompressed - ew.rw.reportEnd(&end) - ew.err = errors.New("final data already written") + end.wasCompressed = w.trailerIsCompressed + w.rw.reportEnd(&end) + w.err = errors.New("final data already written") return nil } @@ -1546,158 +1545,158 @@ type transformingWriter struct { latestEnvelope envelope } -func (tw *transformingWriter) Write(data []byte) (int, error) { - if tw.err != nil { - return 0, tw.err +func (w *transformingWriter) Write(data []byte) (int, error) { + if w.err != nil { + return 0, w.err } - if tw.buffer == nil { - tw.reset() + if w.buffer == nil { + w.reset() } - if tw.expectingBytes == -1 { - if limit := int64(tw.rw.op.methodConf.maxMsgBufferBytes); int64(len(data))+int64(tw.buffer.Len()) > limit { + if w.expectingBytes == -1 { + if limit := int64(w.rw.op.methodConf.maxMsgBufferBytes); int64(len(data))+int64(w.buffer.Len()) > limit { err := bufferLimitError(limit) - tw.rw.reportError(err) + w.rw.reportError(err) return 0, err } - return tw.buffer.Write(data) + return w.buffer.Write(data) } var written int // For enveloped protocols, it's possible that data contains // multiple messages, so we need to process in a loop. for { - if tw.err != nil { - return written, tw.err + if w.err != nil { + return written, w.err } - remainingBytes := tw.expectingBytes - tw.buffer.Len() + remainingBytes := w.expectingBytes - w.buffer.Len() if len(data) < remainingBytes { // not enough data to trigger next action; ingest data and return - tw.buffer.Write(data) + w.buffer.Write(data) written += len(data) break } // ingest remaining needed and trigger next action - tw.buffer.Write(data[:remainingBytes]) + w.buffer.Write(data[:remainingBytes]) written += remainingBytes data = data[remainingBytes:] - if tw.writingEnvelope { + if w.writingEnvelope { var envBytes envelopeBytes - _, _ = tw.buffer.Read(envBytes[:]) + _, _ = w.buffer.Read(envBytes[:]) var err error - tw.latestEnvelope, err = tw.rw.op.serverEnveloper.decodeEnvelope(envBytes) + w.latestEnvelope, err = w.rw.op.serverEnveloper.decodeEnvelope(envBytes) if err != nil { err = malformedRequestError(err) - tw.rw.reportError(err) + w.rw.reportError(err) return written, err } - if limit := tw.rw.op.methodConf.maxMsgBufferBytes; tw.latestEnvelope.length > limit { + if limit := w.rw.op.methodConf.maxMsgBufferBytes; w.latestEnvelope.length > limit { err = bufferLimitError(int64(limit)) - tw.rw.reportError(err) + w.rw.reportError(err) return written, err } - tw.buffer = tw.msg.reset(tw.rw.op.bufferPool, false, tw.latestEnvelope.compressed) - tw.buffer.Grow(int(tw.latestEnvelope.length)) - tw.expectingBytes = int(tw.latestEnvelope.length) - tw.writingEnvelope = false + w.buffer = w.msg.reset(w.rw.op.bufferPool, false, w.latestEnvelope.compressed) + w.buffer.Grow(int(w.latestEnvelope.length)) + w.expectingBytes = int(w.latestEnvelope.length) + w.writingEnvelope = false } else { - if err := tw.flushMessage(); err != nil { - tw.rw.reportError(err) + if err := w.flushMessage(); err != nil { + w.rw.reportError(err) return written, err } - tw.expectingBytes = envelopeLen - tw.writingEnvelope = true + w.expectingBytes = envelopeLen + w.writingEnvelope = true } } return written, nil } -func (tw *transformingWriter) Close() error { - if tw.expectingBytes == -1 { - if err := tw.flushMessage(); err != nil { - tw.rw.reportError(err) +func (w *transformingWriter) Close() error { + if w.expectingBytes == -1 { + if err := w.flushMessage(); err != nil { + w.rw.reportError(err) } - } else if tw.buffer != nil && tw.buffer.Len() > 0 { + } else if w.buffer != nil && w.buffer.Len() > 0 { // Unfinished body! - if tw.writingEnvelope { - tw.rw.reportError(fmt.Errorf("handler only wrote %d out of %d bytes of message envelope", tw.buffer.Len(), envelopeLen)) + if w.writingEnvelope { + w.rw.reportError(fmt.Errorf("handler only wrote %d out of %d bytes of message envelope", w.buffer.Len(), envelopeLen)) } else { - tw.rw.reportError(fmt.Errorf("handler only wrote %d out of %d bytes of message", tw.buffer.Len(), tw.expectingBytes)) + w.rw.reportError(fmt.Errorf("handler only wrote %d out of %d bytes of message", w.buffer.Len(), w.expectingBytes)) } } - tw.expectingBytes = 0 - tw.msg.release(tw.rw.op.bufferPool) - tw.buffer = nil - tw.err = errors.New("body is closed") + w.expectingBytes = 0 + w.msg.release(w.rw.op.bufferPool) + w.buffer = nil + w.err = errors.New("body is closed") return nil } -func (tw *transformingWriter) flushMessage() error { - if tw.latestEnvelope.trailer { - data := tw.buffer - if tw.latestEnvelope.compressed { - data = tw.rw.op.bufferPool.Get() - defer tw.rw.op.bufferPool.Put(data) - if err := tw.rw.op.server.respCompression.decompress(data, tw.buffer); err != nil { +func (w *transformingWriter) flushMessage() error { + if w.latestEnvelope.trailer { + data := w.buffer + if w.latestEnvelope.compressed { + data = w.rw.op.bufferPool.Get() + defer w.rw.op.bufferPool.Put(data) + if err := w.rw.op.server.respCompression.decompress(data, w.buffer); err != nil { return err } } - end, err := tw.rw.op.serverEnveloper.decodeEndFromMessage(tw.rw.op, data) + end, err := w.rw.op.serverEnveloper.decodeEndFromMessage(w.rw.op, data) if err != nil { - tw.rw.reportError(err) + w.rw.reportError(err) return err } - end.wasCompressed = tw.latestEnvelope.compressed - tw.rw.reportEnd(&end) - tw.err = errors.New("final data already written") + end.wasCompressed = w.latestEnvelope.compressed + w.rw.reportEnd(&end) + w.err = errors.New("final data already written") return nil } // We've finished reading the message, so we can manually set the stage - tw.msg.markReady() - if tw.rw.op.hooks.OnServerResponseMessage != nil { - if err := tw.msg.advanceToStage(tw.rw.op, stageDecoded); err != nil { + w.msg.markReady() + if w.rw.op.hooks.OnServerResponseMessage != nil { + if err := w.msg.advanceToStage(w.rw.op, stageDecoded); err != nil { return err } - compressed := tw.msg.wasCompressed && tw.rw.op.server.respCompression != nil - if err := tw.rw.op.hooks.OnServerResponseMessage(tw.rw.op.request.Context(), tw.rw.op, tw.msg.msg, compressed, tw.msg.size); err != nil { + compressed := w.msg.wasCompressed && w.rw.op.server.respCompression != nil + if err := w.rw.op.hooks.OnServerResponseMessage(w.rw.op.request.Context(), w.rw.op, w.msg.msg, compressed, w.msg.size); err != nil { return err } } - if err := tw.msg.advanceToStage(tw.rw.op, stageSend); err != nil { + if err := w.msg.advanceToStage(w.rw.op, stageSend); err != nil { return err } - buffer := tw.msg.sendBuffer() - if enveloper := tw.rw.op.clientEnveloper; enveloper != nil { + buffer := w.msg.sendBuffer() + if enveloper := w.rw.op.clientEnveloper; enveloper != nil { env := envelope{ - compressed: tw.msg.wasCompressed && tw.rw.op.client.respCompression != nil, + compressed: w.msg.wasCompressed && w.rw.op.client.respCompression != nil, length: uint32(buffer.Len()), } envBytes := enveloper.encodeEnvelope(env) - if _, err := tw.w.Write(envBytes[:]); err != nil { - tw.err = err + if _, err := w.w.Write(envBytes[:]); err != nil { + w.err = err return err } } - if _, err := buffer.WriteTo(tw.w); err != nil { - tw.err = err + if _, err := buffer.WriteTo(w.w); err != nil { + w.err = err return err } // flush after each message - tw.rw.flushMessage() + w.rw.flushMessage() - tw.reset() + w.reset() return nil } -func (tw *transformingWriter) reset() { - if tw.rw.op.serverEnveloper != nil { - tw.buffer = tw.msg.reset(tw.rw.op.bufferPool, false, false) - tw.expectingBytes = envelopeLen - tw.writingEnvelope = true +func (w *transformingWriter) reset() { + if w.rw.op.serverEnveloper != nil { + w.buffer = w.msg.reset(w.rw.op.bufferPool, false, false) + w.expectingBytes = envelopeLen + w.writingEnvelope = true } else { - isCompressed := tw.rw.respMeta.compression != "" - tw.buffer = tw.msg.reset(tw.rw.op.bufferPool, false, isCompressed) - tw.expectingBytes = -1 + isCompressed := w.rw.respMeta.compression != "" + w.buffer = w.msg.reset(w.rw.op.bufferPool, false, isCompressed) + w.expectingBytes = -1 } } @@ -1708,45 +1707,45 @@ type errorWriter struct { buffer *bytes.Buffer } -func (ew *errorWriter) Write(data []byte) (int, error) { - if ew.buffer == nil { +func (e *errorWriter) Write(data []byte) (int, error) { + if e.buffer == nil { return 0, errors.New("writer already closed") } - if limit := int64(ew.rw.op.methodConf.maxMsgBufferBytes); int64(len(data))+int64(ew.buffer.Len()) > limit { + if limit := int64(e.rw.op.methodConf.maxMsgBufferBytes); int64(len(data))+int64(e.buffer.Len()) > limit { err := bufferLimitError(limit) - ew.rw.reportError(err) + e.rw.reportError(err) return 0, err } - return ew.buffer.Write(data) + return e.buffer.Write(data) } -func (ew *errorWriter) Close() error { - if ew.respMeta.end == nil { - ew.respMeta.end = &responseEnd{} +func (e *errorWriter) Close() error { + if e.respMeta.end == nil { + e.respMeta.end = &responseEnd{} } - bufferPool := ew.rw.op.bufferPool - defer bufferPool.Put(ew.buffer) - body := ew.buffer - if compressPool := ew.rw.op.server.respCompression; compressPool != nil { + bufferPool := e.rw.op.bufferPool + defer bufferPool.Put(e.buffer) + body := e.buffer + if compressPool := e.rw.op.server.respCompression; compressPool != nil { uncompressed := bufferPool.Get() defer bufferPool.Put(uncompressed) if err := compressPool.decompress(uncompressed, body); err != nil { // can't really just return an error; we have to encode the // error into the RPC response, so we populate respMeta.end - if ew.respMeta.end.httpCode == 0 || ew.respMeta.end.httpCode == http.StatusOK { - ew.respMeta.end.httpCode = http.StatusInternalServerError + if e.respMeta.end.httpCode == 0 || e.respMeta.end.httpCode == http.StatusOK { + e.respMeta.end.httpCode = http.StatusInternalServerError } - ew.respMeta.end.err = connect.NewError(connect.CodeInternal, fmt.Errorf("failed to decompress body: %w", err)) + e.respMeta.end.err = connect.NewError(connect.CodeInternal, fmt.Errorf("failed to decompress body: %w", err)) body = nil } else { body = uncompressed } } if body != nil { - ew.processBody(ew.rw.op.server.codec, body, ew.respMeta.end) + e.processBody(e.rw.op.server.codec, body, e.respMeta.end) } - ew.rw.flushHeaders() - ew.buffer = nil + e.rw.flushHeaders() + e.buffer = nil return nil } @@ -2188,11 +2187,11 @@ type requestLine struct { method, path, queryString, httpVersion string } -func (line *requestLine) fromRequest(req *http.Request) { - line.method = req.Method - line.path = req.URL.Path - line.queryString = req.URL.RawQuery - line.httpVersion = req.Proto +func (l *requestLine) fromRequest(req *http.Request) { + l.method = req.Method + l.path = req.URL.Path + l.queryString = req.URL.RawQuery + l.httpVersion = req.Proto } func intersect(setA, setB []string) []string { diff --git a/handler_test.go b/handler_test.go index f037fb7..1f074b7 100644 --- a/handler_test.go +++ b/handler_test.go @@ -408,7 +408,6 @@ func TestHandler_Errors(t *testing.T) { } } -//nolint:dupl // some of these testStream literals are the same as in vanguard_rpcxrpc_test cases, but we don't need to share func TestHandler_PassThrough(t *testing.T) { t.Parallel() // These cases don't do any transformation and just pass through to the @@ -1056,7 +1055,7 @@ func checkStageRead(t *testing.T, msg *message, compressed bool) { func checkStageDecoded(t *testing.T, msg *message) { t.Helper() require.Equal(t, stageDecoded, msg.stage) - require.Equal(t, testDataString, msg.msg.(*wrapperspb.StringValue).Value) //nolint:forcetypeassert + require.Equal(t, testDataString, msg.msg.(*wrapperspb.StringValue).Value) // Should not be possible to go backwards. require.Error(t, msg.advanceToStage(nil, stageRead)) require.Error(t, msg.advanceToStage(nil, stageEmpty)) @@ -1094,13 +1093,13 @@ func (f *fakeCodec) Name() string { func (f *fakeCodec) MarshalAppend(b []byte, msg proto.Message) ([]byte, error) { f.marshalCalls++ - val := msg.(*wrapperspb.StringValue).Value //nolint:forcetypeassert + val := msg.(*wrapperspb.StringValue).Value return append(b, ([]byte)(val)...), nil } func (f *fakeCodec) Unmarshal(b []byte, msg proto.Message) error { f.unmarshalCalls++ - msg.(*wrapperspb.StringValue).Value = string(b) //nolint:forcetypeassert + msg.(*wrapperspb.StringValue).Value = string(b) return nil } diff --git a/internal/examples/connect+grpc/cmd/server/main.go b/internal/examples/connect+grpc/cmd/server/main.go index eb29666..6601625 100644 --- a/internal/examples/connect+grpc/cmd/server/main.go +++ b/internal/examples/connect+grpc/cmd/server/main.go @@ -33,13 +33,13 @@ import ( ) func main() { - svr := grpc.NewServer() - elizav1grpc.RegisterElizaServiceServer(svr, elizaImpl{}) + server := grpc.NewServer() + elizav1grpc.RegisterElizaServiceServer(server, elizaImpl{}) mux := vanguard.Mux{ Protocols: []vanguard.Protocol{vanguard.ProtocolGRPC}, Codecs: []string{vanguard.CodecProto}, } - err := mux.RegisterServiceByName(svr, "connectrpc.eliza.v1.ElizaService") + err := mux.RegisterServiceByName(server, "connectrpc.eliza.v1.ElizaService") if err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) diff --git a/internal/examples/fileserver/main.go b/internal/examples/fileserver/main.go index 468f1c6..34655d2 100644 --- a/internal/examples/fileserver/main.go +++ b/internal/examples/fileserver/main.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -//nolint:forbidigo,gosec,gochecknoglobals package main import ( @@ -65,14 +64,14 @@ var indexHTMLTemplate = template.Must(template.New("http").Parse(` {{.Title}} -
-{{- if ne .Title "."}}
-..
-{{- end}}
-{{- range $path, $name := .Files}}
-{{$name}}
-{{- end}}
-
+
+  {{- if ne .Title "."}}
+  ..
+  {{- end}}
+  {{- range $path, $name := .Files}}
+  {{$name}}
+  {{- end}}
+  
`)) diff --git a/params.go b/params.go index c757004..566a728 100644 --- a/params.go +++ b/params.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -//nolint:goconst package vanguard import ( @@ -275,7 +274,6 @@ func getParameter(msg protoreflect.Message, fields []protoreflect.FieldDescripto } func marshalFieldValue(field protoreflect.FieldDescriptor, value protoreflect.Value) ([]byte, error) { - //nolint:exhaustive switch kind := field.Kind(); kind { case protoreflect.BoolKind, protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind, diff --git a/protocol.go b/protocol.go index 8f8ad87..9ff93cd 100644 --- a/protocol.go +++ b/protocol.go @@ -377,12 +377,12 @@ type responseEnd struct { type headerKeys map[string]struct{} -func (keys headerKeys) add(k string) { - keys[textproto.CanonicalMIMEHeaderKey(k)] = struct{}{} +func (k headerKeys) add(key string) { + k[textproto.CanonicalMIMEHeaderKey(key)] = struct{}{} } -func (keys headerKeys) contains(k string) bool { - _, contains := keys[textproto.CanonicalMIMEHeaderKey(k)] +func (k headerKeys) contains(key string) bool { + _, contains := k[textproto.CanonicalMIMEHeaderKey(key)] return contains } diff --git a/protocol_connect.go b/protocol_connect.go index 85fa70a..6448b34 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -358,7 +358,7 @@ func (c connectUnaryServerProtocol) requestLine(op *operation, msg proto.Message vals.Set("encoding", op.server.codec.Name()) buf := op.bufferPool.Get() - stableMarshaler := op.server.codec.(StableCodec) //nolint:forcetypeassert,errcheck // c.useGet called above already checked this + stableMarshaler, _ := op.server.codec.(StableCodec) // c.useGet called above already checked this data, err := stableMarshaler.MarshalAppendStable(buf.Bytes(), msg) if err != nil { op.bufferPool.Put(buf) @@ -649,9 +649,9 @@ type connectWireError struct { Details []connectWireDetail `json:"details,omitempty"` } -func (err *connectWireError) toConnectError() *connect.Error { - cerr := connect.NewError(err.Code, errors.New(err.Message)) - for _, detail := range err.Details { +func (e *connectWireError) toConnectError() *connect.Error { + cerr := connect.NewError(e.Code, errors.New(e.Message)) + for _, detail := range e.Details { detailData, err := base64.RawStdEncoding.DecodeString(detail.Value) if err != nil { // seems a waste to fail or take other action here... diff --git a/protocol_grpc.go b/protocol_grpc.go index 125378e..7ae6015 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -20,13 +20,11 @@ import ( "errors" "fmt" "io" - "math" "net/http" "net/textproto" "strconv" "strings" "time" - "unicode/utf8" "connectrpc.com/connect" "google.golang.org/genproto/googleapis/rpc/status" @@ -34,34 +32,6 @@ import ( "google.golang.org/protobuf/types/known/anypb" ) -const ( - grpcTimeoutMaxHours = math.MaxInt64 / int64(time.Hour) - grpcMaxTimeoutChars = 8 -) - -var ( - //nolint:gochecknoglobals - grpcTimeoutUnits = []struct { - size time.Duration - char byte - }{ - {time.Nanosecond, 'n'}, - {time.Microsecond, 'u'}, - {time.Millisecond, 'm'}, - {time.Second, 'S'}, - {time.Minute, 'M'}, - {time.Hour, 'H'}, - } - //nolint:gochecknoglobals - grpcTimeoutUnitLookup = func() map[byte]time.Duration { - m := make(map[byte]time.Duration) - for _, pair := range grpcTimeoutUnits { - m[pair.char] = pair.size - } - return m - }() -) - // grpcClientProtocol implements the gRPC protocol for // processing RPCs received from the client. type grpcClientProtocol struct{} @@ -387,10 +357,8 @@ func grpcAddRequestMeta(contentTypePrefix string, meta requestMeta, headers http headers.Set("Grpc-Accept-Encoding", strings.Join(meta.acceptCompression, ", ")) } if meta.hasTimeout { - timeoutStr, ok := grpcEncodeTimeout(meta.timeout) - if ok { - headers.Set("Grpc-Timeout", timeoutStr) - } + timeoutStr := grpcEncodeTimeout(meta.timeout) + headers.Set("Grpc-Timeout", timeoutStr) } } @@ -460,63 +428,67 @@ func grpcStatusFromError(err *connect.Error) *status.Status { // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#responses // https://datatracker.ietf.org/doc/html/rfc3986#section-2.1 func grpcPercentEncode(msg string) string { + var hexCount int for i := 0; i < len(msg); i++ { - // Characters that need to be escaped are defined in gRPC's HTTP/2 spec. - // They're different from the generic set defined in RFC 3986. - if c := msg[i]; c < ' ' || c > '~' || c == '%' { - return grpcPercentEncodeSlow(msg, i) + if grpcShouldEscape(msg[i]) { + hexCount++ } } - return msg -} - -// msg needs some percent-escaping. Bytes before offset don't require -// percent-encoding, so they can be copied to the output as-is. -func grpcPercentEncodeSlow(msg string, offset int) string { + if hexCount == 0 { + return msg + } + // We need to escape some characters, so we'll need to allocate a new string. var out strings.Builder - out.Grow(2 * len(msg)) - out.WriteString(msg[:offset]) - for i := offset; i < len(msg); i++ { - c := msg[i] - if c < ' ' || c > '~' || c == '%' { - _, _ = fmt.Fprintf(&out, "%%%02X", c) - continue + out.Grow(len(msg) + 2*hexCount) + for i := 0; i < len(msg); i++ { + switch char := msg[i]; { + case grpcShouldEscape(char): + out.WriteByte('%') + out.WriteByte(upperhex[char>>4]) + out.WriteByte(upperhex[char&15]) + default: + out.WriteByte(char) } - out.WriteByte(c) } return out.String() } -func grpcPercentDecode(encoded string) string { - for i := 0; i < len(encoded); i++ { - if c := encoded[i]; c == '%' && i+2 < len(encoded) { - return grpcPercentDecodeSlow(encoded, i) +func grpcPercentDecode(input string) (string, error) { + percentCount := 0 + for i := 0; i < len(input); { + switch input[i] { + case '%': + percentCount++ + if err := validateHex(input[i:]); err != nil { + return "", err + } + i += 3 + default: + i++ } } - return encoded -} - -// Similar to percentEncodeSlow: encoded is percent-encoded, and needs to be -// decoded byte-by-byte starting at offset. -func grpcPercentDecodeSlow(encoded string, offset int) string { + if percentCount == 0 { + return input, nil + } + // We need to unescape some characters, so we'll need to allocate a new string. var out strings.Builder - out.Grow(len(encoded)) - out.WriteString(encoded[:offset]) - for i := offset; i < len(encoded); i++ { - c := encoded[i] - if c != '%' || i+2 >= len(encoded) { - out.WriteByte(c) - continue - } - parsed, err := strconv.ParseUint(encoded[i+1:i+3], 16 /* hex */, 8 /* bitsize */) - if err != nil { - out.WriteRune(utf8.RuneError) - } else { - out.WriteByte(byte(parsed)) + out.Grow(len(input) - 2*percentCount) + for i := 0; i < len(input); i++ { + switch input[i] { + case '%': + out.WriteByte(unhex(input[i+1])<<4 | unhex(input[i+2])) + i += 2 + default: + out.WriteByte(input[i]) } - i += 2 } - return out.String() + return out.String(), nil +} + +// Characters that need to be escaped are defined in gRPC's HTTP/2 spec. +// They're different from the generic set defined in RFC 3986. +func grpcShouldEscape(char byte) bool { + return char < ' ' || char > '~' || char == '%' } // The gRPC wire protocol specifies that errors should be serialized using the @@ -554,7 +526,13 @@ func grpcExtractErrorFromTrailer(trailers http.Header) *connect.Error { } if len(grpcDetails) == 0 { - message := grpcPercentDecode(grpcMsg) + message, err := grpcPercentDecode(grpcMsg) + if err != nil { + return connect.NewError( + connect.CodeInternal, + protocolError("invalid grpc-message trailer: %w", err), + ) + } return connect.NewWireError(connect.Code(code), errors.New(message)) } @@ -604,8 +582,8 @@ func grpcDecodeTimeout(timeout string) (time.Duration, error) { if timeout == "" { return 0, errNoTimeout } - unit, ok := grpcTimeoutUnitLookup[timeout[len(timeout)-1]] - if !ok { + unit := grpcTimeoutUnitLookup(timeout[len(timeout)-1]) + if unit == 0 { return 0, protocolError("timeout %q has invalid unit", timeout) } num, err := strconv.ParseInt(timeout[:len(timeout)-1], 10 /* base */, 64 /* bitsize */) @@ -615,6 +593,7 @@ func grpcDecodeTimeout(timeout string) (time.Duration, error) { if num > 99999999 { // timeout must be ASCII string of at most 8 digits return 0, protocolError("timeout %q is too long", timeout) } + const grpcTimeoutMaxHours = 8 if unit == time.Hour && num > grpcTimeoutMaxHours { // Timeout is effectively unbounded, so ignore it. The grpc-go // implementation does the same thing. @@ -623,17 +602,48 @@ func grpcDecodeTimeout(timeout string) (time.Duration, error) { return time.Duration(num) * unit, nil } -func grpcEncodeTimeout(timeout time.Duration) (string, bool) { +func grpcEncodeTimeout(timeout time.Duration) string { if timeout <= 0 { - return "0n", true + return "0n" } - for _, pair := range grpcTimeoutUnits { - digits := strconv.FormatInt(int64(timeout/pair.size), 10 /* base */) - if len(digits) < grpcMaxTimeoutChars { - return digits + string(pair.char), true - } + const grpcTimeoutMaxValue = 1e8 + var ( + size time.Duration + unit byte + ) + switch { + case timeout < time.Nanosecond*grpcTimeoutMaxValue: + size, unit = time.Nanosecond, 'n' + case timeout < time.Microsecond*grpcTimeoutMaxValue: + size, unit = time.Microsecond, 'u' + case timeout < time.Millisecond*grpcTimeoutMaxValue: + size, unit = time.Millisecond, 'm' + case timeout < time.Second*grpcTimeoutMaxValue: + size, unit = time.Second, 'S' + case timeout < time.Minute*grpcTimeoutMaxValue: + size, unit = time.Minute, 'M' + default: + size, unit = time.Hour, 'H' + } + value := timeout / size + return strconv.FormatInt(int64(value), 10 /* base */) + string(unit) +} + +func grpcTimeoutUnitLookup(unit byte) time.Duration { + switch unit { + case 'n': + return time.Nanosecond + case 'u': + return time.Microsecond + case 'm': + return time.Millisecond + case 'S': + return time.Second + case 'M': + return time.Minute + case 'H': + return time.Hour + default: + return 0 } - // The max time.Duration is smaller than the maximum expressible gRPC - // timeout, so we can't reach this case. - return "", false } diff --git a/protocol_grpc_test.go b/protocol_grpc_test.go index 61efad0..bed9d4d 100644 --- a/protocol_grpc_test.go +++ b/protocol_grpc_test.go @@ -67,8 +67,8 @@ func TestGRPCEncodeTimeoutQuick(t *testing.T) { t.Parallel() // Ensure that the error case is actually unreachable. encode := func(d time.Duration) bool { - _, ok := grpcEncodeTimeout(d) - return ok + v := grpcEncodeTimeout(d) + return v != "" } if err := quick.Check(encode, nil); err != nil { t.Error(err) @@ -82,7 +82,7 @@ func TestGRPCPercentEncodingQuick(t *testing.T) { return true } encoded := grpcPercentEncode(input) - decoded := grpcPercentDecode(encoded) + decoded, _ := grpcPercentDecode(encoded) return decoded == input } if err := quick.Check(roundtrip, nil /* config */); err != nil { @@ -104,7 +104,7 @@ func TestGRPCPercentEncoding(t *testing.T) { assert.True(t, utf8.ValidString(input), "input invalid UTF-8") encoded := grpcPercentEncode(input) t.Logf("%q encoded as %q", input, encoded) - decoded := grpcPercentDecode(encoded) + decoded, _ := grpcPercentDecode(encoded) assert.Equal(t, decoded, input) }) } @@ -137,13 +137,10 @@ func TestGRPCDecodeTimeout(t *testing.T) { func TestGRPCEncodeTimeout(t *testing.T) { t.Parallel() - timeout, ok := grpcEncodeTimeout(time.Hour + time.Second) - assert.True(t, ok) + timeout := grpcEncodeTimeout(time.Hour + time.Second) assert.Equal(t, timeout, "3601000m") - timeout, ok = grpcEncodeTimeout(time.Duration(math.MaxInt64)) - assert.True(t, ok) + timeout = grpcEncodeTimeout(time.Duration(math.MaxInt64)) assert.Equal(t, timeout, "2562047H") - timeout, ok = grpcEncodeTimeout(-1 * time.Hour) - assert.True(t, ok) + timeout = grpcEncodeTimeout(-1 * time.Hour) assert.Equal(t, timeout, "0n") } diff --git a/protocol_http.go b/protocol_http.go index 26018e8..4d2d639 100644 --- a/protocol_http.go +++ b/protocol_http.go @@ -181,8 +181,7 @@ func httpEncodePathValues(input protoreflect.Message, target *routeTarget) ( for i, part := range values { segmentIndex := variable.start + i if segmentIndex >= len(segments) { - //nolint:makezero - segments = append(segments, part) + segments = append(segments, part) // nozero continue } diff --git a/protocol_rest.go b/protocol_rest.go index e3a5485..c1864a1 100644 --- a/protocol_rest.go +++ b/protocol_rest.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -//nolint:revive // this is temporary, will be removed when implementation is complete package vanguard import ( @@ -314,7 +313,7 @@ func (r restServerProtocol) extractProtocolResponseHeaders(statusCode int, heade return meta, nil, nil } -func (r restServerProtocol) extractEndFromTrailers(o *operation, headers http.Header) (responseEnd, error) { +func (r restServerProtocol) extractEndFromTrailers(_ *operation, _ http.Header) (responseEnd, error) { return responseEnd{}, nil } @@ -382,7 +381,7 @@ func (r restServerProtocol) prepareUnmarshalledResponse(op *operation, src []byt return restCodec.UnmarshalField(src, msg.Interface(), leafField) } -func (r restServerProtocol) requiresMessageToProvideRequestLine(o *operation) bool { +func (r restServerProtocol) requiresMessageToProvideRequestLine(_ *operation) bool { return true } diff --git a/router.go b/router.go index 82ac7b2..fc66b45 100644 --- a/router.go +++ b/router.go @@ -39,7 +39,7 @@ type routeTrie struct { // HTTP rule. Only the rule itself is added. If the rule indicates additional // bindings, they are ignored. To add routes for all bindings, callers must // invoke this method for each rule. -func (trie *routeTrie) addRoute(config *methodConfig, rule *annotations.HttpRule) (*routeTarget, error) { +func (t *routeTrie) addRoute(config *methodConfig, rule *annotations.HttpRule) (*routeTarget, error) { var method, template string switch pattern := rule.Pattern.(type) { case *annotations.HttpRule_Get: @@ -71,39 +71,39 @@ func (trie *routeTrie) addRoute(config *methodConfig, rule *annotations.HttpRule if err != nil { return nil, err } - if err := trie.insert(method, target, segments); err != nil { + if err := t.insert(method, target, segments); err != nil { return nil, err } return target, nil } -func (trie *routeTrie) insertChild(segment string) *routeTrie { - child := trie.children[segment] +func (t *routeTrie) insertChild(segment string) *routeTrie { + child := t.children[segment] if child == nil { - if trie.children == nil { - trie.children = make(map[string]*routeTrie, 1) + if t.children == nil { + t.children = make(map[string]*routeTrie, 1) } child = &routeTrie{} - trie.children[segment] = child + t.children[segment] = child } return child } -func (trie *routeTrie) insertVerb(verb string) routeMethods { - methods := trie.verbs[verb] +func (t *routeTrie) insertVerb(verb string) routeMethods { + methods := t.verbs[verb] if methods == nil { - if trie.verbs == nil { - trie.verbs = make(map[string]routeMethods, 1) + if t.verbs == nil { + t.verbs = make(map[string]routeMethods, 1) } methods = make(routeMethods, 1) - trie.verbs[verb] = methods + t.verbs[verb] = methods } return methods } // insert the target into the trie using the given method and segment path. // The path is followed until the final segment is reached. -func (trie *routeTrie) insert(method string, target *routeTarget, segments pathSegments) error { - cursor := trie +func (t *routeTrie) insert(method string, target *routeTarget, segments pathSegments) error { + cursor := t for _, segment := range segments.path { cursor = cursor.insertChild(segment) } @@ -118,7 +118,7 @@ func (trie *routeTrie) insert(method string, target *routeTarget, segments pathS // match finds a route for the given request. If a match is found, the associated target and a map // of matched variable values is returned. -func (trie *routeTrie) match(uriPath, httpMethod string) (*routeTarget, []routeTargetVarMatch, routeMethods) { +func (t *routeTrie) match(uriPath, httpMethod string) (*routeTarget, []routeTargetVarMatch, routeMethods) { // TODO: Not checking if path ends with "/" means we accept missing final segment // for both * and **. Is that right? Makes sense for **, but maybe not for *. if len(uriPath) == 0 || uriPath[0] != '/' || uriPath[len(uriPath)-1] == ':' { @@ -142,7 +142,7 @@ func (trie *routeTrie) match(uriPath, httpMethod string) (*routeTarget, []routeT verb = lastElement[pos+1:] } } - target, methods := trie.findTarget(path, verb, httpMethod) + target, methods := t.findTarget(path, verb, httpMethod) if target == nil { return nil, nil, methods } @@ -162,21 +162,21 @@ func (trie *routeTrie) match(uriPath, httpMethod string) (*routeTarget, []routeT // is nil but methods are non-nil, the path and verb matched a route, but not // the method. This can be used to send back a well-formed "Allow" response // header. If both are nil, the path and verb did not match. -func (trie *routeTrie) findTarget(path []string, verb, method string) (*routeTarget, routeMethods) { +func (t *routeTrie) findTarget(path []string, verb, method string) (*routeTarget, routeMethods) { if len(path) == 0 { - return trie.getTarget(verb, method) + return t.getTarget(verb, method) } current := path[0] path = path[1:] - if child := trie.children[current]; child != nil { + if child := t.children[current]; child != nil { target, methods := child.findTarget(path, verb, method) if target != nil || methods != nil { return target, methods } } - if childAst := trie.children["*"]; childAst != nil { + if childAst := t.children["*"]; childAst != nil { target, methods := childAst.findTarget(path, verb, method) if target != nil || methods != nil { return target, methods @@ -185,7 +185,7 @@ func (trie *routeTrie) findTarget(path []string, verb, method string) (*routeTar // Double-asterisk must be the last element in pattern. // So it consumes all remaining path elements. - if childDblAst := trie.children["**"]; childDblAst != nil { + if childDblAst := t.children["**"]; childDblAst != nil { return childDblAst.findTarget(nil, verb, method) } return nil, nil @@ -194,8 +194,8 @@ func (trie *routeTrie) findTarget(path []string, verb, method string) (*routeTar // getTarget gets the target for the given verb and method from the // node trie. It is like findTarget, except that it does not use a // path to first descend into a sub-trie. -func (trie *routeTrie) getTarget(verb, method string) (*routeTarget, routeMethods) { - methods := trie.verbs[verb] +func (t *routeTrie) getTarget(verb, method string) (*routeTarget, routeMethods) { + methods := t.verbs[verb] if target := methods[method]; target != nil { return target, methods } diff --git a/router_test.go b/router_test.go index 99c50d7..6a32325 100644 --- a/router_test.go +++ b/router_test.go @@ -208,26 +208,23 @@ func BenchmarkTrieMatch(b *testing.B) { assert.Len(b, vars, 10) } -//nolint:gochecknoglobals -var routes = []string{ - "/foo/bar/baz/buzz", - "/foo/bar/{name}", - "/foo/bar/{name}/baz/{child}", - "/foo/bar/{name}/baz/{child.id}/buzz/{child.thing.id}", - "/foo/bar/*/{thing.id}/{cat=**}", - "/foo/bar/*/{thing.id}/{cat=**}:do", - "/foo/bar/*/{thing.id}/{cat=**}:cancel", - "/foo/bob/{book_id={author}/{isbn}/*}/details", - "/foo/blah/{longest_var={long_var.a={medium.a={short.aa}/*/{short.ab}/foo}/*}/{long_var.b={medium.b={short.ba}/*/{short.bb}/foo}/{last=**}}}:details", - "/foo%2Fbar/%2A/%2A%2a/{starstar=%2A%2a/**}:%2c", - "/trailing/**:slash", - "/verb", -} - func initTrie(tb testing.TB) *routeTrie { tb.Helper() var trie routeTrie - for _, route := range routes { + for _, route := range []string{ + "/foo/bar/baz/buzz", + "/foo/bar/{name}", + "/foo/bar/{name}/baz/{child}", + "/foo/bar/{name}/baz/{child.id}/buzz/{child.thing.id}", + "/foo/bar/*/{thing.id}/{cat=**}", + "/foo/bar/*/{thing.id}/{cat=**}:do", + "/foo/bar/*/{thing.id}/{cat=**}:cancel", + "/foo/bob/{book_id={author}/{isbn}/*}/details", + "/foo/blah/{longest_var={long_var.a={medium.a={short.aa}/*/{short.ab}/foo}/*}/{long_var.b={medium.b={short.ba}/*/{short.bb}/foo}/{last=**}}}:details", + "/foo%2Fbar/%2A/%2A%2a/{starstar=%2A%2a/**}:%2c", + "/trailing/**:slash", + "/verb", + } { segments, variables, err := parsePathTemplate(route) require.NoError(tb, err) diff --git a/vanguard.go b/vanguard.go index 47fa861..021bab9 100644 --- a/vanguard.go +++ b/vanguard.go @@ -216,13 +216,20 @@ func (m *Mux) RegisterService(handler http.Handler, serviceDesc protoreflect.Ser opt.apply(&svcOpts) } - svcOpts.protocols = computeSet(svcOpts.protocols, m.Protocols, defaultProtocols, false) + svcOpts.protocols = computeSet(svcOpts.protocols, m.Protocols, map[Protocol]struct{}{ + ProtocolConnect: {}, + ProtocolGRPC: {}, + ProtocolGRPCWeb: {}, + }, false) for protocol := range svcOpts.protocols { if protocol <= ProtocolUnknown || protocol > protocolMax { return fmt.Errorf("protocol %d is not a valid value", protocol) } } - svcOpts.codecNames = computeSet(svcOpts.codecNames, m.Codecs, defaultCodecs, false) + svcOpts.codecNames = computeSet(svcOpts.codecNames, m.Codecs, map[string]struct{}{ + CodecProto: {}, + CodecJSON: {}, + }, false) for codecName := range svcOpts.codecNames { if _, known := m.codecImpls[codecName]; !known { return fmt.Errorf("codec %s is not known; use mux.AddCodec to add known codecs first", codecName) @@ -236,7 +243,9 @@ func (m *Mux) RegisterService(handler http.Handler, serviceDesc protoreflect.Ser } } // empty is allowed here: non-nil but empty means do not send compressed data to handler - svcOpts.compressorNames = computeSet(svcOpts.compressorNames, m.Compressors, defaultCompressors, true) + svcOpts.compressorNames = computeSet(svcOpts.compressorNames, m.Compressors, map[string]struct{}{ + CompressionGzip: {}, + }, true) for compressorName := range svcOpts.compressorNames { if _, known := m.compressionPools[compressorName]; !known { return fmt.Errorf("compression algorithm %s is not known; use mux.AddCompression to add known algorithms first", compressorName) @@ -710,17 +719,6 @@ type TypeResolver interface { protoregistry.ExtensionTypeResolver } -//nolint:gochecknoglobals -var ( - defaultProtocols = map[Protocol]struct{}{ - ProtocolConnect: {}, - ProtocolGRPC: {}, - ProtocolGRPCWeb: {}, - } - defaultCodecs = map[string]struct{}{CodecProto: {}, CodecJSON: {}} - defaultCompressors = map[string]struct{}{CompressionGzip: {}} -) - type serviceOptionFunc func(*serviceOptions) func (f serviceOptionFunc) apply(opts *serviceOptions) { diff --git a/vanguard_examples_test.go b/vanguard_examples_test.go index a6505b5..828f270 100644 --- a/vanguard_examples_test.go +++ b/vanguard_examples_test.go @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -//nolint:gocritic package vanguard_test import ( @@ -60,13 +59,13 @@ func ExampleMux_connectToGRPC() { // Create the server. // (NB: This is a httptest.Server, but it could be any http.Server) - svr := httptest.NewUnstartedServer(mux.AsHandler()) - svr.EnableHTTP2 = true - svr.StartTLS() - defer svr.Close() + server := httptest.NewUnstartedServer(mux.AsHandler()) + server.EnableHTTP2 = true + server.StartTLS() + defer server.Close() // Create a connect client and call the service. - client := testv1connect.NewLibraryServiceClient(svr.Client(), svr.URL) + client := testv1connect.NewLibraryServiceClient(server.Client(), server.URL) // Call the service using Connect translated by the middleware to // gRPC. @@ -104,9 +103,9 @@ func ExampleMux_restToGRPC() { // Create the server. // (NB: This is a httptest.Server, but it could be any http.Server) - svr := httptest.NewServer(mux.AsHandler()) - defer svr.Close() - client := svr.Client() + server := httptest.NewServer(mux.AsHandler()) + defer server.Close() + client := server.Client() book := &testv1.Book{ Title: "2001: A Space Odyssey", @@ -121,7 +120,7 @@ func ExampleMux_restToGRPC() { // Create the POST request. req, _ := http.NewRequestWithContext( context.Background(), http.MethodPost, - svr.URL+"/v1/shelves/top/books", + server.URL+"/v1/shelves/top/books", bytes.NewReader(body), ) req.Header.Set("Content-Type", "application/json") @@ -164,11 +163,11 @@ func ExampleMux_connectToREST() { // Create the server. // (NB: This is a httptest.Server, but it could be any http.Server) - svr := httptest.NewServer(mux.AsHandler()) - defer svr.Close() + server := httptest.NewServer(mux.AsHandler()) + defer server.Close() // Create a connect client and call the service. - client := testv1connect.NewLibraryServiceClient(svr.Client(), svr.URL) + client := testv1connect.NewLibraryServiceClient(server.Client(), server.URL) rsp, err := client.GetBook( context.Background(), connect.NewRequest(&testv1.GetBookRequest{ diff --git a/vanguard_rpcxrest_test.go b/vanguard_rpcxrest_test.go index 891a0c9..5214210 100644 --- a/vanguard_rpcxrest_test.go +++ b/vanguard_rpcxrest_test.go @@ -98,7 +98,7 @@ func TestMux_RPCxREST(t *testing.T) { server.StartTLS() disableCompression(server) t.Cleanup(server.Close) - return testServer{name: name, svr: server} + return testServer{name: name, server: server} } servers := []testServer{} for _, compression := range compressions { @@ -106,9 +106,9 @@ func TestMux_RPCxREST(t *testing.T) { } type testOpt struct { - name string - svr *httptest.Server - opts []connect.ClientOption + name string + server *httptest.Server + opts []connect.ClientOption } testOpts := []testOpt{} for _, server := range servers { @@ -122,9 +122,9 @@ func TestMux_RPCxREST(t *testing.T) { copyOpts := make([]connect.ClientOption, len(opts)) copy(copyOpts, opts) testOpts = append(testOpts, testOpt{ - name: fmt.Sprintf("%s_%s_%s/%s", protocol, codec, compression, server.name), - svr: server.svr, - opts: copyOpts, + name: fmt.Sprintf("%s_%s_%s/%s", protocol, codec, compression, server.name), + server: server.server, + opts: copyOpts, }) } } @@ -411,10 +411,10 @@ func TestMux_RPCxREST(t *testing.T) { opts := opts clients := testClients{ libClient: testv1connect.NewLibraryServiceClient( - opts.svr.Client(), opts.svr.URL, opts.opts..., + opts.server.Client(), opts.server.URL, opts.opts..., ), contentClient: testv1connect.NewContentServiceClient( - opts.svr.Client(), opts.svr.URL, opts.opts..., + opts.server.Client(), opts.server.URL, opts.opts..., ), } t.Run(opts.name, func(t *testing.T) { diff --git a/vanguard_rpcxrpc_test.go b/vanguard_rpcxrpc_test.go index cd4ce48..ea8ce3a 100644 --- a/vanguard_rpcxrpc_test.go +++ b/vanguard_rpcxrpc_test.go @@ -32,7 +32,6 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) -//nolint:dupl // some of these testStream literals are the same as in handler_test cases, but we don't need to share func TestMux_RPCxRPC(t *testing.T) { t.Parallel() @@ -88,7 +87,7 @@ func TestMux_RPCxRPC(t *testing.T) { server.StartTLS() disableCompression(server) t.Cleanup(server.Close) - return testServer{name: name, svr: server} + return testServer{name: name, server: server} } var servers []testServer for _, protocol := range protocols { @@ -100,9 +99,9 @@ func TestMux_RPCxRPC(t *testing.T) { } type testOpt struct { - name string - svr *httptest.Server - opts []connect.ClientOption + name string + server *httptest.Server + opts []connect.ClientOption } var testOpts []testOpt for _, server := range servers { @@ -127,9 +126,9 @@ func TestMux_RPCxRPC(t *testing.T) { copyOpts := make([]connect.ClientOption, len(opts)) copy(copyOpts, opts) testOpts = append(testOpts, testOpt{ - name: fmt.Sprintf("%s%s_%s_%s/%s", protocol, suffix, codec, compression, server.name), - svr: server.svr, - opts: copyOpts, + name: fmt.Sprintf("%s%s_%s_%s/%s", protocol, suffix, codec, compression, server.name), + server: server.server, + opts: copyOpts, }) } } @@ -389,10 +388,10 @@ func TestMux_RPCxRPC(t *testing.T) { opts := opts clients := testClients{ libClient: testv1connect.NewLibraryServiceClient( - opts.svr.Client(), opts.svr.URL, opts.opts..., + opts.server.Client(), opts.server.URL, opts.opts..., ), contentClient: testv1connect.NewContentServiceClient( - opts.svr.Client(), opts.svr.URL, opts.opts..., + opts.server.Client(), opts.server.URL, opts.opts..., ), } t.Run(opts.name, func(t *testing.T) { diff --git a/vanguard_test.go b/vanguard_test.go index 93b26dc..0749e9b 100644 --- a/vanguard_test.go +++ b/vanguard_test.go @@ -42,7 +42,6 @@ import ( "google.golang.org/protobuf/types/known/emptypb" ) -//nolint:dupl // some of these testStream literals are the same as in other cases, but we don't need to share func TestMux_BufferTooLargeFails(t *testing.T) { t.Parallel() @@ -590,16 +589,16 @@ func TestMux_ConnectGetUsesPostIfRequestTooLarge(t *testing.T) { t.Cleanup(serverWithSvcOption.Close) testCases := []struct { - name string - svr *httptest.Server + name string + server *httptest.Server }{ { - name: "with_mux_setting", - svr: serverWithSetting, + name: "with_mux_setting", + server: serverWithSetting, }, { - name: "with_svc_option", - svr: serverWithSvcOption, + name: "with_svc_option", + server: serverWithSvcOption, }, } @@ -623,8 +622,8 @@ func TestMux_ConnectGetUsesPostIfRequestTooLarge(t *testing.T) { defer interceptor.del(t) client := testv1connect.NewLibraryServiceClient( - testCase.svr.Client(), - testCase.svr.URL, + testCase.server.Client(), + testCase.server.URL, connect.WithHTTPGet(), connect.WithHTTPGetMaxURLSize(512, false), connect.WithSendGzip(), @@ -653,7 +652,6 @@ func TestMux_ConnectGetUsesPostIfRequestTooLarge(t *testing.T) { } } -//nolint:dupl // some of these testStream literals are the same as in other cases, but we don't need to share func TestMux_MessageHooks(t *testing.T) { t.Parallel() // NB: These cases are identical to the pass-through cases, but should @@ -743,11 +741,11 @@ func TestMux_MessageHooks(t *testing.T) { } } - svrCases := []struct { + serverCases := []struct { name string reqHook bool respHook bool - svr *httptest.Server + server *httptest.Server }{ { name: "request_hook", @@ -763,10 +761,10 @@ func TestMux_MessageHooks(t *testing.T) { respHook: true, }, } - for i := range svrCases { - svrCase := &svrCases[i] + for i := range serverCases { + serverCase := &serverCases[i] mux := &Mux{ - HooksCallback: makeHooks(svrCase.reqHook, svrCase.respHook), + HooksCallback: makeHooks(serverCase.reqHook, serverCase.respHook), } require.NoError(t, mux.RegisterServiceByName(contentHandler, testv1connect.ContentServiceName)) handler := mux.AsHandler() @@ -782,7 +780,7 @@ func TestMux_MessageHooks(t *testing.T) { server.StartTLS() t.Cleanup(server.Close) - svrCase.svr = server + serverCase.server = server } ctx := context.Background() @@ -1068,18 +1066,18 @@ func TestMux_MessageHooks(t *testing.T) { testReq := testReq t.Run(testReq.name, func(t *testing.T) { t.Parallel() - for _, svrCase := range svrCases { - svrCase := svrCase - t.Run(svrCase.name, func(t *testing.T) { + for _, serverCase := range serverCases { + serverCase := serverCase + t.Run(serverCase.name, func(t *testing.T) { clientOptions := make([]connect.ClientOption, 0, 4) clientOptions = append(clientOptions, protocolCase.opts...) clientOptions = append(clientOptions, encodingCase.opts...) clientOptions = append(clientOptions, compressionCase.opts...) - client := testv1connect.NewContentServiceClient(svrCase.svr.Client(), svrCase.svr.URL, clientOptions...) + client := testv1connect.NewContentServiceClient(serverCase.server.Client(), serverCase.server.URL, clientOptions...) runRPCTestCase(t, &interceptor, client, testReq.invoke, testReq.stream) - if svrCase.reqHook { + if serverCase.reqHook { var reqs []proto.Message for _, msg := range testReq.stream.msgs { if msg.in != nil { @@ -1088,7 +1086,7 @@ func TestMux_MessageHooks(t *testing.T) { } checkHookResults(t, reqs, &reqMsgs) } - if svrCase.respHook { + if serverCase.respHook { var resps []proto.Message for _, msg := range testReq.stream.msgs { if msg.out != nil && msg.out.msg != nil { @@ -1109,7 +1107,6 @@ func TestMux_MessageHooks(t *testing.T) { } } -//nolint:dupl // some of these testStream literals are the same as in other cases, but we don't need to share func TestMux_HookOrder(t *testing.T) { t.Parallel() @@ -1123,7 +1120,7 @@ func TestMux_HookOrder(t *testing.T) { errorCases := []struct { name string failure hookKind - svr *httptest.Server + server *httptest.Server }{ { name: "normal", @@ -1214,7 +1211,7 @@ func TestMux_HookOrder(t *testing.T) { server.StartTLS() t.Cleanup(server.Close) - errorCases[i].svr = server + errorCases[i].server = server } ctx := context.Background() @@ -1510,7 +1507,7 @@ func TestMux_HookOrder(t *testing.T) { testReq := testReq t.Run(testReq.name, func(t *testing.T) { t.Parallel() - client := testv1connect.NewContentServiceClient(errorCase.svr.Client(), errorCase.svr.URL) + client := testv1connect.NewContentServiceClient(errorCase.server.Client(), errorCase.server.URL) awaitServer := interceptor.set(t, testReq.stream) defer interceptor.del(t) @@ -1723,38 +1720,38 @@ type ttStream struct { done chan struct{} } -func (str *ttStream) start() { +func (s *ttStream) start() { // Called from the interceptor when it starts handling the stream - str.started.Store(true) + s.started.Store(true) } -func (str *ttStream) finish(result error) { +func (s *ttStream) finish(result error) { // Called from the interceptor when it finishes handling the stream - str.result = result - close(str.done) + s.result = result + close(s.done) } -func (str *ttStream) await(t *testing.T, expectServerDone bool) (svrInvoked bool, svrErr error) { +func (s *ttStream) await(t *testing.T, expectServerDone bool) (serverInvoked bool, serverErr error) { t.Helper() // Called from test code to make sure server handler has completed. // Returns any error that the interceptor finished with. // Should only be called after the RPC appears to have completed in // the test client. - if !str.started.Load() { + if !s.started.Load() { // Interceptor never started, so nothing to wait for. return false, nil } if expectServerDone { select { - case <-str.done: - return true, str.result + case <-s.done: + return true, s.result default: t.Fatal("expecting server to already be done but it's not") } } select { - case <-str.done: - return true, str.result + case <-s.done: + return true, s.result case <-time.After(3 * time.Second): return true, fmt.Errorf("timeout: interceptor still did not finish after 3 seconds") } @@ -1764,8 +1761,8 @@ type testInterceptor struct { sync.Map } -func (ti *testInterceptor) get(testName string) (*ttStream, bool) { - val, ok := ti.Load(testName) +func (i *testInterceptor) get(testName string) (*ttStream, bool) { + val, ok := i.Load(testName) if !ok { return nil, false } @@ -1773,26 +1770,26 @@ func (ti *testInterceptor) get(testName string) (*ttStream, bool) { return stream, ok } -func (ti *testInterceptor) set(t *testing.T, stream testStream) func(*testing.T, bool) (bool, error) { +func (i *testInterceptor) set(t *testing.T, stream testStream) func(*testing.T, bool) (bool, error) { t.Helper() str := &ttStream{ T: t, testStream: stream, done: make(chan struct{}), } - ti.Store(t.Name(), str) + i.Store(t.Name(), str) // The returned function can be used by test code to await server completion. // (Useful in the event that middleware cancels the operation early, so client // could see a completed response while server still running concurrently.) return str.await } -func (ti *testInterceptor) del(t *testing.T) { +func (i *testInterceptor) del(t *testing.T) { t.Helper() - ti.Delete(t.Name()) + i.Delete(t.Name()) } -func (ti *testInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { +func (i *testInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return func( ctx context.Context, req connect.AnyRequest, @@ -1801,7 +1798,7 @@ func (ti *testInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { if val == "" { return next(ctx, req) } - stream, ok := ti.get(val) + stream, ok := i.get(val) if !ok { return nil, fmt.Errorf("invalid testCase header: %s", val) } @@ -1870,11 +1867,11 @@ func (ti *testInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { } } -func (ti *testInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { +func (i *testInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { return next } -func (ti *testInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { +func (i *testInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return func( ctx context.Context, conn connect.StreamingHandlerConn, @@ -1883,7 +1880,7 @@ func (ti *testInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFun if val == "" { return next(ctx, conn) } - stream, ok := ti.get(val) + stream, ok := i.get(val) if !ok { return fmt.Errorf("invalid testCase header: %s", val) } @@ -1948,7 +1945,7 @@ func (ti *testInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFun } } -func (ti *testInterceptor) restUnaryHandler( +func (i *testInterceptor) restUnaryHandler( codec Codec, comp *compressionPool, ) http.HandlerFunc { codecNames := map[string]string{ @@ -2014,8 +2011,7 @@ func (ti *testInterceptor) restUnaryHandler( // Write error, if any. if out.err != nil { httpWriteError(rsp, out.err) - //nolint:nilerr - return nil + return nil // ignore } // Write body. @@ -2058,7 +2054,7 @@ func (ti *testInterceptor) restUnaryHandler( http.Error(rsp, "missing test header", http.StatusInternalServerError) return } - stream, ok := ti.get(val) + stream, ok := i.get(val) if !ok { http.Error(rsp, "invalid test header", http.StatusInternalServerError) return @@ -2105,8 +2101,8 @@ func getDecompressor(t *testing.T, name string) connect.Decompressor { } type testServer struct { - name string - svr *httptest.Server + name string + server *httptest.Server } func appendClientProtocolOptions(t *testing.T, opts []connect.ClientOption, protocol Protocol) []connect.ClientOption { @@ -2197,7 +2193,6 @@ func outputFromUnary[Req, Resp any]( return headers, nil, nil, err } msg := any(resp.Msg) - //nolint:forcetypeassert return resp.Header(), []proto.Message{msg.(proto.Message)}, resp.Trailer(), nil } @@ -2222,7 +2217,6 @@ func outputFromServerStream[Req, Resp any]( var msgs []proto.Message for str.Receive() { msg := any(str.Msg()) - //nolint:forcetypeassert msgs = append(msgs, msg.(proto.Message)) } return str.ResponseHeader(), msgs, str.ResponseTrailer(), str.Err() @@ -2239,7 +2233,6 @@ func outputFromClientStream[Req, Resp any]( str.RequestHeader()[k] = v } for _, msg := range reqs { - //nolint:forcetypeassert if str.Send(any(msg).(*Req)) != nil { // we don't need this error; we'll get the error below // since str.CloseAndReceive returns the actual RPC errors @@ -2255,7 +2248,6 @@ func outputFromClientStream[Req, Resp any]( return headers, nil, nil, err } msg := any(resp.Msg) - //nolint:forcetypeassert return resp.Header(), []proto.Message{msg.(proto.Message)}, resp.Trailer(), nil } @@ -2284,7 +2276,6 @@ func outputFromBidiStream[Req, Resp any]( return } msg := any(resp) - //nolint:forcetypeassert msgs = append(msgs, msg.(proto.Message)) } }() @@ -2293,7 +2284,6 @@ func outputFromBidiStream[Req, Resp any]( str.RequestHeader()[k] = v } for _, msg := range reqs { - //nolint:forcetypeassert if str.Send(any(msg).(*Req)) != nil { // we don't need this error; we'll get the error from above // goroutine since str.Receive returns the actual RPC errors @@ -2403,7 +2393,7 @@ func runRPCTestCase[Client any]( break } } - svrInvoked, svrErr := awaitServer(t, expectServerDone) + serverInvoked, serverErr := awaitServer(t, expectServerDone) // Verify the error received by the client. receivedErr := expectedErr if stream.err != nil { @@ -2417,19 +2407,19 @@ func runRPCTestCase[Client any]( // Also check the error observed by the server. switch { case expectedErr == nil: - assert.NoError(t, svrErr) + assert.NoError(t, serverErr) case expectServerCancel: - if svrInvoked && svrErr != nil { + if serverInvoked && serverErr != nil { // We expect the server to either have seen the same error or it later // observed a cancel error (since the middleware cancels the request // after it aborts the operation). - if connect.CodeOf(svrErr) != connect.CodeOf(expectedErr) && !errors.Is(svrErr, context.Canceled) { - assert.Equal(t, connect.CodeCanceled, connect.CodeOf(svrErr)) + if connect.CodeOf(serverErr) != connect.CodeOf(expectedErr) && !errors.Is(serverErr, context.Canceled) { + assert.Equal(t, connect.CodeCanceled, connect.CodeOf(serverErr)) } } default: - assert.Error(t, svrErr) - assert.Equal(t, expectedErr.Code(), connect.CodeOf(svrErr)) + assert.Error(t, serverErr) + assert.Equal(t, expectedErr.Code(), connect.CodeOf(serverErr)) } assert.Subset(t, headers, stream.rspHeader) if stream.err == nil { @@ -2454,8 +2444,8 @@ func runRPCTestCase[Client any]( } } -func disableCompression(svr *httptest.Server) { - transport := svr.Client().Transport.(*http.Transport) //nolint:errcheck,forcetypeassert +func disableCompression(server *httptest.Server) { + transport, _ := server.Client().Transport.(*http.Transport) transport.DisableCompression = true } @@ -2564,7 +2554,8 @@ func (h *testHooks) getEvents(t *testing.T) (Operation, []hookKind) { for op, kinds := range ops { return op, kinds } - panic("should not be able to get here") //nolint:forbidigo + t.Fatal("unreachable") + return nil, nil } func newConnectError(code connect.Code, msg string) *connect.Error {