diff --git a/README.md b/README.md index 2fd791c..02d1b0f 100644 --- a/README.md +++ b/README.md @@ -165,12 +165,12 @@ Finally, you can register the middleware handler with an `http.Server` or ```go // The Mux can be used as the sole handler for an HTTP server. -err := http.Serve(listener, vanguardMux.AsHandler()) +err := http.Serve(listener, vanguardMux) // Or it can be used alongside other handlers, all registered with // the same http.ServeMux. mux := http.NewServeMux() -mux.Handle("/", vanguardMux.AsHandler()) +mux.Handle("/", vanguardMux) err := http.Serve(listener, mux) ``` The above example registers the handler for the root path. This is useful diff --git a/buffer_pool.go b/buffer_pool.go index abd8e19..d9374ee 100644 --- a/buffer_pool.go +++ b/buffer_pool.go @@ -28,10 +28,6 @@ type bufferPool struct { sync.Pool } -func newBufferPool() *bufferPool { - return &bufferPool{} -} - func (b *bufferPool) Get() *bytes.Buffer { if buffer, ok := b.Pool.Get().(*bytes.Buffer); ok { buffer.Reset() diff --git a/codec.go b/codec.go index 17e5feb..52e23a0 100644 --- a/codec.go +++ b/codec.go @@ -85,7 +85,9 @@ type RESTCodec interface { // The returned codec implements StableCodec, in addition to // Codec. func DefaultProtoCodec(res TypeResolver) Codec { - return &protoCodec{Resolver: res} + return protoCodec{ + UnmarshalOptions: proto.UnmarshalOptions{Resolver: res}, + } } // DefaultJSONCodec is the default codec factory used for the codec named @@ -110,22 +112,22 @@ type JSONCodec struct { UnmarshalOptions protojson.UnmarshalOptions } -var _ StableCodec = (*JSONCodec)(nil) -var _ RESTCodec = (*JSONCodec)(nil) +var _ StableCodec = JSONCodec{} +var _ RESTCodec = JSONCodec{} -func (j *JSONCodec) Name() string { +func (j JSONCodec) Name() string { return CodecJSON } -func (j *JSONCodec) IsBinary() bool { +func (j JSONCodec) IsBinary() bool { return false } -func (j *JSONCodec) MarshalAppend(base []byte, msg proto.Message) ([]byte, error) { +func (j JSONCodec) MarshalAppend(base []byte, msg proto.Message) ([]byte, error) { return j.MarshalOptions.MarshalAppend(base, msg) } -func (j *JSONCodec) MarshalAppendStable(base []byte, msg proto.Message) ([]byte, error) { +func (j JSONCodec) MarshalAppendStable(base []byte, msg proto.Message) ([]byte, error) { data, err := j.MarshalOptions.MarshalAppend(base, msg) if err != nil { return nil, err @@ -133,7 +135,7 @@ func (j *JSONCodec) MarshalAppendStable(base []byte, msg proto.Message) ([]byte, return jsonStabilize(data) } -func (j *JSONCodec) MarshalAppendField(base []byte, msg proto.Message, field protoreflect.FieldDescriptor) ([]byte, error) { +func (j JSONCodec) MarshalAppendField(base []byte, msg proto.Message, field protoreflect.FieldDescriptor) ([]byte, error) { if field.Message() != nil && field.Cardinality() != protoreflect.Repeated { return j.MarshalAppend(base, msg.ProtoReflect().Get(field).Message().Interface()) } @@ -191,7 +193,7 @@ func (j *JSONCodec) MarshalAppendField(base []byte, msg proto.Message, field pro return nil, fmt.Errorf("JSON does not contain key %s", fieldName) } -func (j *JSONCodec) UnmarshalField(data []byte, msg proto.Message, field protoreflect.FieldDescriptor) error { +func (j JSONCodec) UnmarshalField(data []byte, msg proto.Message, field protoreflect.FieldDescriptor) error { if field.Message() != nil && field.Cardinality() != protoreflect.Repeated { return j.Unmarshal(data, msg.ProtoReflect().Mutable(field).Message().Interface()) } @@ -211,11 +213,11 @@ func (j *JSONCodec) UnmarshalField(data []byte, msg proto.Message, field protore return j.Unmarshal(buf.Bytes(), msg) } -func (j *JSONCodec) Unmarshal(bytes []byte, msg proto.Message) error { +func (j JSONCodec) Unmarshal(bytes []byte, msg proto.Message) error { return j.UnmarshalOptions.Unmarshal(bytes, msg) } -func (j *JSONCodec) fieldName(field protoreflect.FieldDescriptor) string { +func (j JSONCodec) fieldName(field protoreflect.FieldDescriptor) string { if !j.MarshalOptions.UseProtoNames { return field.JSONName() } @@ -226,28 +228,33 @@ func (j *JSONCodec) fieldName(field protoreflect.FieldDescriptor) string { return string(field.Name()) } -type protoCodec proto.UnmarshalOptions +type protoCodec struct { + proto.MarshalOptions + proto.UnmarshalOptions +} -var _ StableCodec = (*protoCodec)(nil) +var _ StableCodec = protoCodec{} -func (p *protoCodec) Name() string { +func (p protoCodec) Name() string { return CodecProto } -func (p *protoCodec) IsBinary() bool { +func (p protoCodec) IsBinary() bool { return true } -func (p *protoCodec) MarshalAppend(base []byte, msg proto.Message) ([]byte, error) { +func (p protoCodec) MarshalAppend(base []byte, msg proto.Message) ([]byte, error) { return proto.MarshalOptions{}.MarshalAppend(base, msg) } -func (p *protoCodec) MarshalAppendStable(base []byte, msg proto.Message) ([]byte, error) { - return proto.MarshalOptions{Deterministic: true}.MarshalAppend(base, msg) +func (p protoCodec) MarshalAppendStable(base []byte, msg proto.Message) ([]byte, error) { + opts := p.MarshalOptions + opts.Deterministic = true + return opts.MarshalAppend(base, msg) } -func (p *protoCodec) Unmarshal(bytes []byte, msg proto.Message) error { - return (*proto.UnmarshalOptions)(p).Unmarshal(bytes, msg) +func (p protoCodec) Unmarshal(bytes []byte, msg proto.Message) error { + return p.UnmarshalOptions.Unmarshal(bytes, msg) } func jsonStabilize(data []byte) ([]byte, error) { @@ -260,3 +267,16 @@ func jsonStabilize(data []byte) ([]byte, error) { } return buf.Bytes(), nil } + +type codecMap map[string]func(TypeResolver) Codec + +func (m codecMap) get(name string, resolver TypeResolver) Codec { + if m == nil { + return nil + } + codecFn, ok := m[name] + if !ok { + return nil + } + return codecFn(resolver) +} diff --git a/compression.go b/compression.go index 74b6a62..9eb5c34 100644 --- a/compression.go +++ b/compression.go @@ -35,6 +35,27 @@ func DefaultGzipDecompressor() connect.Decompressor { return &gzip.Reader{} } +type compressionMap map[string]*compressionPool + +func (m compressionMap) intersection(names []string) []string { + length := len(names) + if len(m) < length { + length = len(m) + } + if length == 0 { + // If either set is empty, the intersection is empty. + // We don't use nil since it is used in places as a sentinel. + return make([]string, 0) + } + intersection := make([]string, 0, length) + for _, name := range names { + if _, ok := m[name]; ok { + intersection = append(intersection, name) + } + } + return intersection +} + type compressionPool struct { name string decompressors sync.Pool diff --git a/handler.go b/handler.go index a75efe2..5ac63f3 100644 --- a/handler.go +++ b/handler.go @@ -31,23 +31,17 @@ import ( "google.golang.org/protobuf/reflect/protoreflect" ) -type handler struct { - mux *Mux - bufferPool *bufferPool - codecs map[codecKey]Codec - canDecompress []string -} - -func (h *handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - op := h.newOperation(writer, request) - err := op.validate(h.mux, h.codecs) +// ServeHTTP implements http.Handler. +func (m *Mux) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + op := m.newOperation(writer, request) + err := op.validate(m, m.codecs) - useUnknownHandler := h.mux.UnknownHandler != nil && errors.Is(err, errNotFound) + useUnknownHandler := m.UnknownHandler != nil && errors.Is(err, errNotFound) var callback func(context.Context, Operation) (Hooks, error) if op.methodConf != nil { callback = op.methodConf.hooksCallback } else { - callback = h.mux.HooksCallback + callback = m.HooksCallback } if callback != nil { var hookErr error @@ -58,7 +52,7 @@ func (h *handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { } if useUnknownHandler { request.Header = op.originalHeaders // restore headers, just in case initialization removed keys - h.mux.UnknownHandler.ServeHTTP(writer, request) + m.UnknownHandler.ServeHTTP(writer, request) return } @@ -81,16 +75,15 @@ func (h *handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { op.handle() } -func (h *handler) newOperation(writer http.ResponseWriter, request *http.Request) *operation { +func (m *Mux) newOperation(writer http.ResponseWriter, request *http.Request) *operation { ctx, cancel := context.WithCancel(request.Context()) request = request.WithContext(ctx) op := &operation{ - writer: writer, - request: request, - cancel: cancel, - bufferPool: h.bufferPool, - canDecompress: h.canDecompress, - compressionPools: h.mux.compressionPools, + writer: writer, + request: request, + cancel: cancel, + bufferPool: &m.bufferPool, + compressors: m.compressors, } op.requestLine.fromRequest(request) return op @@ -235,33 +228,14 @@ func classifyRequest(req *http.Request) (clientProtocolHandler, url.Values) { } } -type codecKey struct { - res TypeResolver - name string -} - -func newCodecMap(methodConfigs map[string]*methodConfig, codecs map[string]func(TypeResolver) Codec) map[codecKey]Codec { - result := make(map[codecKey]Codec, len(codecs)) - for _, conf := range methodConfigs { - for codecName, codecFactory := range codecs { - key := codecKey{res: conf.resolver, name: codecName} - if _, exists := result[key]; !exists { - result[key] = codecFactory(conf.resolver) - } - } - } - return result -} - // operation represents a single HTTP operation, which maps to an incoming HTTP request. // It tracks properties needed to implement protocol transformation. type operation struct { - writer http.ResponseWriter - request *http.Request - cancel context.CancelFunc - bufferPool *bufferPool - canDecompress []string - compressionPools map[string]*compressionPool + writer http.ResponseWriter + request *http.Request + cancel context.CancelFunc + bufferPool *bufferPool + compressors compressionMap queryVars url.Values originalHeaders http.Header @@ -325,7 +299,7 @@ func (o *operation) HandlerInfo() PeerInfo { func (o *operation) doNotImplement() {} -func (o *operation) validate(mux *Mux, codecs map[codecKey]Codec) error { +func (o *operation) validate(mux *Mux, codecs codecMap) error { // Identify the protocol. clientProtoHandler, queryVars := classifyRequest(o.request) if clientProtoHandler == nil { @@ -381,12 +355,12 @@ func (o *operation) validate(mux *Mux, codecs map[codecKey]Codec) error { } if reqMeta.compression != "" { var ok bool - o.client.reqCompression, ok = o.compressionPools[reqMeta.compression] + o.client.reqCompression, ok = o.compressors[reqMeta.compression] if !ok { return newHTTPError(http.StatusUnsupportedMediaType, "%q compression not supported", reqMeta.compression) } } - o.client.codec = codecs[codecKey{res: o.methodConf.resolver, name: reqMeta.codec}] + o.client.codec = codecs.get(reqMeta.codec, o.methodConf.resolver) if o.client.codec == nil { return newHTTPError(http.StatusUnsupportedMediaType, "%q sub-format not supported", reqMeta.codec) } @@ -418,11 +392,11 @@ func (o *operation) validate(mux *Mux, codecs map[codecKey]Codec) error { // NB: This is fine to set even if a custom content-type is used via // the use of google.api.HttpBody. The actual content-type and body // data will be written via serverBodyPreparer implementation. - o.server.codec = codecs[codecKey{res: o.methodConf.resolver, name: CodecJSON}] + o.server.codec = codecs.get(CodecJSON, o.methodConf.resolver) } else if _, supportsCodec := o.methodConf.codecNames[reqMeta.codec]; supportsCodec { o.server.codec = o.client.codec } else { - o.server.codec = codecs[codecKey{res: o.methodConf.resolver, name: o.methodConf.preferredCodec}] + o.server.codec = codecs.get(o.methodConf.preferredCodec, o.methodConf.resolver) } if reqMeta.compression != "" { @@ -542,7 +516,7 @@ func (o *operation) handle() { //nolint:gocyclo serverReqMeta := o.reqMeta serverReqMeta.codec = o.server.codec.Name() serverReqMeta.compression = o.server.reqCompression.Name() - serverReqMeta.acceptCompression = intersect(o.reqMeta.acceptCompression, o.canDecompress) + serverReqMeta.acceptCompression = o.compressors.intersection(o.reqMeta.acceptCompression) o.server.protocol.addProtocolRequestHeaders(serverReqMeta, o.request.Header) // Now we can define the transformed response writer (which delays @@ -1062,7 +1036,7 @@ func (w *responseWriter) WriteHeader(statusCode int) { respMeta.compression = "" // normalize to empty string } if respMeta.compression != "" { - respCompression, ok := w.op.compressionPools[respMeta.compression] + respCompression, ok := w.op.compressors[respMeta.compression] if !ok { w.reportError(fmt.Errorf("response indicates unsupported compression encoding %q", respMeta.compression)) return @@ -1222,7 +1196,7 @@ func (w *responseWriter) flushHeaders() { cliRespMeta := *w.respMeta cliRespMeta.codec = w.op.client.codec.Name() cliRespMeta.compression = w.op.client.respCompression.Name() - cliRespMeta.acceptCompression = intersect(w.respMeta.acceptCompression, w.op.canDecompress) + cliRespMeta.acceptCompression = w.op.compressors.intersection(w.respMeta.acceptCompression) statusCode := w.op.client.protocol.addProtocolResponseHeaders(cliRespMeta, w.Header()) hasErr := w.respMeta.end != nil && w.respMeta.end.err != nil // We only buffer full response for unary operations, so if we have an error, @@ -2193,25 +2167,3 @@ func (l *requestLine) fromRequest(req *http.Request) { l.queryString = req.URL.RawQuery l.httpVersion = req.Proto } - -func intersect(setA, setB []string) []string { - length := len(setA) - if len(setB) < length { - length = len(setB) - } - if length == 0 { - // If either set is empty, the intersection is empty. - // We don't use nil since it is used in places as a sentinel. - return make([]string, 0) - } - result := make([]string, 0, length) - for _, item := range setA { - for _, other := range setB { - if other == item { - result = append(result, item) - break - } - } - } - return result -} diff --git a/handler_bench_test.go b/handler_bench_test.go index bd63480..b38caed 100644 --- a/handler_bench_test.go +++ b/handler_bench_test.go @@ -148,8 +148,6 @@ func BenchmarkServeHTTP(b *testing.B) { req.Header.Set("Grpc-Timeout", "1S") req.Header.Set("Grpc-Accept-Encoding", "gzip") - hdlr := mux.AsHandler() - b.StartTimer() b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { @@ -158,7 +156,7 @@ func BenchmarkServeHTTP(b *testing.B) { req.Body = io.NopCloser(bytes.NewReader(reqGRPCBody)) rsp := httptest.NewRecorder() - hdlr.ServeHTTP(rsp, req) + mux.ServeHTTP(rsp, req) assert.Equal(b, http.StatusOK, rsp.Code, "response code") assert.Equal(b, "0", rsp.Header().Get("Grpc-Status"), "response status") assert.Equal(b, rspGRPCBody, rsp.Body.Bytes(), "response body") @@ -192,8 +190,6 @@ func BenchmarkServeHTTP(b *testing.B) { req.Header.Set("Content-Type", "application/json") req.Header.Set("X-Server-Timeout", "1000") - hdlr := mux.AsHandler() - b.StartTimer() b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { @@ -202,7 +198,7 @@ func BenchmarkServeHTTP(b *testing.B) { req.Body = io.NopCloser(bytes.NewReader(reqMsgBookJSON)) rsp := httptest.NewRecorder() - hdlr.ServeHTTP(rsp, req) + mux.ServeHTTP(rsp, req) assert.Equal(b, http.StatusOK, rsp.Code, "response code") assert.Equal(b, "application/json", rsp.Header().Get("Content-Type"), "response content type") data := rsp.Body.Bytes() @@ -238,8 +234,6 @@ func BenchmarkServeHTTP(b *testing.B) { req.Header.Set("Grpc-Timeout", "1S") req.Header.Set("Grpc-Accept-Encoding", "gzip") - hdlr := mux.AsHandler() - b.StartTimer() b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { @@ -248,7 +242,7 @@ func BenchmarkServeHTTP(b *testing.B) { req.Body = io.NopCloser(bytes.NewReader(reqGRPCBody)) rsp := httptest.NewRecorder() - hdlr.ServeHTTP(rsp, req) + mux.ServeHTTP(rsp, req) assert.Equal(b, http.StatusOK, rsp.Code, "response code") assert.Equal(b, "application/grpc+proto", rsp.Header().Get("Content-Type"), "response content type") assert.Equal(b, rspGRPCBody, rsp.Body.Bytes(), "response body") @@ -283,8 +277,6 @@ func BenchmarkServeHTTP(b *testing.B) { req.Header.Set("Connect-Timeout-Ms", "1000") req.ContentLength = int64(len(reqMsgJSON)) - hdlr := mux.AsHandler() - b.StartTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { @@ -292,7 +284,7 @@ func BenchmarkServeHTTP(b *testing.B) { req.Body = io.NopCloser(bytes.NewReader(reqMsgJSON)) rsp := httptest.NewRecorder() - hdlr.ServeHTTP(rsp, req) + mux.ServeHTTP(rsp, req) assert.Equal(b, http.StatusOK, rsp.Code, "response code") assert.Equal(b, "application/json", rsp.Header().Get("Content-Type"), "response content type") data := rsp.Body.Bytes() @@ -331,8 +323,6 @@ func BenchmarkServeHTTP(b *testing.B) { req.Header.Set("Connect-Timeout-Ms", "1000") req.ContentLength = int64(len(reqMsgProtoComp)) - hdlr := mux.AsHandler() - b.StartTimer() b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { @@ -341,7 +331,7 @@ func BenchmarkServeHTTP(b *testing.B) { req.Body = io.NopCloser(bytes.NewReader(reqMsgProtoComp)) rsp := httptest.NewRecorder() - hdlr.ServeHTTP(rsp, req) + mux.ServeHTTP(rsp, req) assert.Equal(b, http.StatusOK, rsp.Code, "response code") assert.Equal(b, "application/proto", rsp.Header().Get("Content-Type"), "response content type") assert.Equal(b, rspMsgProtoComp, rsp.Body.Bytes(), "response body") @@ -370,13 +360,11 @@ func BenchmarkServeHTTP(b *testing.B) { ); err != nil { b.Fatal(err) } - req := httptest.NewRequest(http.MethodPost, "/raw/file.bin", nil) + req := httptest.NewRequest(http.MethodPost, "/file.bin:upload", nil) req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("X-Server-Timeout", "1000") req.ContentLength = int64(len(largePayload)) - hdlr := mux.AsHandler() - b.StartTimer() b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { @@ -385,7 +373,7 @@ func BenchmarkServeHTTP(b *testing.B) { req.Body = io.NopCloser(bytes.NewReader(largePayload)) rsp := httptest.NewRecorder() - hdlr.ServeHTTP(rsp, req) + mux.ServeHTTP(rsp, req) assert.Equal(b, http.StatusOK, rsp.Code, "response code") assert.Equal(b, "application/json", rsp.Header().Get("Content-Type"), "response content type") assert.Equal(b, "{}", rsp.Body.String(), "response body") diff --git a/handler_test.go b/handler_test.go index 1f074b7..5a1ae97 100644 --- a/handler_test.go +++ b/handler_test.go @@ -396,7 +396,7 @@ func TestHandler_Errors(t *testing.T) { } } respWriter := httptest.NewRecorder() - targetMux.AsHandler().ServeHTTP(respWriter, req) + targetMux.ServeHTTP(respWriter, req) resp := respWriter.Result() err := resp.Body.Close() require.NoError(t, err) @@ -447,7 +447,7 @@ func TestHandler_PassThrough(t *testing.T) { require.NoError(t, mux.RegisterServiceByName(checkPassThrough(contentHandler), testv1connect.ContentServiceName)) // Use HTTP/2 so we can test a bidi stream. - server := httptest.NewUnstartedServer(mux.AsHandler()) + server := httptest.NewUnstartedServer(&mux) server.EnableHTTP2 = true server.StartTLS() t.Cleanup(server.Close) @@ -767,7 +767,7 @@ func TestMessage_AdvanceStage(t *testing.T) { respComp = abcCompression.newPool() } op := &operation{ - bufferPool: newBufferPool(), + bufferPool: &bufferPool{}, client: clientProtocolDetails{ codec: clientCodec, reqCompression: clientReqComp, @@ -959,65 +959,6 @@ func TestMessage_AdvanceStage(t *testing.T) { } } -func TestIntersection(t *testing.T) { - t.Parallel() - testCases := []struct { - name string - a, b, result []string - resultCap int - }{ - { - name: "b is superset", - a: []string{"a", "b", "c"}, - b: []string{"a", "b", "c", "d", "e", "f"}, - result: []string{"a", "b", "c"}, - resultCap: 3, - }, - { - name: "a is superset", - a: []string{"a", "b", "c", "d", "e", "f"}, - b: []string{"a", "b", "c"}, - result: []string{"a", "b", "c"}, - resultCap: 3, - }, - { - name: "a is empty", - a: nil, - b: []string{"a", "b", "c", "d", "e", "f"}, - result: []string{}, - }, - { - name: "b is empty", - a: []string{"a", "b", "c"}, - b: nil, - result: []string{}, - }, - { - name: "result is empty", - a: []string{"a", "b", "c"}, - b: []string{"d", "e", "f"}, - result: []string{}, // only nil when one of the inputs is empty - resultCap: 3, - }, - { - name: "result is subset of both", - a: []string{"x", "y", "z", "a", "b", "c"}, - b: []string{"a", "b", "c", "d", "e", "f"}, - result: []string{"a", "b", "c"}, - resultCap: 6, - }, - } - for _, testCase := range testCases { - testCase := testCase - t.Run(testCase.name, func(t *testing.T) { - t.Parallel() - result := intersect(testCase.a, testCase.b) - require.Equal(t, testCase.result, result) - require.Equal(t, testCase.resultCap, cap(result)) - }) - } -} - func checkStageEmpty(t *testing.T, msg *message, compressed bool) { t.Helper() require.Equal(t, stageEmpty, msg.stage) diff --git a/internal/examples/connect+grpc/cmd/server/main.go b/internal/examples/connect+grpc/cmd/server/main.go index 6601625..caf0e97 100644 --- a/internal/examples/connect+grpc/cmd/server/main.go +++ b/internal/examples/connect+grpc/cmd/server/main.go @@ -35,7 +35,7 @@ import ( func main() { server := grpc.NewServer() elizav1grpc.RegisterElizaServiceServer(server, elizaImpl{}) - mux := vanguard.Mux{ + mux := &vanguard.Mux{ Protocols: []vanguard.Protocol{vanguard.ProtocolGRPC}, Codecs: []string{vanguard.CodecProto}, } @@ -49,7 +49,7 @@ func main() { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } - err = http.Serve(l, mux.AsHandler()) + err = http.Serve(l, mux) if err != http.ErrServerClosed { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) diff --git a/internal/examples/fileserver/main.go b/internal/examples/fileserver/main.go index 34655d2..4e1cc93 100644 --- a/internal/examples/fileserver/main.go +++ b/internal/examples/fileserver/main.go @@ -54,7 +54,7 @@ func main() { log.Fatal(err) } log.Printf("Serving %s on HTTP port: %s\n", *directory, *port) - log.Fatal(http.ListenAndServe(":"+*port, mux.AsHandler())) + log.Fatal(http.ListenAndServe(":"+*port, mux)) } var indexHTMLTemplate = template.Must(template.New("http").Parse(` diff --git a/internal/examples/pets/cmd/pets-be/main.go b/internal/examples/pets/cmd/pets-be/main.go index fca1b43..71da997 100644 --- a/internal/examples/pets/cmd/pets-be/main.go +++ b/internal/examples/pets/cmd/pets-be/main.go @@ -54,7 +54,7 @@ func main() { } serveMux := http.NewServeMux() - serveMux.Handle("/", internal.TraceHandler(mux.AsHandler())) + serveMux.Handle("/", internal.TraceHandler(mux)) serveMux.Handle(grpcreflect.NewHandlerV1(grpcreflect.NewStaticReflector(petstorev2connect.PetServiceName))) listener, err := net.Listen("tcp", "127.0.0.1:30304") diff --git a/internal/examples/pets/cmd/pets-fe/main.go b/internal/examples/pets/cmd/pets-fe/main.go index 652e316..547bee4 100644 --- a/internal/examples/pets/cmd/pets-fe/main.go +++ b/internal/examples/pets/cmd/pets-fe/main.go @@ -85,7 +85,7 @@ func main() { log.Fatal(err) } serveMux := http.NewServeMux() - serveMux.Handle("/", internal.TraceHandler(mux.AsHandler())) + serveMux.Handle("/", internal.TraceHandler(mux)) serveMux.Handle(grpcreflect.NewHandlerV1(grpcreflect.NewStaticReflector(petstorev2connect.PetServiceName))) svrs[i] = &http.Server{ Addr: ":http", diff --git a/vanguard.go b/vanguard.go index 021bab9..d31cd89 100644 --- a/vanguard.go +++ b/vanguard.go @@ -19,9 +19,7 @@ import ( "fmt" "math" "net/http" - "sort" "strings" - "sync" "time" "connectrpc.com/connect" @@ -53,9 +51,8 @@ const ( // REST). // // All services should be registered (via Register* methods) from a single -// thread during initialization. The handler returned from the AsHandler -// method is only thread-safe among concurrently executing HTTP requests. It -// is not safe to mutate the Mux once its handler is being used by a server. +// thread during initialization. It is not safe to mutate the Mux once its +// handler is being used by a server. type Mux struct { // The protocols that are supported by the wrapped handler, by default. // This can be overridden on a per-service level via options when calling @@ -154,30 +151,11 @@ type Mux struct { // a nil error. HooksCallback func(context.Context, Operation) (Hooks, error) - init sync.Once - codecImpls map[string]func(TypeResolver) Codec - compressionPools map[string]*compressionPool - methods map[string]*methodConfig - restRoutes routeTrie -} - -// AsHandler returns HTTP middleware that applies the given configuration -// to handlers. -// -// This should only be called after the configuration is finalized. -func (m *Mux) AsHandler() http.Handler { - m.maybeInit() - canDecompress := make([]string, 0, len(m.compressionPools)) - for compression := range m.compressionPools { - canDecompress = append(canDecompress, compression) - } - sort.Strings(canDecompress) - return &handler{ - mux: m, - bufferPool: newBufferPool(), - codecs: newCodecMap(m.methods, m.codecImpls), - canDecompress: canDecompress, - } + bufferPool bufferPool + codecs codecMap + compressors compressionMap + methods map[string]*methodConfig + restRoutes routeTrie } // RegisterServiceByName registers the given handler for the named service. @@ -231,7 +209,7 @@ func (m *Mux) RegisterService(handler http.Handler, serviceDesc protoreflect.Ser CodecJSON: {}, }, false) for codecName := range svcOpts.codecNames { - if _, known := m.codecImpls[codecName]; !known { + if _, known := m.codecs[codecName]; !known { return fmt.Errorf("codec %s is not known; use mux.AddCodec to add known codecs first", codecName) } } @@ -247,7 +225,7 @@ func (m *Mux) RegisterService(handler http.Handler, serviceDesc protoreflect.Ser CompressionGzip: {}, }, true) for compressorName := range svcOpts.compressorNames { - if _, known := m.compressionPools[compressorName]; !known { + if _, known := m.compressors[compressorName]; !known { return fmt.Errorf("compression algorithm %s is not known; use mux.AddCompression to add known algorithms first", compressorName) } } @@ -349,7 +327,7 @@ func (m *Mux) RegisterRules(rules ...*annotations.HttpRule) error { // different configuration. func (m *Mux) AddCodec(name string, newCodec func(TypeResolver) Codec) { m.maybeInit() - m.codecImpls[name] = newCodec + m.codecs[name] = newCodec } // AddCompression adds the given compression algorithm implementation. @@ -363,7 +341,7 @@ func (m *Mux) AddCodec(name string, newCodec func(TypeResolver) Codec) { // or compression level. func (m *Mux) AddCompression(name string, newCompressor func() connect.Compressor, newDecompressor func() connect.Decompressor) { m.maybeInit() - m.compressionPools[name] = newCompressionPool(name, newCompressor, newDecompressor) + m.compressors[name] = newCompressionPool(name, newCompressor, newDecompressor) } func (m *Mux) registerMethod(handler http.Handler, methodDesc protoreflect.MethodDescriptor, opts serviceOptions) error { @@ -384,6 +362,9 @@ func (m *Mux) registerMethod(handler http.Handler, methodDesc protoreflect.Metho maxGetURLBytes: opts.maxGetURLBytes, hooksCallback: opts.hooksCallback, } + if m.methods == nil { + m.methods = make(map[string]*methodConfig, 1) + } m.methods[methodPath] = methodConf switch { @@ -424,19 +405,19 @@ func (m *Mux) addRule(httpRule *annotations.HttpRule, methodConf *methodConfig) } func (m *Mux) maybeInit() { - m.init.Do(func() { - // initialize default codecs and compressors - m.codecImpls = map[string]func(res TypeResolver) Codec{ - CodecProto: DefaultProtoCodec, - CodecJSON: func(res TypeResolver) Codec { - return DefaultJSONCodec(res) - }, - } - m.compressionPools = map[string]*compressionPool{ - CompressionGzip: newCompressionPool(CompressionGzip, DefaultGzipCompressor, DefaultGzipDecompressor), - } - m.methods = map[string]*methodConfig{} - }) + if m.codecs != nil { + return // already initialized + } + // initialize default codecs and compressors + m.codecs = map[string]func(res TypeResolver) Codec{ + CodecProto: DefaultProtoCodec, + CodecJSON: func(res TypeResolver) Codec { + return DefaultJSONCodec(res) + }, + } + m.compressors = map[string]*compressionPool{ + CompressionGzip: newCompressionPool(CompressionGzip, DefaultGzipCompressor, DefaultGzipDecompressor), + } } // ServiceOption is an option for configuring how the middleware will handle diff --git a/vanguard_examples_test.go b/vanguard_examples_test.go index 828f270..5d30e2f 100644 --- a/vanguard_examples_test.go +++ b/vanguard_examples_test.go @@ -59,7 +59,7 @@ func ExampleMux_connectToGRPC() { // Create the server. // (NB: This is a httptest.Server, but it could be any http.Server) - server := httptest.NewUnstartedServer(mux.AsHandler()) + server := httptest.NewUnstartedServer(mux) server.EnableHTTP2 = true server.StartTLS() defer server.Close() @@ -103,7 +103,7 @@ func ExampleMux_restToGRPC() { // Create the server. // (NB: This is a httptest.Server, but it could be any http.Server) - server := httptest.NewServer(mux.AsHandler()) + server := httptest.NewServer(mux) defer server.Close() client := server.Client() @@ -163,7 +163,7 @@ func ExampleMux_connectToREST() { // Create the server. // (NB: This is a httptest.Server, but it could be any http.Server) - server := httptest.NewServer(mux.AsHandler()) + server := httptest.NewServer(mux) defer server.Close() // Create a connect client and call the service. diff --git a/vanguard_restxrpc_test.go b/vanguard_restxrpc_test.go index 4c20374..e2e63d4 100644 --- a/vanguard_restxrpc_test.go +++ b/vanguard_restxrpc_test.go @@ -516,7 +516,7 @@ func TestMux_RESTxRPC(t *testing.T) { t.Log("req:", string(debug)) rsp := httptest.NewRecorder() - opts.mux.mux.AsHandler().ServeHTTP(rsp, req) + opts.mux.mux.ServeHTTP(rsp, req) result := rsp.Result() defer result.Body.Close() diff --git a/vanguard_rpcxrest_test.go b/vanguard_rpcxrest_test.go index 5214210..1b59432 100644 --- a/vanguard_rpcxrest_test.go +++ b/vanguard_rpcxrest_test.go @@ -93,7 +93,7 @@ func TestMux_RPCxREST(t *testing.T) { t.Fatal(err) } } - server := httptest.NewUnstartedServer(mux.AsHandler()) + server := httptest.NewUnstartedServer(mux) server.EnableHTTP2 = true server.StartTLS() disableCompression(server) diff --git a/vanguard_rpcxrpc_test.go b/vanguard_rpcxrpc_test.go index ea8ce3a..d52607e 100644 --- a/vanguard_rpcxrpc_test.go +++ b/vanguard_rpcxrpc_test.go @@ -82,7 +82,7 @@ func TestMux_RPCxRPC(t *testing.T) { err := mux.RegisterServiceByName(hdlr, service, opts...) require.NoError(t, err) } - server := httptest.NewUnstartedServer(mux.AsHandler()) + server := httptest.NewUnstartedServer(mux) server.EnableHTTP2 = true server.StartTLS() disableCompression(server) diff --git a/vanguard_test.go b/vanguard_test.go index 0749e9b..1f7f317 100644 --- a/vanguard_test.go +++ b/vanguard_test.go @@ -525,7 +525,7 @@ func TestMux_BufferTooLargeFails(t *testing.T) { require.NoError(t, err) err = mux.RegisterServiceByName(hdlr, testv1connect.ContentServiceName, svcOpts...) require.NoError(t, err) - server := httptest.NewUnstartedServer(mux.AsHandler()) + server := httptest.NewUnstartedServer(mux) server.EnableHTTP2 = true server.StartTLS() disableCompression(server) @@ -572,7 +572,7 @@ func TestMux_ConnectGetUsesPostIfRequestTooLarge(t *testing.T) { muxWithSetting := &Mux{MaxGetURLBytes: 512, Compressors: []string{}} err := muxWithSetting.RegisterServiceByName(hdlr, testv1connect.LibraryServiceName) require.NoError(t, err) - serverWithSetting := httptest.NewServer(muxWithSetting.AsHandler()) + serverWithSetting := httptest.NewServer(muxWithSetting) disableCompression(serverWithSetting) t.Cleanup(serverWithSetting.Close) @@ -584,7 +584,7 @@ func TestMux_ConnectGetUsesPostIfRequestTooLarge(t *testing.T) { WithNoCompression(), ) require.NoError(t, err) - serverWithSvcOption := httptest.NewServer(muxWithSvcOption.AsHandler()) + serverWithSvcOption := httptest.NewServer(muxWithSvcOption) disableCompression(serverWithSvcOption) t.Cleanup(serverWithSvcOption.Close) @@ -767,12 +767,11 @@ func TestMux_MessageHooks(t *testing.T) { HooksCallback: makeHooks(serverCase.reqHook, serverCase.respHook), } require.NoError(t, mux.RegisterServiceByName(contentHandler, testv1connect.ContentServiceName)) - handler := mux.AsHandler() // propagate test name into context so that request and response hooks can access it setContextHandler := http.HandlerFunc(func(respWriter http.ResponseWriter, request *http.Request) { testName := request.Header.Get("Test") ctx := context.WithValue(request.Context(), testCaseNameContextKey{}, testName) - handler.ServeHTTP(respWriter, request.WithContext(ctx)) + mux.ServeHTTP(respWriter, request.WithContext(ctx)) }) // Use HTTP/2 so we can test a bidi stream. server := httptest.NewUnstartedServer(setContextHandler) @@ -1198,12 +1197,11 @@ func TestMux_HookOrder(t *testing.T) { } mux := &Mux{HooksCallback: callback} require.NoError(t, mux.RegisterServiceByName(contentHandler, testv1connect.ContentServiceName)) - handler := mux.AsHandler() // propagate test name into context so that hooks can access it setContextHandler := http.HandlerFunc(func(respWriter http.ResponseWriter, request *http.Request) { testName := request.Header.Get("Test") ctx := context.WithValue(request.Context(), testCaseNameContextKey{}, testName) - handler.ServeHTTP(respWriter, request.WithContext(ctx)) + mux.ServeHTTP(respWriter, request.WithContext(ctx)) }) // Use HTTP/2 so we can test a bidi stream. server := httptest.NewUnstartedServer(setContextHandler) @@ -1639,7 +1637,7 @@ func TestRuleSelector(t *testing.T) { }) defer interceptor.del(t) - mux.AsHandler().ServeHTTP(rsp, req) + mux.ServeHTTP(rsp, req) result := rsp.Result() defer result.Body.Close()