diff --git a/api/internal/consumer/consumer.go b/api/internal/consumer/consumer.go index 3173553d..5736b26a 100644 --- a/api/internal/consumer/consumer.go +++ b/api/internal/consumer/consumer.go @@ -45,10 +45,12 @@ type Consumer struct { FilterConfigs map[string]*fmModel.ParsedFilterConfig // fields that generated from the configuration - CanSkipMethod map[string]bool FilterNames []string InitOnce sync.Once + CanSkipMethod map[string]bool CanSkipMethodOnce sync.Once + CanSyncRunMethod map[string]bool + // CanSyncRunMethod share the same sync.Once with CanSkipMethodOnce } func (c *Consumer) Unmarshal(s string) error { @@ -92,9 +94,10 @@ func (c *Consumer) InitConfigs() error { } c.FilterConfigs[name] = &fmModel.ParsedFilterConfig{ - Name: name, - ParsedConfig: conf, - Factory: p.Factory, + Name: name, + ParsedConfig: conf, + Factory: p.Factory, + SyncRunPhases: p.ConfigParser.NonBlockingPhases(), } } diff --git a/api/pkg/filtermanager/api/phase.go b/api/pkg/filtermanager/api/phase.go new file mode 100644 index 00000000..2a9accd7 --- /dev/null +++ b/api/pkg/filtermanager/api/phase.go @@ -0,0 +1,63 @@ +// Copyright The HTNN Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +type Phase int + +const ( + PhaseDecodeHeaders Phase = 0x01 + PhaseDecodeData Phase = 0x02 + PhaseDecodeTrailers Phase = 0x04 + PhaseDecodeRequest Phase = 0x08 + PhaseEncodeHeaders Phase = 0x10 + PhaseEncodeData Phase = 0x20 + PhaseEncodeTrailers Phase = 0x40 + PhaseEncodeResponse Phase = 0x80 + PhaseOnLog Phase = 0x100 +) + +var ( + AllPhases = PhaseDecodeHeaders | PhaseDecodeData | PhaseDecodeTrailers | PhaseDecodeRequest | + PhaseEncodeHeaders | PhaseEncodeData | PhaseEncodeTrailers | PhaseEncodeResponse | PhaseOnLog +) + +func (p Phase) Contains(phases Phase) bool { + return p&phases == phases +} + +func MethodToPhase(meth string) Phase { + switch meth { + case "DecodeHeaders": + return PhaseDecodeHeaders + case "DecodeData": + return PhaseDecodeData + case "DecodeTrailers": + return PhaseDecodeTrailers + case "DecodeRequest": + return PhaseDecodeRequest + case "EncodeHeaders": + return PhaseEncodeHeaders + case "EncodeData": + return PhaseEncodeData + case "EncodeTrailers": + return PhaseEncodeTrailers + case "EncodeResponse": + return PhaseEncodeResponse + case "OnLog": + return PhaseOnLog + default: + return 0 + } +} diff --git a/api/pkg/filtermanager/config.go b/api/pkg/filtermanager/config.go index 2ad5dbe2..6cd411fe 100644 --- a/api/pkg/filtermanager/config.go +++ b/api/pkg/filtermanager/config.go @@ -232,9 +232,10 @@ func (p *FilterManagerConfigParser) Parse(any *anypb.Any, callbacks capi.ConfigC }) } else { conf.parsed = append(conf.parsed, &model.ParsedFilterConfig{ - Name: proto.Name, - ParsedConfig: config, - Factory: plugin.Factory, + Name: proto.Name, + ParsedConfig: config, + Factory: plugin.Factory, + SyncRunPhases: plugin.ConfigParser.NonBlockingPhases(), }) _, ok := pkgPlugins.LoadPlugin(name).(pkgPlugins.ConsumerPlugin) diff --git a/api/pkg/filtermanager/filtermanager.go b/api/pkg/filtermanager/filtermanager.go index 2dca20f8..f89abcab 100644 --- a/api/pkg/filtermanager/filtermanager.go +++ b/api/pkg/filtermanager/filtermanager.go @@ -58,6 +58,14 @@ type filterManager struct { canSkipOnLog bool canSkipMethod map[string]bool + canSyncRunDecodeHeaders bool + canSyncRunDecodeData bool + canSyncRunDecodeTrailers bool + canSyncRunEncodeHeaders bool + canSyncRunEncodeData bool + canSyncRunEncodeTrailers bool + canSyncRunMethod map[string]bool + callbacks *filterManagerCallbackHandler config *filterManagerConfig @@ -87,6 +95,14 @@ func (m *filterManager) Reset() { m.canSkipEncodeTrailers = false m.canSkipOnLog = false + m.canSyncRunDecodeHeaders = false + m.canSyncRunDecodeData = false + m.canSyncRunDecodeTrailers = false + m.canSyncRunEncodeHeaders = false + m.canSyncRunEncodeData = false + m.canSyncRunEncodeTrailers = false + // m.canSyncRunMethod is reused across filters in the same config + m.callbacks.Reset() } @@ -106,19 +122,6 @@ func (m *filterManager) DebugModeEnabled() bool { return m.config.enableDebugMode } -type phase int - -const ( - phaseDecodeHeaders phase = iota - phaseDecodeData - phaseDecodeTrailers - phaseDecodeRequest - phaseEncodeHeaders - phaseEncodeData - phaseEncodeTrailers - phaseEncodeResponse -) - func newSkipMethodsMap() map[string]bool { return map[string]bool{ "DecodeHeaders": true, @@ -133,6 +136,19 @@ func newSkipMethodsMap() map[string]bool { } } +func newSyncRunMethodMap() map[string]bool { + return map[string]bool{ + "DecodeHeaders": true, + "DecodeData": true, + "DecodeRequest": true, + "DecodeTrailers": true, + "EncodeHeaders": true, + "EncodeData": true, + "EncodeResponse": true, + "EncodeTrailers": true, + } +} + func needLogExecution() bool { return api.GetLogLevel() <= api.LogLevelDebug } @@ -164,8 +180,10 @@ func FilterManagerFactory(c interface{}, cb capi.FilterCallbackHandler) (streamF fm.callbacks.FilterCallbackHandler = cb canSkipMethod := fm.canSkipMethod + canSyncRunMethod := fm.canSyncRunMethod if canSkipMethod == nil { canSkipMethod = newSkipMethodsMap() + canSyncRunMethod = newSyncRunMethodMap() } filters := make([]*model.FilterWrapper, len(parsedConfig)) @@ -188,6 +206,11 @@ func FilterManagerFactory(c interface{}, cb capi.FilterCallbackHandler) (streamF } canSkipMethod[meth] = canSkipMethod[meth] && !overridden definedMethod[meth] = overridden + + if overridden { + // canSkipMethod contains canSyncRunMethod so we can safely check it in the same loop + canSyncRunMethod[meth] = canSyncRunMethod[meth] && fc.SyncRunPhases.Contains(api.MethodToPhase(meth)) + } } if definedMethod["DecodeRequest"] { @@ -223,6 +246,7 @@ func FilterManagerFactory(c interface{}, cb capi.FilterCallbackHandler) (streamF if fm.canSkipMethod == nil { fm.canSkipMethod = canSkipMethod + fm.canSyncRunMethod = canSyncRunMethod } // We can't cache the slice of filters as it may be changed by consumer @@ -238,6 +262,14 @@ func FilterManagerFactory(c interface{}, cb capi.FilterCallbackHandler) (streamF fm.canSkipEncodeTrailers = fm.canSkipMethod["EncodeTrailers"] && fm.canSkipMethod["EncodeResponse"] fm.canSkipOnLog = fm.canSkipMethod["OnLog"] + // Similar to the skip check + fm.canSyncRunDecodeHeaders = fm.canSyncRunMethod["DecodeHeaders"] && fm.canSyncRunMethod["DecodeRequest"] && fm.config.initOnce == nil + fm.canSyncRunDecodeData = fm.canSyncRunMethod["DecodeData"] && fm.canSyncRunMethod["DecodeRequest"] + fm.canSyncRunDecodeTrailers = fm.canSyncRunMethod["DecodeTrailers"] && fm.canSyncRunMethod["DecodeRequest"] + fm.canSyncRunEncodeHeaders = fm.canSyncRunMethod["EncodeHeaders"] + fm.canSyncRunEncodeData = fm.canSyncRunMethod["EncodeData"] && fm.canSyncRunMethod["EncodeResponse"] + fm.canSyncRunEncodeTrailers = fm.canSyncRunMethod["EncodeTrailers"] && fm.canSyncRunMethod["EncodeResponse"] + return wrapFilterManager(fm) } @@ -250,14 +282,14 @@ func (m *filterManager) recordLocalReplyPluginName(name string) { // off a goroutine and the goroutine panics. } -func (m *filterManager) handleAction(res api.ResultAction, phase phase, filter *model.FilterWrapper) (needReturn bool) { +func (m *filterManager) handleAction(res api.ResultAction, phase api.Phase, filter *model.FilterWrapper) (needReturn bool) { if res == api.Continue { return false } if res == api.WaitAllData { - if phase == phaseDecodeHeaders { + if phase == api.PhaseDecodeHeaders { m.decodeRequestNeeded = true - } else if phase == phaseEncodeHeaders { + } else if phase == api.PhaseEncodeHeaders { m.encodeResponseNeeded = true } else { api.LogErrorf("WaitAllData only allowed when processing headers, phase: %v. "+ @@ -269,7 +301,7 @@ func (m *filterManager) handleAction(res api.ResultAction, phase phase, filter * switch v := res.(type) { case *api.LocalResponse: m.recordLocalReplyPluginName(filter.Name) - m.localReply(v, phase < phaseEncodeHeaders) + m.localReply(v, phase < api.PhaseEncodeHeaders) return true default: api.LogErrorf("unknown result action: %+v", v) @@ -349,171 +381,199 @@ func (m *filterManager) DecodeHeaders(headers capi.RequestHeaderMap, endStream b return capi.Continue } + if m.canSyncRunDecodeHeaders { + return m.decodeHeaders(headers, endStream) + } + + // We don't exact the repeated async pattern in a new method as it will require a closure to + // wrap `m.decodeHeaders`, which makes this method 25% slower. m.MarkRunningInGoThread(true) go func() { defer m.MarkRunningInGoThread(false) defer m.callbacks.DecoderFilterCallbacks().RecoverPanic() - var res api.ResultAction - - m.config.InitOnce() - if m.config.initFailed { - api.LogErrorf("error in plugin %s: %s", m.config.initFailedPluginName, m.config.initFailure) - m.recordLocalReplyPluginName(m.config.initFailedPluginName) - m.localReply(&api.LocalResponse{ - Code: 500, - }, true) - return - } - - m.hdrLock.Lock() - if m.reqHdr == nil { - m.reqHdr = &filterManagerRequestHeaderMap{ - RequestHeaderMap: headers, - } - } - m.hdrLock.Unlock() - if m.config.consumerFiltersEndAt != 0 { - for i := 0; i < m.config.consumerFiltersEndAt; i++ { - f := m.filters[i] - // We don't support DecodeRequest for now - res = f.DecodeHeaders(m.reqHdr, endStream) - if m.handleAction(res, phaseDecodeHeaders, f) { - return - } + + res := m.decodeHeaders(headers, endStream) + if res != capi.LocalReply { + m.callbacks.Continue(res, true) + } + }() + + return capi.Running +} + +func (m *filterManager) decodeHeaders(headers capi.RequestHeaderMap, endStream bool) capi.StatusType { + var res api.ResultAction + + m.config.InitOnce() + if m.config.initFailed { + api.LogErrorf("error in plugin %s: %s", m.config.initFailedPluginName, m.config.initFailure) + m.recordLocalReplyPluginName(m.config.initFailedPluginName) + m.localReply(&api.LocalResponse{ + Code: 500, + }, true) + return capi.LocalReply + } + + m.hdrLock.Lock() + if m.reqHdr == nil { + m.reqHdr = &filterManagerRequestHeaderMap{ + RequestHeaderMap: headers, + } + } + m.hdrLock.Unlock() + if m.config.consumerFiltersEndAt != 0 { + for i := 0; i < m.config.consumerFiltersEndAt; i++ { + f := m.filters[i] + // We don't support DecodeRequest for now + res = f.DecodeHeaders(m.reqHdr, endStream) + if m.handleAction(res, api.PhaseDecodeHeaders, f) { + return capi.LocalReply } + } - // we check consumer at the end of authn filters, so we can have multiple authn filters - // configured and the consumer will be set by any of them - c, ok := m.callbacks.consumer.(*consumer.Consumer) - if ok && len(c.FilterConfigs) > 0 { - api.LogDebugf("merge filters from consumer: %s", c.Name()) + // we check consumer at the end of authn filters, so we can have multiple authn filters + // configured and the consumer will be set by any of them + c, ok := m.callbacks.consumer.(*consumer.Consumer) + if ok && len(c.FilterConfigs) > 0 { + api.LogDebugf("merge filters from consumer: %s", c.Name()) - c.InitOnce.Do(func() { - names := make([]string, 0, len(c.FilterConfigs)) - for name, fc := range c.FilterConfigs { - names = append(names, name) + c.InitOnce.Do(func() { + names := make([]string, 0, len(c.FilterConfigs)) + for name, fc := range c.FilterConfigs { + names = append(names, name) - config := fc.ParsedConfig - if initer, ok := config.(pkgPlugins.Initer); ok { - // For now, we have nothing to provide as config callbacks - err := initer.Init(nil) - if err != nil { - fc.Factory = NewInternalErrorFactory(fc.Name, err) - } + config := fc.ParsedConfig + if initer, ok := config.(pkgPlugins.Initer); ok { + // For now, we have nothing to provide as config callbacks + err := initer.Init(nil) + if err != nil { + fc.Factory = NewInternalErrorFactory(fc.Name, err) } } + } - c.FilterNames = names - }) + c.FilterNames = names + }) - filterWrappers := make([]*model.FilterWrapper, len(c.FilterConfigs)) - for i, name := range c.FilterNames { - fc := c.FilterConfigs[name] - factory := fc.Factory - config := fc.ParsedConfig - f := factory(config, m.callbacks) - filterWrappers[i] = model.NewFilterWrapper(name, f) - } + filterWrappers := make([]*model.FilterWrapper, len(c.FilterConfigs)) + for i, name := range c.FilterNames { + fc := c.FilterConfigs[name] + factory := fc.Factory + config := fc.ParsedConfig + f := factory(config, m.callbacks) + filterWrappers[i] = model.NewFilterWrapper(name, f) + } - c.CanSkipMethodOnce.Do(func() { - canSkipMethod := newSkipMethodsMap() - for _, fw := range filterWrappers { - f := fw.Filter - for meth := range canSkipMethod { - overridden, err := reflectx.IsMethodOverridden(f, meth) - if err != nil { - api.LogErrorf("failed to check method %s in filter: %v", meth, err) - // canSkipMethod[meth] will be false - } - canSkipMethod[meth] = canSkipMethod[meth] && !overridden + c.CanSkipMethodOnce.Do(func() { + canSkipMethod := newSkipMethodsMap() + canSyncRunMethod := newSyncRunMethodMap() + for _, fw := range filterWrappers { + f := fw.Filter + fc := c.FilterConfigs[fw.Name] + for meth := range canSkipMethod { + overridden, err := reflectx.IsMethodOverridden(f, meth) + if err != nil { + api.LogErrorf("failed to check method %s in filter: %v", meth, err) + // canSkipMethod[meth] will be false } - } - c.CanSkipMethod = canSkipMethod - }) + canSkipMethod[meth] = canSkipMethod[meth] && !overridden - if needLogExecution() { - for _, fw := range filterWrappers { - f := fw.Filter - fw.Filter = NewLogExecutionFilter(fw.Name, f, m.callbacks) + if overridden { + canSyncRunMethod[meth] = canSyncRunMethod[meth] && fc.SyncRunPhases.Contains(api.MethodToPhase(meth)) + } } } + c.CanSkipMethod = canSkipMethod + c.CanSyncRunMethod = canSyncRunMethod + }) + + if needLogExecution() { + for _, fw := range filterWrappers { + f := fw.Filter + fw.Filter = NewLogExecutionFilter(fw.Name, f, m.callbacks) + } + } - if m.DebugModeEnabled() { - for _, fw := range filterWrappers { - f := fw.Filter - fw.Filter = NewDebugFilter(fw.Name, f, m.callbacks) - } + if m.DebugModeEnabled() { + for _, fw := range filterWrappers { + f := fw.Filter + fw.Filter = NewDebugFilter(fw.Name, f, m.callbacks) } + } - canSkipMethod := c.CanSkipMethod - m.canSkipDecodeData = m.canSkipDecodeData && canSkipMethod["DecodeData"] && canSkipMethod["DecodeRequest"] - m.canSkipDecodeTrailers = m.canSkipDecodeTrailers && canSkipMethod["DecodeTrailers"] && canSkipMethod["DecodeRequest"] - m.canSkipEncodeHeaders = m.canSkipEncodeData && canSkipMethod["EncodeHeaders"] - m.canSkipEncodeData = m.canSkipEncodeData && canSkipMethod["EncodeData"] && canSkipMethod["EncodeResponse"] - m.canSkipEncodeTrailers = m.canSkipEncodeTrailers && canSkipMethod["EncodeTrailers"] && canSkipMethod["EncodeResponse"] - m.canSkipOnLog = m.canSkipOnLog && canSkipMethod["OnLog"] + canSkipMethod := c.CanSkipMethod + m.canSkipDecodeData = m.canSkipDecodeData && canSkipMethod["DecodeData"] && canSkipMethod["DecodeRequest"] + m.canSkipDecodeTrailers = m.canSkipDecodeTrailers && canSkipMethod["DecodeTrailers"] && canSkipMethod["DecodeRequest"] + m.canSkipEncodeHeaders = m.canSkipEncodeData && canSkipMethod["EncodeHeaders"] + m.canSkipEncodeData = m.canSkipEncodeData && canSkipMethod["EncodeData"] && canSkipMethod["EncodeResponse"] + m.canSkipEncodeTrailers = m.canSkipEncodeTrailers && canSkipMethod["EncodeTrailers"] && canSkipMethod["EncodeResponse"] + m.canSkipOnLog = m.canSkipOnLog && canSkipMethod["OnLog"] + + // Similar to the skip check + canSyncRunMethod := c.CanSyncRunMethod + m.canSyncRunDecodeData = m.canSyncRunDecodeData && canSyncRunMethod["DecodeData"] && canSyncRunMethod["DecodeRequest"] + m.canSyncRunDecodeTrailers = m.canSyncRunDecodeTrailers && canSyncRunMethod["DecodeTrailers"] && canSyncRunMethod["DecodeRequest"] + m.canSyncRunEncodeHeaders = m.canSyncRunEncodeData && canSyncRunMethod["EncodeHeaders"] + m.canSyncRunEncodeData = m.canSyncRunEncodeData && canSyncRunMethod["EncodeData"] && canSyncRunMethod["EncodeResponse"] + m.canSyncRunEncodeTrailers = m.canSyncRunEncodeTrailers && canSyncRunMethod["EncodeTrailers"] && canSyncRunMethod["EncodeResponse"] + + // TODO: add field to control if merging is allowed + i := 0 + for _, f := range m.filters { + if c.FilterConfigs[f.Name] == nil { + m.filters[i] = f + i++ + } + } + m.filters = append(m.filters[:i], filterWrappers...) + sort.Slice(m.filters, func(i, j int) bool { + return pkgPlugins.ComparePluginOrder(m.filters[i].Name, m.filters[j].Name) + }) - // TODO: add field to control if merging is allowed - i := 0 + if api.GetLogLevel() <= api.LogLevelDebug { for _, f := range m.filters { - if c.FilterConfigs[f.Name] == nil { - m.filters[i] = f - i++ - } - } - m.filters = append(m.filters[:i], filterWrappers...) - sort.Slice(m.filters, func(i, j int) bool { - return pkgPlugins.ComparePluginOrder(m.filters[i].Name, m.filters[j].Name) - }) - - if api.GetLogLevel() <= api.LogLevelDebug { - for _, f := range m.filters { - fc := c.FilterConfigs[f.Name] - if fc == nil { - // the plugin is not from consumer - for _, cfg := range m.config.parsed { - if cfg.Name == f.Name { - fc = cfg - break - } + fc := c.FilterConfigs[f.Name] + if fc == nil { + // the plugin is not from consumer + for _, cfg := range m.config.parsed { + if cfg.Name == f.Name { + fc = cfg + break } } - api.LogDebugf("after merged consumer, plugin: %s, config: %+v", f.Name, fc.ParsedConfig) } + api.LogDebugf("after merged consumer, plugin: %s, config: %+v", f.Name, fc.ParsedConfig) } } } + } - for i := m.config.consumerFiltersEndAt; i < len(m.filters); i++ { - f := m.filters[i] - res = f.DecodeHeaders(m.reqHdr, endStream) - if m.handleAction(res, phaseDecodeHeaders, f) { - return - } + for i := m.config.consumerFiltersEndAt; i < len(m.filters); i++ { + f := m.filters[i] + res = f.DecodeHeaders(m.reqHdr, endStream) + if m.handleAction(res, api.PhaseDecodeHeaders, f) { + return capi.LocalReply + } - if m.decodeRequestNeeded { - m.decodeRequestNeeded = false - if !endStream { - m.decodeIdx = i - // some filters, like authorization with request body, need to - // have a whole body before passing to the next filter - m.callbacks.Continue(capi.StopAndBuffer, true) - return - } + if m.decodeRequestNeeded { + m.decodeRequestNeeded = false + if !endStream { + m.decodeIdx = i + // some filters, like authorization with request body, need to + // have a whole body before passing to the next filter + return capi.StopAndBuffer + } - // no body and no trailers - res = f.DecodeRequest(m.reqHdr, nil, nil) - if m.handleAction(res, phaseDecodeRequest, f) { - return - } + // no body and no trailers + res = f.DecodeRequest(m.reqHdr, nil, nil) + if m.handleAction(res, api.PhaseDecodeRequest, f) { + return capi.LocalReply } } + } - m.callbacks.Continue(capi.Continue, true) - }() - - return capi.Running + return capi.Continue } func (m *filterManager) DecodeRequest(headers api.RequestHeaderMap, buf capi.BufferInstance, trailers capi.RequestTrailerMap) bool { @@ -527,7 +587,7 @@ func (m *filterManager) DecodeRequest(headers api.RequestHeaderMap, buf capi.Buf for i := 0; i < m.decodeIdx; i++ { f := m.filters[i] res = f.DecodeData(buf, endStreamInBody) - if m.handleAction(res, phaseDecodeData, f) { + if m.handleAction(res, api.PhaseDecodeData, f) { return false } } @@ -538,7 +598,7 @@ func (m *filterManager) DecodeRequest(headers api.RequestHeaderMap, buf capi.Buf for i := 0; i < m.decodeIdx; i++ { f := m.filters[i] res = f.DecodeTrailers(trailers) - if m.handleAction(res, phaseDecodeTrailers, f) { + if m.handleAction(res, api.PhaseDecodeTrailers, f) { return false } } @@ -546,7 +606,7 @@ func (m *filterManager) DecodeRequest(headers api.RequestHeaderMap, buf capi.Buf f := m.filters[m.decodeIdx] res = f.DecodeRequest(headers, buf, trailers) - if m.handleAction(res, phaseDecodeRequest, f) { + if m.handleAction(res, api.PhaseDecodeRequest, f) { return false } @@ -558,7 +618,7 @@ func (m *filterManager) DecodeRequest(headers api.RequestHeaderMap, buf capi.Buf // The endStream in DecodeHeaders indicates whether there is a body. // The body always exists when we hit this path. res = f.DecodeHeaders(headers, false) - if m.handleAction(res, phaseDecodeHeaders, f) { + if m.handleAction(res, api.PhaseDecodeHeaders, f) { return false } if m.decodeRequestNeeded { @@ -573,7 +633,7 @@ func (m *filterManager) DecodeRequest(headers api.RequestHeaderMap, buf capi.Buf for j := m.decodeIdx + 1; j < i; j++ { f := m.filters[j] res = f.DecodeData(buf, endStreamInBody) - if m.handleAction(res, phaseDecodeData, f) { + if m.handleAction(res, api.PhaseDecodeData, f) { return false } } @@ -583,7 +643,7 @@ func (m *filterManager) DecodeRequest(headers api.RequestHeaderMap, buf capi.Buf for j := m.decodeIdx + 1; j < i; j++ { f := m.filters[j] res = f.DecodeTrailers(trailers) - if m.handleAction(res, phaseDecodeTrailers, f) { + if m.handleAction(res, api.PhaseDecodeTrailers, f) { return false } } @@ -594,7 +654,7 @@ func (m *filterManager) DecodeRequest(headers api.RequestHeaderMap, buf capi.Buf m.decodeIdx = i f := m.filters[m.decodeIdx] res = f.DecodeRequest(headers, buf, trailers) - if m.handleAction(res, phaseDecodeRequest, f) { + if m.handleAction(res, api.PhaseDecodeRequest, f) { return false } i++ @@ -609,85 +669,109 @@ func (m *filterManager) DecodeData(buf capi.BufferInstance, endStream bool) capi return capi.Continue } + if m.canSyncRunDecodeData { + return m.decodeData(buf, endStream) + } + m.MarkRunningInGoThread(true) go func() { defer m.MarkRunningInGoThread(false) defer m.callbacks.DecoderFilterCallbacks().RecoverPanic() - var res api.ResultAction - - // We have discussed a lot about how to support processing data both streamingly and - // as a whole body. Here are some solutions we have considered: - // 1. let Envoy process data streamingly, and do buffering in Go. This solution is costly - // and may be broken if the buffered data at Go side is rewritten by later C++ filter. - // 2. separate the filters which need a whole body in a separate C++ filter. It can't - // be done without a special control plane. - // 3. add multiple virtual C++ filters to Envoy when init the Envoy Golang filter. It - // is complex because we need to share and move the state between multiple Envoy C++ - // filter. - // 4. when a filter requires a whole body, all the filters will use a whole body. - // Otherwise, streaming processing is used. It's simple and already satisfies our - // most demand, so we choose this way for now. - - status := capi.Continue - n := len(m.filters) - if m.decodeIdx == -1 { - // every filter doesn't need buffered body - for i := 0; i < n; i++ { - f := m.filters[i] - res = f.DecodeData(buf, endStream) - if m.handleAction(res, phaseDecodeData, f) { - return - } - } - } else if endStream { - conti := m.DecodeRequest(m.reqHdr, buf, nil) - if !conti { - return - } - } else { - m.reqBuf = buf - status = capi.StopAndBuffer - } - m.callbacks.Continue(status, true) + res := m.decodeData(buf, endStream) + if res != capi.LocalReply { + m.callbacks.Continue(res, true) + } }() return capi.Running } +func (m *filterManager) decodeData(buf capi.BufferInstance, endStream bool) capi.StatusType { + var res api.ResultAction + + // We have discussed a lot about how to support processing data both streamingly and + // as a whole body. Here are some solutions we have considered: + // 1. let Envoy process data streamingly, and do buffering in Go. This solution is costly + // and may be broken if the buffered data at Go side is rewritten by later C++ filter. + // 2. separate the filters which need a whole body in a separate C++ filter. It can't + // be done without a special control plane. + // 3. add multiple virtual C++ filters to Envoy when init the Envoy Golang filter. It + // is complex because we need to share and move the state between multiple Envoy C++ + // filter. + // 4. when a filter requires a whole body, all the filters will use a whole body. + // Otherwise, streaming processing is used. It's simple and already satisfies our + // most demand, so we choose this way for now. + + status := capi.Continue + n := len(m.filters) + if m.decodeIdx == -1 { + // every filter doesn't need buffered body + for i := 0; i < n; i++ { + f := m.filters[i] + res = f.DecodeData(buf, endStream) + if m.handleAction(res, api.PhaseDecodeData, f) { + return capi.LocalReply + } + } + } else if endStream { + conti := m.DecodeRequest(m.reqHdr, buf, nil) + if !conti { + return capi.LocalReply + } + } else { + m.reqBuf = buf + status = capi.StopAndBuffer + } + + return status +} + func (m *filterManager) DecodeTrailers(trailers capi.RequestTrailerMap) capi.StatusType { if m.canSkipDecodeTrailers { return capi.Continue } + if m.canSyncRunDecodeTrailers { + return m.decodeTrailers(trailers) + } + m.MarkRunningInGoThread(true) go func() { defer m.MarkRunningInGoThread(false) defer m.callbacks.DecoderFilterCallbacks().RecoverPanic() - var res api.ResultAction - if m.decodeIdx == -1 { - for _, f := range m.filters { - res = f.DecodeTrailers(trailers) - if m.handleAction(res, phaseDecodeTrailers, f) { - return - } - } - } else { - conti := m.DecodeRequest(m.reqHdr, m.reqBuf, trailers) - if !conti { - return - } + res := m.decodeTrailers(trailers) + if res != capi.LocalReply { + m.callbacks.Continue(res, true) } - - m.callbacks.Continue(capi.Continue, true) }() return capi.Running } +func (m *filterManager) decodeTrailers(trailers capi.RequestTrailerMap) capi.StatusType { + var res api.ResultAction + + if m.decodeIdx == -1 { + for _, f := range m.filters { + res = f.DecodeTrailers(trailers) + if m.handleAction(res, api.PhaseDecodeTrailers, f) { + return capi.LocalReply + } + } + } else { + conti := m.DecodeRequest(m.reqHdr, m.reqBuf, trailers) + if !conti { + return capi.LocalReply + } + } + + return capi.Continue +} + func (m *filterManager) EncodeHeaders(headers capi.ResponseHeaderMap, endStream bool) capi.StatusType { if !supportGettingHeadersOnLog { // Ensure the headers are cached on the Go side. @@ -699,44 +783,55 @@ func (m *filterManager) EncodeHeaders(headers capi.ResponseHeaderMap, endStream return capi.Continue } + if m.canSyncRunEncodeHeaders { + return m.encodeHeaders(headers, endStream) + } + m.MarkRunningInGoThread(true) go func() { defer m.MarkRunningInGoThread(false) defer m.callbacks.EncoderFilterCallbacks().RecoverPanic() - var res api.ResultAction - m.hdrLock.Lock() - m.rspHdr = headers - m.hdrLock.Unlock() - n := len(m.filters) - for i := n - 1; i >= 0; i-- { - f := m.filters[i] - res = f.EncodeHeaders(headers, endStream) - if m.handleAction(res, phaseEncodeHeaders, f) { - return - } + res := m.encodeHeaders(headers, endStream) + if res != capi.LocalReply { + m.callbacks.Continue(res, false) + } + }() - if m.encodeResponseNeeded { - m.encodeResponseNeeded = false - if !endStream { - m.encodeIdx = i - m.callbacks.Continue(capi.StopAndBuffer, false) - return - } + return capi.Running +} - // no body - res = f.EncodeResponse(headers, nil, nil) - if m.handleAction(res, phaseEncodeResponse, f) { - return - } - } +func (m *filterManager) encodeHeaders(headers capi.ResponseHeaderMap, endStream bool) capi.StatusType { + var res api.ResultAction + + m.hdrLock.Lock() + m.rspHdr = headers + m.hdrLock.Unlock() + n := len(m.filters) + for i := n - 1; i >= 0; i-- { + f := m.filters[i] + res = f.EncodeHeaders(headers, endStream) + if m.handleAction(res, api.PhaseEncodeHeaders, f) { + return capi.LocalReply } - m.callbacks.Continue(capi.Continue, false) - }() + if m.encodeResponseNeeded { + m.encodeResponseNeeded = false + if !endStream { + m.encodeIdx = i + return capi.StopAndBuffer + } - return capi.Running + // no body + res = f.EncodeResponse(headers, nil, nil) + if m.handleAction(res, api.PhaseEncodeResponse, f) { + return capi.LocalReply + } + } + } + + return capi.Continue } func (m *filterManager) EncodeResponse(headers api.ResponseHeaderMap, buf capi.BufferInstance, trailers capi.ResponseTrailerMap) bool { @@ -750,7 +845,7 @@ func (m *filterManager) EncodeResponse(headers api.ResponseHeaderMap, buf capi.B for i := n - 1; i > m.encodeIdx; i-- { f := m.filters[i] res = f.EncodeData(buf, endStreamInBody) - if m.handleAction(res, phaseEncodeData, f) { + if m.handleAction(res, api.PhaseEncodeData, f) { return false } } @@ -760,7 +855,7 @@ func (m *filterManager) EncodeResponse(headers api.ResponseHeaderMap, buf capi.B for i := n - 1; i > m.encodeIdx; i-- { f := m.filters[i] res = f.EncodeTrailers(trailers) - if m.handleAction(res, phaseEncodeTrailers, f) { + if m.handleAction(res, api.PhaseEncodeTrailers, f) { return false } } @@ -768,7 +863,7 @@ func (m *filterManager) EncodeResponse(headers api.ResponseHeaderMap, buf capi.B f := m.filters[m.encodeIdx] res = f.EncodeResponse(m.rspHdr, buf, nil) - if m.handleAction(res, phaseEncodeResponse, f) { + if m.handleAction(res, api.PhaseEncodeResponse, f) { return false } @@ -777,7 +872,7 @@ func (m *filterManager) EncodeResponse(headers api.ResponseHeaderMap, buf capi.B for ; i >= 0; i-- { f := m.filters[i] res = f.EncodeHeaders(m.rspHdr, false) - if m.handleAction(res, phaseEncodeHeaders, f) { + if m.handleAction(res, api.PhaseEncodeHeaders, f) { return false } if m.encodeResponseNeeded { @@ -790,7 +885,7 @@ func (m *filterManager) EncodeResponse(headers api.ResponseHeaderMap, buf capi.B for j := m.encodeIdx - 1; j > i; j-- { f := m.filters[j] res = f.EncodeData(buf, endStreamInBody) - if m.handleAction(res, phaseEncodeData, f) { + if m.handleAction(res, api.PhaseEncodeData, f) { return false } } @@ -800,7 +895,7 @@ func (m *filterManager) EncodeResponse(headers api.ResponseHeaderMap, buf capi.B for j := m.encodeIdx - 1; j > i; j-- { f := m.filters[j] res = f.EncodeTrailers(trailers) - if m.handleAction(res, phaseEncodeTrailers, f) { + if m.handleAction(res, api.PhaseEncodeTrailers, f) { return false } } @@ -811,7 +906,7 @@ func (m *filterManager) EncodeResponse(headers api.ResponseHeaderMap, buf capi.B m.encodeIdx = i f := m.filters[m.encodeIdx] res = f.EncodeResponse(m.rspHdr, buf, nil) - if m.handleAction(res, phaseEncodeResponse, f) { + if m.handleAction(res, api.PhaseEncodeResponse, f) { return false } i-- @@ -826,65 +921,90 @@ func (m *filterManager) EncodeData(buf capi.BufferInstance, endStream bool) capi return capi.Continue } + if m.canSyncRunEncodeData { + return m.encodeData(buf, endStream) + } + m.MarkRunningInGoThread(true) go func() { defer m.MarkRunningInGoThread(false) defer m.callbacks.EncoderFilterCallbacks().RecoverPanic() - var res api.ResultAction - - status := capi.Continue - n := len(m.filters) - if m.encodeIdx == -1 { - // every filter doesn't need buffered body - for i := n - 1; i >= 0; i-- { - f := m.filters[i] - res = f.EncodeData(buf, endStream) - if m.handleAction(res, phaseEncodeData, f) { - return - } - } - } else { - // FIXME: we should implement like the decode part here, but it will cause server closed the stream without sending trailers - conti := m.EncodeResponse(m.rspHdr, buf, nil) - if !conti { - return - } - } - m.callbacks.Continue(status, false) + res := m.encodeData(buf, endStream) + if res != capi.LocalReply { + m.callbacks.Continue(res, false) + } }() return capi.Running } +func (m *filterManager) encodeData(buf capi.BufferInstance, endStream bool) capi.StatusType { + var res api.ResultAction + + status := capi.Continue + n := len(m.filters) + if m.encodeIdx == -1 { + // every filter doesn't need buffered body + for i := n - 1; i >= 0; i-- { + f := m.filters[i] + res = f.EncodeData(buf, endStream) + if m.handleAction(res, api.PhaseEncodeData, f) { + return capi.LocalReply + } + } + } else { + // FIXME: we should implement like the decode part here, but it will cause server closed the stream without sending trailers. + // As a result, we don't process the trailers in EncodeResponse for now. + conti := m.EncodeResponse(m.rspHdr, buf, nil) + if !conti { + return capi.LocalReply + } + } + + return status +} + func (m *filterManager) EncodeTrailers(trailers capi.ResponseTrailerMap) capi.StatusType { if m.canSkipEncodeTrailers { return capi.Continue } + if m.canSyncRunEncodeTrailers { + return m.encodeTrailers(trailers) + } + m.MarkRunningInGoThread(true) go func() { defer m.MarkRunningInGoThread(false) defer m.callbacks.EncoderFilterCallbacks().RecoverPanic() - var res api.ResultAction - if m.encodeIdx == -1 { - for _, f := range m.filters { - res = f.EncodeTrailers(trailers) - if m.handleAction(res, phaseEncodeTrailers, f) { - return - } - } + res := m.encodeTrailers(trailers) + if res != capi.LocalReply { + m.callbacks.Continue(res, false) } - - m.callbacks.Continue(capi.Continue, false) }() return capi.Running } +func (m *filterManager) encodeTrailers(trailers capi.ResponseTrailerMap) capi.StatusType { + var res api.ResultAction + + if m.encodeIdx == -1 { + for _, f := range m.filters { + res = f.EncodeTrailers(trailers) + if m.handleAction(res, api.PhaseEncodeTrailers, f) { + return capi.LocalReply + } + } + } + + return capi.Continue +} + func (m *filterManager) runOnLogPhase(reqHdr api.RequestHeaderMap, reqTrailer api.RequestTrailerMap, rspHdr api.ResponseHeaderMap, rspTrailer api.ResponseTrailerMap) { diff --git a/api/pkg/filtermanager/filtermanager_benchmark_test.go b/api/pkg/filtermanager/filtermanager_benchmark_test.go index 345c5b1d..5c0c2057 100644 --- a/api/pkg/filtermanager/filtermanager_benchmark_test.go +++ b/api/pkg/filtermanager/filtermanager_benchmark_test.go @@ -58,6 +58,32 @@ func BenchmarkFilterManagerAllPhase(b *testing.B) { } } +func BenchmarkFilterManagerAllPhaseCanSyncRun(b *testing.B) { + envoy.DisableLogInTest() // otherwise, there is too much output + cb := envoy.NewCAPIFilterCallbackHandler() + config := initFilterManagerConfig("ns") + config.parsed = []*model.ParsedFilterConfig{ + { + Name: "allPhase", + Factory: PassThroughFactory, + SyncRunPhases: api.AllPhases, + }, + } + reqHdr := envoy.NewRequestHeaderMap(http.Header{}) + respHdr := envoy.NewResponseHeaderMap(http.Header{}) + reqBuf := envoy.NewBufferInstance([]byte{}) + respBuf := envoy.NewBufferInstance([]byte{}) + + for n := 0; n < b.N; n++ { + m := unwrapFilterManager(FilterManagerFactory(config, cb)) + m.DecodeHeaders(reqHdr, false) + m.DecodeData(reqBuf, true) + m.EncodeHeaders(respHdr, false) + m.EncodeData(respBuf, true) + m.OnLog(reqHdr, nil, respHdr, nil) + } +} + func regularFactory(c interface{}, callbacks api.FilterCallbackHandler) api.Filter { return ®ularFilter{} } diff --git a/api/pkg/filtermanager/filtermanager_test.go b/api/pkg/filtermanager/filtermanager_test.go index 07450341..2ba2ea94 100644 --- a/api/pkg/filtermanager/filtermanager_test.go +++ b/api/pkg/filtermanager/filtermanager_test.go @@ -392,6 +392,27 @@ func TestSkipMethodWhenThereAreMultiFilters(t *testing.T) { } } +type addRespConf struct { + hdrName string +} + +func addRespFactory(c interface{}, _ api.FilterCallbackHandler) api.Filter { + return &addRespFilter{ + conf: c.(addRespConf), + } +} + +type addRespFilter struct { + api.PassThroughFilter + + conf addRespConf +} + +func (f *addRespFilter) EncodeHeaders(headers api.ResponseHeaderMap, endStream bool) api.ResultAction { + headers.Set(f.conf.hdrName, "htnn") + return api.Continue +} + type setConsumerConf struct { Consumers map[string]*internalConsumer.Consumer } @@ -432,6 +453,14 @@ func TestFiltersFromConsumer(t *testing.T) { hdrName: fmt.Sprintf("x-htnn-consumer-%d", i), }, }, + "4_add_resp": { + Name: "4_add_resp", + Factory: addRespFactory, + ParsedConfig: addRespConf{ + hdrName: fmt.Sprintf("x-htnn-resp-%d", i), + }, + SyncRunPhases: api.PhaseEncodeHeaders, + }, }, } if i%2 == 0 { @@ -477,11 +506,12 @@ func TestFiltersFromConsumer(t *testing.T) { cb.WaitContinued() if idx%2 == 0 { assert.Equal(t, false, m.canSkipOnLog) - assert.Equal(t, 3, len(m.filters)) + assert.Equal(t, 4, len(m.filters)) } else { assert.Equal(t, true, m.canSkipOnLog) - assert.Equal(t, 2, len(m.filters)) + assert.Equal(t, 3, len(m.filters)) } + assert.Equal(t, true, m.canSyncRunEncodeHeaders) _, ok := hdr.Get("x-htnn-route") assert.False(t, ok) @@ -667,3 +697,33 @@ func TestDoNotRecycleInUsedFilterManager(t *testing.T) { } wg.Wait() } + +func TestSyncRunWhenThereAreMultiFilters(t *testing.T) { + cb := envoy.NewCAPIFilterCallbackHandler() + config := initFilterManagerConfig("ns") + config.parsed = []*model.ParsedFilterConfig{ + { + Name: "add_req", + Factory: addReqFactory, + ParsedConfig: addReqConf{ + hdrName: "x-htnn-route", + }, + SyncRunPhases: api.PhaseDecodeTrailers, + }, + { + Name: "access_field_on_log", + Factory: accessFieldOnLogFactory, + SyncRunPhases: api.AllPhases, + }, + } + + for i := 0; i < 2; i++ { + m := unwrapFilterManager(FilterManagerFactory(config, cb)) + assert.Equal(t, false, m.canSyncRunDecodeHeaders) + assert.Equal(t, true, m.canSyncRunDecodeData) + assert.Equal(t, true, m.canSyncRunDecodeTrailers) + assert.Equal(t, true, m.canSyncRunEncodeHeaders) + assert.Equal(t, true, m.canSyncRunEncodeData) + assert.Equal(t, true, m.canSyncRunEncodeTrailers) + } +} diff --git a/api/pkg/filtermanager/model/model.go b/api/pkg/filtermanager/model/model.go index b11aefbe..9ca6d087 100644 --- a/api/pkg/filtermanager/model/model.go +++ b/api/pkg/filtermanager/model/model.go @@ -30,11 +30,12 @@ type FilterConfig struct { } type ParsedFilterConfig struct { - Name string - ParsedConfig interface{} - InitOnce sync.Once - InitFailure error - Factory api.FilterFactory + Name string + ParsedConfig interface{} + InitOnce sync.Once + InitFailure error + Factory api.FilterFactory + SyncRunPhases api.Phase } type FilterWrapper struct { diff --git a/api/pkg/plugins/plugins.go b/api/pkg/plugins/plugins.go index 4e8df0d9..26b8de71 100644 --- a/api/pkg/plugins/plugins.go +++ b/api/pkg/plugins/plugins.go @@ -37,6 +37,7 @@ var ( type FilterConfigParser interface { Parse(input interface{}) (interface{}, error) Merge(parentConfig interface{}, childConfig interface{}) interface{} + NonBlockingPhases() api.Phase } type FilterFactoryAndParser struct { @@ -197,6 +198,10 @@ func (p *PluginMethodDefaultImpl) Merge(parent interface{}, child interface{}) i return child } +func (p *PluginMethodDefaultImpl) NonBlockingPhases() api.Phase { + return 0 +} + func ComparePluginOrder(a, b string) bool { return ComparePluginOrderInt(a, b) < 0 } diff --git a/api/pkg/plugins/type.go b/api/pkg/plugins/type.go index 27c16741..1c258586 100644 --- a/api/pkg/plugins/type.go +++ b/api/pkg/plugins/type.go @@ -142,6 +142,7 @@ type Plugin interface { Type() PluginType Order() PluginOrder Merge(parent interface{}, child interface{}) interface{} + NonBlockingPhases() api.Phase } type Initer interface { diff --git a/types/plugins/demo/config.go b/types/plugins/demo/config.go index 2685c426..4d179172 100644 --- a/types/plugins/demo/config.go +++ b/types/plugins/demo/config.go @@ -51,6 +51,23 @@ func (p *Plugin) Order() plugins.PluginOrder { } } +// NonBlockingPhases returns the phases of the plugin which can be run non-blockingly, default to 0. +// If the plugin's filter doesn't contain any blocking operation, it should return true. +// A blocking operation can be: +// 1. I/O operation +// 2. Sleep +// 3. Blocking syscall +// 4. Context switch like waiting on a channel +// and more. +// +// If a phase only contains non-blocking plugins, it will be executed synchorously, which is +// more effective. +// +// Phase OnLog is always be executed synchorously so we don't need to specify it here. +func (p *Plugin) NonBlockingPhases() api.Phase { + return api.PhaseDecodeHeaders | api.PhaseEncodeHeaders +} + // Config returns api.PluginConfig's implementation used during configuration processing func (p *Plugin) Config() api.PluginConfig { return &Config{}