Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop AsHandler method from Mux #77

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,12 @@ Finally, you can register the middleware handler with an `http.Server` or

```go
// The Mux can be used as the sole handler for an HTTP server.
err := http.Serve(listener, vanguardMux.AsHandler())
err := http.Serve(listener, vanguardMux)

// Or it can be used alongside other handlers, all registered with
// the same http.ServeMux.
mux := http.NewServeMux()
mux.Handle("/", vanguardMux.AsHandler())
mux.Handle("/", vanguardMux)
err := http.Serve(listener, mux)
```
The above example registers the handler for the root path. This is useful
Expand Down
4 changes: 0 additions & 4 deletions buffer_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ type bufferPool struct {
sync.Pool
}

func newBufferPool() *bufferPool {
return &bufferPool{}
}

func (b *bufferPool) Get() *bytes.Buffer {
if buffer, ok := b.Pool.Get().(*bytes.Buffer); ok {
buffer.Reset()
Expand Down
60 changes: 40 additions & 20 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ type RESTCodec interface {
// The returned codec implements StableCodec, in addition to
// Codec.
func DefaultProtoCodec(res TypeResolver) Codec {
return &protoCodec{Resolver: res}
return protoCodec{
UnmarshalOptions: proto.UnmarshalOptions{Resolver: res},
}
}

// DefaultJSONCodec is the default codec factory used for the codec named
Expand All @@ -110,30 +112,30 @@ type JSONCodec struct {
UnmarshalOptions protojson.UnmarshalOptions
}

var _ StableCodec = (*JSONCodec)(nil)
var _ RESTCodec = (*JSONCodec)(nil)
var _ StableCodec = JSONCodec{}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks unrelated and not sure we need to do it, but not blocking for now since we're reviewing the Codec interfaces separately.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It removes having to have the cache to convert codecs from a map[string] func(TypeResolver) Codec to map[codecKey]Codec where the codecKey is a type of struct{string, TypeResolver} which means we can drop the constructor for the handler.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little less efficient, could move the cache to the methodConfig but if we are changing the codec layout easier to call the factory func.

var _ RESTCodec = JSONCodec{}

func (j *JSONCodec) Name() string {
func (j JSONCodec) Name() string {
return CodecJSON
}

func (j *JSONCodec) IsBinary() bool {
func (j JSONCodec) IsBinary() bool {
return false
}

func (j *JSONCodec) MarshalAppend(base []byte, msg proto.Message) ([]byte, error) {
func (j JSONCodec) MarshalAppend(base []byte, msg proto.Message) ([]byte, error) {
return j.MarshalOptions.MarshalAppend(base, msg)
}

func (j *JSONCodec) MarshalAppendStable(base []byte, msg proto.Message) ([]byte, error) {
func (j JSONCodec) MarshalAppendStable(base []byte, msg proto.Message) ([]byte, error) {
data, err := j.MarshalOptions.MarshalAppend(base, msg)
if err != nil {
return nil, err
}
return jsonStabilize(data)
}

func (j *JSONCodec) MarshalAppendField(base []byte, msg proto.Message, field protoreflect.FieldDescriptor) ([]byte, error) {
func (j JSONCodec) MarshalAppendField(base []byte, msg proto.Message, field protoreflect.FieldDescriptor) ([]byte, error) {
if field.Message() != nil && field.Cardinality() != protoreflect.Repeated {
return j.MarshalAppend(base, msg.ProtoReflect().Get(field).Message().Interface())
}
Expand Down Expand Up @@ -191,7 +193,7 @@ func (j *JSONCodec) MarshalAppendField(base []byte, msg proto.Message, field pro
return nil, fmt.Errorf("JSON does not contain key %s", fieldName)
}

func (j *JSONCodec) UnmarshalField(data []byte, msg proto.Message, field protoreflect.FieldDescriptor) error {
func (j JSONCodec) UnmarshalField(data []byte, msg proto.Message, field protoreflect.FieldDescriptor) error {
if field.Message() != nil && field.Cardinality() != protoreflect.Repeated {
return j.Unmarshal(data, msg.ProtoReflect().Mutable(field).Message().Interface())
}
Expand All @@ -211,11 +213,11 @@ func (j *JSONCodec) UnmarshalField(data []byte, msg proto.Message, field protore
return j.Unmarshal(buf.Bytes(), msg)
}

func (j *JSONCodec) Unmarshal(bytes []byte, msg proto.Message) error {
func (j JSONCodec) Unmarshal(bytes []byte, msg proto.Message) error {
return j.UnmarshalOptions.Unmarshal(bytes, msg)
}

func (j *JSONCodec) fieldName(field protoreflect.FieldDescriptor) string {
func (j JSONCodec) fieldName(field protoreflect.FieldDescriptor) string {
if !j.MarshalOptions.UseProtoNames {
return field.JSONName()
}
Expand All @@ -226,28 +228,33 @@ func (j *JSONCodec) fieldName(field protoreflect.FieldDescriptor) string {
return string(field.Name())
}

type protoCodec proto.UnmarshalOptions
type protoCodec struct {
proto.MarshalOptions
proto.UnmarshalOptions
}

var _ StableCodec = (*protoCodec)(nil)
var _ StableCodec = protoCodec{}

func (p *protoCodec) Name() string {
func (p protoCodec) Name() string {
return CodecProto
}

func (p *protoCodec) IsBinary() bool {
func (p protoCodec) IsBinary() bool {
return true
}

func (p *protoCodec) MarshalAppend(base []byte, msg proto.Message) ([]byte, error) {
func (p protoCodec) MarshalAppend(base []byte, msg proto.Message) ([]byte, error) {
return proto.MarshalOptions{}.MarshalAppend(base, msg)
}

func (p *protoCodec) MarshalAppendStable(base []byte, msg proto.Message) ([]byte, error) {
return proto.MarshalOptions{Deterministic: true}.MarshalAppend(base, msg)
func (p protoCodec) MarshalAppendStable(base []byte, msg proto.Message) ([]byte, error) {
opts := p.MarshalOptions
opts.Deterministic = true
return opts.MarshalAppend(base, msg)
}

func (p *protoCodec) Unmarshal(bytes []byte, msg proto.Message) error {
return (*proto.UnmarshalOptions)(p).Unmarshal(bytes, msg)
func (p protoCodec) Unmarshal(bytes []byte, msg proto.Message) error {
return p.UnmarshalOptions.Unmarshal(bytes, msg)
}

func jsonStabilize(data []byte) ([]byte, error) {
Expand All @@ -260,3 +267,16 @@ func jsonStabilize(data []byte) ([]byte, error) {
}
return buf.Bytes(), nil
}

type codecMap map[string]func(TypeResolver) Codec

func (m codecMap) get(name string, resolver TypeResolver) Codec {
if m == nil {
return nil
}
codecFn, ok := m[name]
if !ok {
return nil
}
return codecFn(resolver)
}
21 changes: 21 additions & 0 deletions compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,27 @@ func DefaultGzipDecompressor() connect.Decompressor {
return &gzip.Reader{}
}

type compressionMap map[string]*compressionPool

func (m compressionMap) intersection(names []string) []string {
length := len(names)
if len(m) < length {
length = len(m)
}
if length == 0 {
// If either set is empty, the intersection is empty.
// We don't use nil since it is used in places as a sentinel.
return make([]string, 0)
}
intersection := make([]string, 0, length)
for _, name := range names {
if _, ok := m[name]; ok {
intersection = append(intersection, name)
}
}
return intersection
}

type compressionPool struct {
name string
decompressors sync.Pool
Expand Down
100 changes: 26 additions & 74 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,17 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
)

type handler struct {
mux *Mux
bufferPool *bufferPool
codecs map[codecKey]Codec
canDecompress []string
}

func (h *handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
op := h.newOperation(writer, request)
err := op.validate(h.mux, h.codecs)
// ServeHTTP implements http.Handler.
func (m *Mux) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
op := m.newOperation(writer, request)
err := op.validate(m, m.codecs)

useUnknownHandler := h.mux.UnknownHandler != nil && errors.Is(err, errNotFound)
useUnknownHandler := m.UnknownHandler != nil && errors.Is(err, errNotFound)
var callback func(context.Context, Operation) (Hooks, error)
if op.methodConf != nil {
callback = op.methodConf.hooksCallback
} else {
callback = h.mux.HooksCallback
callback = m.HooksCallback
}
if callback != nil {
var hookErr error
Expand All @@ -58,7 +52,7 @@ func (h *handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
}
if useUnknownHandler {
request.Header = op.originalHeaders // restore headers, just in case initialization removed keys
h.mux.UnknownHandler.ServeHTTP(writer, request)
m.UnknownHandler.ServeHTTP(writer, request)
return
}

Expand All @@ -81,16 +75,15 @@ func (h *handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
op.handle()
}

func (h *handler) newOperation(writer http.ResponseWriter, request *http.Request) *operation {
func (m *Mux) newOperation(writer http.ResponseWriter, request *http.Request) *operation {
ctx, cancel := context.WithCancel(request.Context())
request = request.WithContext(ctx)
op := &operation{
writer: writer,
request: request,
cancel: cancel,
bufferPool: h.bufferPool,
canDecompress: h.canDecompress,
compressionPools: h.mux.compressionPools,
writer: writer,
request: request,
cancel: cancel,
bufferPool: &m.bufferPool,
compressors: m.compressors,
}
op.requestLine.fromRequest(request)
return op
Expand Down Expand Up @@ -235,33 +228,14 @@ func classifyRequest(req *http.Request) (clientProtocolHandler, url.Values) {
}
}

type codecKey struct {
res TypeResolver
name string
}

func newCodecMap(methodConfigs map[string]*methodConfig, codecs map[string]func(TypeResolver) Codec) map[codecKey]Codec {
result := make(map[codecKey]Codec, len(codecs))
for _, conf := range methodConfigs {
for codecName, codecFactory := range codecs {
key := codecKey{res: conf.resolver, name: codecName}
if _, exists := result[key]; !exists {
result[key] = codecFactory(conf.resolver)
}
}
}
return result
}

// operation represents a single HTTP operation, which maps to an incoming HTTP request.
// It tracks properties needed to implement protocol transformation.
type operation struct {
writer http.ResponseWriter
request *http.Request
cancel context.CancelFunc
bufferPool *bufferPool
canDecompress []string
compressionPools map[string]*compressionPool
writer http.ResponseWriter
request *http.Request
cancel context.CancelFunc
bufferPool *bufferPool
compressors compressionMap

queryVars url.Values
originalHeaders http.Header
Expand Down Expand Up @@ -325,7 +299,7 @@ func (o *operation) HandlerInfo() PeerInfo {

func (o *operation) doNotImplement() {}

func (o *operation) validate(mux *Mux, codecs map[codecKey]Codec) error {
func (o *operation) validate(mux *Mux, codecs codecMap) error {
// Identify the protocol.
clientProtoHandler, queryVars := classifyRequest(o.request)
if clientProtoHandler == nil {
Expand Down Expand Up @@ -381,12 +355,12 @@ func (o *operation) validate(mux *Mux, codecs map[codecKey]Codec) error {
}
if reqMeta.compression != "" {
var ok bool
o.client.reqCompression, ok = o.compressionPools[reqMeta.compression]
o.client.reqCompression, ok = o.compressors[reqMeta.compression]
if !ok {
return newHTTPError(http.StatusUnsupportedMediaType, "%q compression not supported", reqMeta.compression)
}
}
o.client.codec = codecs[codecKey{res: o.methodConf.resolver, name: reqMeta.codec}]
o.client.codec = codecs.get(reqMeta.codec, o.methodConf.resolver)
if o.client.codec == nil {
return newHTTPError(http.StatusUnsupportedMediaType, "%q sub-format not supported", reqMeta.codec)
}
Expand Down Expand Up @@ -418,11 +392,11 @@ func (o *operation) validate(mux *Mux, codecs map[codecKey]Codec) error {
// NB: This is fine to set even if a custom content-type is used via
// the use of google.api.HttpBody. The actual content-type and body
// data will be written via serverBodyPreparer implementation.
o.server.codec = codecs[codecKey{res: o.methodConf.resolver, name: CodecJSON}]
o.server.codec = codecs.get(CodecJSON, o.methodConf.resolver)
} else if _, supportsCodec := o.methodConf.codecNames[reqMeta.codec]; supportsCodec {
o.server.codec = o.client.codec
} else {
o.server.codec = codecs[codecKey{res: o.methodConf.resolver, name: o.methodConf.preferredCodec}]
o.server.codec = codecs.get(o.methodConf.preferredCodec, o.methodConf.resolver)
}

if reqMeta.compression != "" {
Expand Down Expand Up @@ -542,7 +516,7 @@ func (o *operation) handle() { //nolint:gocyclo
serverReqMeta := o.reqMeta
serverReqMeta.codec = o.server.codec.Name()
serverReqMeta.compression = o.server.reqCompression.Name()
serverReqMeta.acceptCompression = intersect(o.reqMeta.acceptCompression, o.canDecompress)
serverReqMeta.acceptCompression = o.compressors.intersection(o.reqMeta.acceptCompression)
o.server.protocol.addProtocolRequestHeaders(serverReqMeta, o.request.Header)

// Now we can define the transformed response writer (which delays
Expand Down Expand Up @@ -1062,7 +1036,7 @@ func (w *responseWriter) WriteHeader(statusCode int) {
respMeta.compression = "" // normalize to empty string
}
if respMeta.compression != "" {
respCompression, ok := w.op.compressionPools[respMeta.compression]
respCompression, ok := w.op.compressors[respMeta.compression]
if !ok {
w.reportError(fmt.Errorf("response indicates unsupported compression encoding %q", respMeta.compression))
return
Expand Down Expand Up @@ -1222,7 +1196,7 @@ func (w *responseWriter) flushHeaders() {
cliRespMeta := *w.respMeta
cliRespMeta.codec = w.op.client.codec.Name()
cliRespMeta.compression = w.op.client.respCompression.Name()
cliRespMeta.acceptCompression = intersect(w.respMeta.acceptCompression, w.op.canDecompress)
cliRespMeta.acceptCompression = w.op.compressors.intersection(w.respMeta.acceptCompression)
statusCode := w.op.client.protocol.addProtocolResponseHeaders(cliRespMeta, w.Header())
hasErr := w.respMeta.end != nil && w.respMeta.end.err != nil
// We only buffer full response for unary operations, so if we have an error,
Expand Down Expand Up @@ -2193,25 +2167,3 @@ func (l *requestLine) fromRequest(req *http.Request) {
l.queryString = req.URL.RawQuery
l.httpVersion = req.Proto
}

func intersect(setA, setB []string) []string {
length := len(setA)
if len(setB) < length {
length = len(setB)
}
if length == 0 {
// If either set is empty, the intersection is empty.
// We don't use nil since it is used in places as a sentinel.
return make([]string, 0)
}
result := make([]string, 0, length)
for _, item := range setA {
for _, other := range setB {
if other == item {
result = append(result, item)
break
}
}
}
return result
}
Loading
Loading