Skip to content

Commit

Permalink
Drop AsHandler method from Mux
Browse files Browse the repository at this point in the history
  • Loading branch information
emcfarlane committed Sep 27, 2023
1 parent 72857a0 commit c139bdc
Show file tree
Hide file tree
Showing 17 changed files with 143 additions and 246 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,12 @@ Finally, you can register the middleware handler with an `http.Server` or

```go
// The Mux can be used as the sole handler for an HTTP server.
err := http.Serve(listener, vanguardMux.AsHandler())
err := http.Serve(listener, vanguardMux)

// Or it can be used alongside other handlers, all registered with
// the same http.ServeMux.
mux := http.NewServeMux()
mux.Handle("/", vanguardMux.AsHandler())
mux.Handle("/", vanguardMux)
err := http.Serve(listener, mux)
```
The above example registers the handler for the root path. This is useful
Expand Down
4 changes: 0 additions & 4 deletions buffer_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ type bufferPool struct {
sync.Pool
}

func newBufferPool() *bufferPool {
return &bufferPool{}
}

func (b *bufferPool) Get() *bytes.Buffer {
if buffer, ok := b.Pool.Get().(*bytes.Buffer); ok {
buffer.Reset()
Expand Down
60 changes: 40 additions & 20 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ type RESTCodec interface {
// The returned codec implements StableCodec, in addition to
// Codec.
func DefaultProtoCodec(res TypeResolver) Codec {
return &protoCodec{Resolver: res}
return protoCodec{
UnmarshalOptions: proto.UnmarshalOptions{Resolver: res},
}
}

// DefaultJSONCodec is the default codec factory used for the codec named
Expand All @@ -110,30 +112,30 @@ type JSONCodec struct {
UnmarshalOptions protojson.UnmarshalOptions
}

var _ StableCodec = (*JSONCodec)(nil)
var _ RESTCodec = (*JSONCodec)(nil)
var _ StableCodec = JSONCodec{}
var _ RESTCodec = JSONCodec{}

func (j *JSONCodec) Name() string {
func (j JSONCodec) Name() string {
return CodecJSON
}

func (j *JSONCodec) IsBinary() bool {
func (j JSONCodec) IsBinary() bool {
return false
}

func (j *JSONCodec) MarshalAppend(base []byte, msg proto.Message) ([]byte, error) {
func (j JSONCodec) MarshalAppend(base []byte, msg proto.Message) ([]byte, error) {
return j.MarshalOptions.MarshalAppend(base, msg)
}

func (j *JSONCodec) MarshalAppendStable(base []byte, msg proto.Message) ([]byte, error) {
func (j JSONCodec) MarshalAppendStable(base []byte, msg proto.Message) ([]byte, error) {
data, err := j.MarshalOptions.MarshalAppend(base, msg)
if err != nil {
return nil, err
}
return jsonStabilize(data)
}

func (j *JSONCodec) MarshalAppendField(base []byte, msg proto.Message, field protoreflect.FieldDescriptor) ([]byte, error) {
func (j JSONCodec) MarshalAppendField(base []byte, msg proto.Message, field protoreflect.FieldDescriptor) ([]byte, error) {
if field.Message() != nil && field.Cardinality() != protoreflect.Repeated {
return j.MarshalAppend(base, msg.ProtoReflect().Get(field).Message().Interface())
}
Expand Down Expand Up @@ -191,7 +193,7 @@ func (j *JSONCodec) MarshalAppendField(base []byte, msg proto.Message, field pro
return nil, fmt.Errorf("JSON does not contain key %s", fieldName)
}

func (j *JSONCodec) UnmarshalField(data []byte, msg proto.Message, field protoreflect.FieldDescriptor) error {
func (j JSONCodec) UnmarshalField(data []byte, msg proto.Message, field protoreflect.FieldDescriptor) error {
if field.Message() != nil && field.Cardinality() != protoreflect.Repeated {
return j.Unmarshal(data, msg.ProtoReflect().Mutable(field).Message().Interface())
}
Expand All @@ -211,11 +213,11 @@ func (j *JSONCodec) UnmarshalField(data []byte, msg proto.Message, field protore
return j.Unmarshal(buf.Bytes(), msg)
}

func (j *JSONCodec) Unmarshal(bytes []byte, msg proto.Message) error {
func (j JSONCodec) Unmarshal(bytes []byte, msg proto.Message) error {
return j.UnmarshalOptions.Unmarshal(bytes, msg)
}

func (j *JSONCodec) fieldName(field protoreflect.FieldDescriptor) string {
func (j JSONCodec) fieldName(field protoreflect.FieldDescriptor) string {
if !j.MarshalOptions.UseProtoNames {
return field.JSONName()
}
Expand All @@ -226,28 +228,33 @@ func (j *JSONCodec) fieldName(field protoreflect.FieldDescriptor) string {
return string(field.Name())
}

type protoCodec proto.UnmarshalOptions
type protoCodec struct {
proto.MarshalOptions
proto.UnmarshalOptions
}

var _ StableCodec = (*protoCodec)(nil)
var _ StableCodec = protoCodec{}

func (p *protoCodec) Name() string {
func (p protoCodec) Name() string {
return CodecProto
}

func (p *protoCodec) IsBinary() bool {
func (p protoCodec) IsBinary() bool {
return true
}

func (p *protoCodec) MarshalAppend(base []byte, msg proto.Message) ([]byte, error) {
func (p protoCodec) MarshalAppend(base []byte, msg proto.Message) ([]byte, error) {
return proto.MarshalOptions{}.MarshalAppend(base, msg)
}

func (p *protoCodec) MarshalAppendStable(base []byte, msg proto.Message) ([]byte, error) {
return proto.MarshalOptions{Deterministic: true}.MarshalAppend(base, msg)
func (p protoCodec) MarshalAppendStable(base []byte, msg proto.Message) ([]byte, error) {
opts := p.MarshalOptions
opts.Deterministic = true
return opts.MarshalAppend(base, msg)
}

func (p *protoCodec) Unmarshal(bytes []byte, msg proto.Message) error {
return (*proto.UnmarshalOptions)(p).Unmarshal(bytes, msg)
func (p protoCodec) Unmarshal(bytes []byte, msg proto.Message) error {
return p.UnmarshalOptions.Unmarshal(bytes, msg)
}

func jsonStabilize(data []byte) ([]byte, error) {
Expand All @@ -260,3 +267,16 @@ func jsonStabilize(data []byte) ([]byte, error) {
}
return buf.Bytes(), nil
}

type codecMap map[string]func(TypeResolver) Codec

func (m codecMap) get(name string, resolver TypeResolver) Codec {
if m == nil {
return nil
}
codecFn, ok := m[name]
if !ok {
return nil
}
return codecFn(resolver)
}
21 changes: 21 additions & 0 deletions compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ func DefaultGzipDecompressor() connect.Decompressor {
return &gzip.Reader{}
}

type compressionMap map[string]*compressionPool

func (m compressionMap) intersection(names []string) []string {
length := len(names)
if len(m) < length {
length = len(m)
}
if length == 0 {
// If either set is empty, the intersection is empty.
// We don't use nil since it is used in places as a sentinel.
return make([]string, 0)
}
intersection := make([]string, 0, length)
for _, name := range names {
if _, ok := m[name]; ok {
intersection = append(intersection, name)
}
}
return intersection
}

type compressionPool struct {
name string
decompressors sync.Pool
Expand Down
100 changes: 26 additions & 74 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,17 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
)

type handler struct {
mux *Mux
bufferPool *bufferPool
codecs map[codecKey]Codec
canDecompress []string
}

func (h *handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
op := h.newOperation(writer, request)
err := op.validate(h.mux, h.codecs)
// ServeHTTP implements http.Handler.
func (m *Mux) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
op := m.newOperation(writer, request)
err := op.validate(m, m.codecs)

useUnknownHandler := h.mux.UnknownHandler != nil && errors.Is(err, errNotFound)
useUnknownHandler := m.UnknownHandler != nil && errors.Is(err, errNotFound)
var callback func(context.Context, Operation) (Hooks, error)
if op.methodConf != nil {
callback = op.methodConf.hooksCallback
} else {
callback = h.mux.HooksCallback
callback = m.HooksCallback
}
if callback != nil {
var hookErr error
Expand All @@ -58,7 +52,7 @@ func (h *handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
}
if useUnknownHandler {
request.Header = op.originalHeaders // restore headers, just in case initialization removed keys
h.mux.UnknownHandler.ServeHTTP(writer, request)
m.UnknownHandler.ServeHTTP(writer, request)
return
}

Expand All @@ -81,16 +75,15 @@ func (h *handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
op.handle()
}

func (h *handler) newOperation(writer http.ResponseWriter, request *http.Request) *operation {
func (m *Mux) newOperation(writer http.ResponseWriter, request *http.Request) *operation {
ctx, cancel := context.WithCancel(request.Context())
request = request.WithContext(ctx)
op := &operation{
writer: writer,
request: request,
cancel: cancel,
bufferPool: h.bufferPool,
canDecompress: h.canDecompress,
compressionPools: h.mux.compressionPools,
writer: writer,
request: request,
cancel: cancel,
bufferPool: &m.bufferPool,
compressors: m.compressors,
}
op.requestLine.fromRequest(request)
return op
Expand Down Expand Up @@ -235,33 +228,14 @@ func classifyRequest(req *http.Request) (clientProtocolHandler, url.Values) {
}
}

type codecKey struct {
res TypeResolver
name string
}

func newCodecMap(methodConfigs map[string]*methodConfig, codecs map[string]func(TypeResolver) Codec) map[codecKey]Codec {
result := make(map[codecKey]Codec, len(codecs))
for _, conf := range methodConfigs {
for codecName, codecFactory := range codecs {
key := codecKey{res: conf.resolver, name: codecName}
if _, exists := result[key]; !exists {
result[key] = codecFactory(conf.resolver)
}
}
}
return result
}

// operation represents a single HTTP operation, which maps to an incoming HTTP request.
// It tracks properties needed to implement protocol transformation.
type operation struct {
writer http.ResponseWriter
request *http.Request
cancel context.CancelFunc
bufferPool *bufferPool
canDecompress []string
compressionPools map[string]*compressionPool
writer http.ResponseWriter
request *http.Request
cancel context.CancelFunc
bufferPool *bufferPool
compressors compressionMap

queryVars url.Values
originalHeaders http.Header
Expand Down Expand Up @@ -325,7 +299,7 @@ func (o *operation) HandlerInfo() PeerInfo {

func (o *operation) doNotImplement() {}

func (o *operation) validate(mux *Mux, codecs map[codecKey]Codec) error {
func (o *operation) validate(mux *Mux, codecs codecMap) error {
// Identify the protocol.
clientProtoHandler, queryVars := classifyRequest(o.request)
if clientProtoHandler == nil {
Expand Down Expand Up @@ -381,12 +355,12 @@ func (o *operation) validate(mux *Mux, codecs map[codecKey]Codec) error {
}
if reqMeta.compression != "" {
var ok bool
o.client.reqCompression, ok = o.compressionPools[reqMeta.compression]
o.client.reqCompression, ok = o.compressors[reqMeta.compression]
if !ok {
return newHTTPError(http.StatusUnsupportedMediaType, "%q compression not supported", reqMeta.compression)
}
}
o.client.codec = codecs[codecKey{res: o.methodConf.resolver, name: reqMeta.codec}]
o.client.codec = codecs.get(reqMeta.codec, o.methodConf.resolver)
if o.client.codec == nil {
return newHTTPError(http.StatusUnsupportedMediaType, "%q sub-format not supported", reqMeta.codec)
}
Expand Down Expand Up @@ -418,11 +392,11 @@ func (o *operation) validate(mux *Mux, codecs map[codecKey]Codec) error {
// NB: This is fine to set even if a custom content-type is used via
// the use of google.api.HttpBody. The actual content-type and body
// data will be written via serverBodyPreparer implementation.
o.server.codec = codecs[codecKey{res: o.methodConf.resolver, name: CodecJSON}]
o.server.codec = codecs.get(CodecJSON, o.methodConf.resolver)
} else if _, supportsCodec := o.methodConf.codecNames[reqMeta.codec]; supportsCodec {
o.server.codec = o.client.codec
} else {
o.server.codec = codecs[codecKey{res: o.methodConf.resolver, name: o.methodConf.preferredCodec}]
o.server.codec = codecs.get(o.methodConf.preferredCodec, o.methodConf.resolver)
}

if reqMeta.compression != "" {
Expand Down Expand Up @@ -542,7 +516,7 @@ func (o *operation) handle() { //nolint:gocyclo
serverReqMeta := o.reqMeta
serverReqMeta.codec = o.server.codec.Name()
serverReqMeta.compression = o.server.reqCompression.Name()
serverReqMeta.acceptCompression = intersect(o.reqMeta.acceptCompression, o.canDecompress)
serverReqMeta.acceptCompression = o.compressors.intersection(o.reqMeta.acceptCompression)
o.server.protocol.addProtocolRequestHeaders(serverReqMeta, o.request.Header)

// Now we can define the transformed response writer (which delays
Expand Down Expand Up @@ -1062,7 +1036,7 @@ func (w *responseWriter) WriteHeader(statusCode int) {
respMeta.compression = "" // normalize to empty string
}
if respMeta.compression != "" {
respCompression, ok := w.op.compressionPools[respMeta.compression]
respCompression, ok := w.op.compressors[respMeta.compression]
if !ok {
w.reportError(fmt.Errorf("response indicates unsupported compression encoding %q", respMeta.compression))
return
Expand Down Expand Up @@ -1222,7 +1196,7 @@ func (w *responseWriter) flushHeaders() {
cliRespMeta := *w.respMeta
cliRespMeta.codec = w.op.client.codec.Name()
cliRespMeta.compression = w.op.client.respCompression.Name()
cliRespMeta.acceptCompression = intersect(w.respMeta.acceptCompression, w.op.canDecompress)
cliRespMeta.acceptCompression = w.op.compressors.intersection(w.respMeta.acceptCompression)
statusCode := w.op.client.protocol.addProtocolResponseHeaders(cliRespMeta, w.Header())
hasErr := w.respMeta.end != nil && w.respMeta.end.err != nil
// We only buffer full response for unary operations, so if we have an error,
Expand Down Expand Up @@ -2193,25 +2167,3 @@ func (l *requestLine) fromRequest(req *http.Request) {
l.queryString = req.URL.RawQuery
l.httpVersion = req.Proto
}

func intersect(setA, setB []string) []string {
length := len(setA)
if len(setB) < length {
length = len(setB)
}
if length == 0 {
// If either set is empty, the intersection is empty.
// We don't use nil since it is used in places as a sentinel.
return make([]string, 0)
}
result := make([]string, 0, length)
for _, item := range setA {
for _, other := range setB {
if other == item {
result = append(result, item)
break
}
}
}
return result
}
Loading

0 comments on commit c139bdc

Please sign in to comment.