Skip to content

Commit

Permalink
check it per phase
Browse files Browse the repository at this point in the history
Signed-off-by: spacewander <[email protected]>
  • Loading branch information
spacewander committed Nov 18, 2024
1 parent bf085f0 commit 0233a39
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 71 deletions.
8 changes: 4 additions & 4 deletions api/internal/consumer/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ func (c *Consumer) InitConfigs() error {
}

c.FilterConfigs[name] = &fmModel.ParsedFilterConfig{
Name: name,
ParsedConfig: conf,
Factory: p.Factory,
CanSyncRun: p.ConfigParser.IsNonBlocking(),
Name: name,
ParsedConfig: conf,
Factory: p.Factory,
SyncRunPhases: p.ConfigParser.NonBlockingPhases(),
}
}

Expand Down
63 changes: 63 additions & 0 deletions api/pkg/filtermanager/api/phase.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright The HTNN Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package api

type Phase int

const (
PhaseDecodeHeaders Phase = 0x01
PhaseDecodeData = 0x02
PhaseDecodeTrailers = 0x04
PhaseDecodeRequest = 0x08
PhaseEncodeHeaders = 0x10
PhaseEncodeData = 0x20
PhaseEncodeTrailers = 0x40
PhaseEncodeResponse = 0x80
PhaseOnLog = 0x100
)

var (
AllPhases = PhaseDecodeHeaders | PhaseDecodeData | PhaseDecodeTrailers | PhaseDecodeRequest |
PhaseEncodeHeaders | PhaseEncodeData | PhaseEncodeTrailers | PhaseEncodeResponse | PhaseOnLog
)

func (p Phase) Contains(phases Phase) bool {
return p&phases == phases
}

func MethodToPhase(meth string) Phase {
switch meth {
case "DecodeHeaders":
return PhaseDecodeHeaders
case "DecodeData":
return PhaseDecodeData
case "DecodeTrailers":
return PhaseDecodeTrailers
case "DecodeRequest":
return PhaseDecodeRequest
case "EncodeHeaders":
return PhaseEncodeHeaders
case "EncodeData":
return PhaseEncodeData
case "EncodeTrailers":
return PhaseEncodeTrailers
case "EncodeResponse":
return PhaseEncodeResponse
case "OnLog":
return PhaseOnLog
default:
return 0

Check warning on line 61 in api/pkg/filtermanager/api/phase.go

View check run for this annotation

Codecov / codecov/patch

api/pkg/filtermanager/api/phase.go#L58-L61

Added lines #L58 - L61 were not covered by tests
}
}
8 changes: 4 additions & 4 deletions api/pkg/filtermanager/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,10 @@ func (p *FilterManagerConfigParser) Parse(any *anypb.Any, callbacks capi.ConfigC
})
} else {
conf.parsed = append(conf.parsed, &model.ParsedFilterConfig{
Name: proto.Name,
ParsedConfig: config,
Factory: plugin.Factory,
CanSyncRun: plugin.ConfigParser.IsNonBlocking(),
Name: proto.Name,
ParsedConfig: config,
Factory: plugin.Factory,
SyncRunPhases: plugin.ConfigParser.NonBlockingPhases(),
})

_, ok := pkgPlugins.LoadPlugin(name).(pkgPlugins.ConsumerPlugin)
Expand Down
71 changes: 29 additions & 42 deletions api/pkg/filtermanager/filtermanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,6 @@ func (m *filterManager) DebugModeEnabled() bool {
return m.config.enableDebugMode
}

type phase int

const (
phaseDecodeHeaders phase = iota
phaseDecodeData
phaseDecodeTrailers
phaseDecodeRequest
phaseEncodeHeaders
phaseEncodeData
phaseEncodeTrailers
phaseEncodeResponse
)

func newSkipMethodsMap() map[string]bool {
return map[string]bool{
"DecodeHeaders": true,
Expand Down Expand Up @@ -222,7 +209,7 @@ func FilterManagerFactory(c interface{}, cb capi.FilterCallbackHandler) (streamF

if overridden {
// canSkipMethod contains canSyncRunMethod so we can safely check it in the same loop
canSyncRunMethod[meth] = canSyncRunMethod[meth] && fc.CanSyncRun
canSyncRunMethod[meth] = canSyncRunMethod[meth] && fc.SyncRunPhases.Contains(api.MethodToPhase(meth))
}
}

Expand Down Expand Up @@ -295,14 +282,14 @@ func (m *filterManager) recordLocalReplyPluginName(name string) {
// off a goroutine and the goroutine panics.
}

func (m *filterManager) handleAction(res api.ResultAction, phase phase, filter *model.FilterWrapper) (needReturn bool) {
func (m *filterManager) handleAction(res api.ResultAction, phase api.Phase, filter *model.FilterWrapper) (needReturn bool) {
if res == api.Continue {
return false
}
if res == api.WaitAllData {
if phase == phaseDecodeHeaders {
if phase == api.PhaseDecodeHeaders {
m.decodeRequestNeeded = true
} else if phase == phaseEncodeHeaders {
} else if phase == api.PhaseEncodeHeaders {
m.encodeResponseNeeded = true
} else {
api.LogErrorf("WaitAllData only allowed when processing headers, phase: %v. "+
Expand All @@ -314,7 +301,7 @@ func (m *filterManager) handleAction(res api.ResultAction, phase phase, filter *
switch v := res.(type) {
case *api.LocalResponse:
m.recordLocalReplyPluginName(filter.Name)
m.localReply(v, phase < phaseEncodeHeaders)
m.localReply(v, phase < api.PhaseEncodeHeaders)
return true
default:
api.LogErrorf("unknown result action: %+v", v)
Expand Down Expand Up @@ -440,7 +427,7 @@ func (m *filterManager) decodeHeaders(headers capi.RequestHeaderMap, endStream b
f := m.filters[i]
// We don't support DecodeRequest for now
res = f.DecodeHeaders(m.reqHdr, endStream)
if m.handleAction(res, phaseDecodeHeaders, f) {
if m.handleAction(res, api.PhaseDecodeHeaders, f) {
return capi.LocalReply
}
}
Expand Down Expand Up @@ -493,7 +480,7 @@ func (m *filterManager) decodeHeaders(headers capi.RequestHeaderMap, endStream b
canSkipMethod[meth] = canSkipMethod[meth] && !overridden

if overridden {
canSyncRunMethod[meth] = canSyncRunMethod[meth] && fc.CanSyncRun
canSyncRunMethod[meth] = canSyncRunMethod[meth] && fc.SyncRunPhases.Contains(api.MethodToPhase(meth))
}
}
}
Expand Down Expand Up @@ -565,7 +552,7 @@ func (m *filterManager) decodeHeaders(headers capi.RequestHeaderMap, endStream b
for i := m.config.consumerFiltersEndAt; i < len(m.filters); i++ {
f := m.filters[i]
res = f.DecodeHeaders(m.reqHdr, endStream)
if m.handleAction(res, phaseDecodeHeaders, f) {
if m.handleAction(res, api.PhaseDecodeHeaders, f) {
return capi.LocalReply
}

Expand All @@ -580,7 +567,7 @@ func (m *filterManager) decodeHeaders(headers capi.RequestHeaderMap, endStream b

// no body and no trailers
res = f.DecodeRequest(m.reqHdr, nil, nil)
if m.handleAction(res, phaseDecodeRequest, f) {
if m.handleAction(res, api.PhaseDecodeRequest, f) {
return capi.LocalReply
}
}
Expand All @@ -600,7 +587,7 @@ func (m *filterManager) DecodeRequest(headers api.RequestHeaderMap, buf capi.Buf
for i := 0; i < m.decodeIdx; i++ {
f := m.filters[i]
res = f.DecodeData(buf, endStreamInBody)
if m.handleAction(res, phaseDecodeData, f) {
if m.handleAction(res, api.PhaseDecodeData, f) {
return false
}
}
Expand All @@ -611,15 +598,15 @@ func (m *filterManager) DecodeRequest(headers api.RequestHeaderMap, buf capi.Buf
for i := 0; i < m.decodeIdx; i++ {
f := m.filters[i]
res = f.DecodeTrailers(trailers)
if m.handleAction(res, phaseDecodeTrailers, f) {
if m.handleAction(res, api.PhaseDecodeTrailers, f) {
return false
}
}
}

f := m.filters[m.decodeIdx]
res = f.DecodeRequest(headers, buf, trailers)
if m.handleAction(res, phaseDecodeRequest, f) {
if m.handleAction(res, api.PhaseDecodeRequest, f) {
return false
}

Expand All @@ -631,7 +618,7 @@ func (m *filterManager) DecodeRequest(headers api.RequestHeaderMap, buf capi.Buf
// The endStream in DecodeHeaders indicates whether there is a body.
// The body always exists when we hit this path.
res = f.DecodeHeaders(headers, false)
if m.handleAction(res, phaseDecodeHeaders, f) {
if m.handleAction(res, api.PhaseDecodeHeaders, f) {
return false
}
if m.decodeRequestNeeded {
Expand All @@ -646,7 +633,7 @@ func (m *filterManager) DecodeRequest(headers api.RequestHeaderMap, buf capi.Buf
for j := m.decodeIdx + 1; j < i; j++ {
f := m.filters[j]
res = f.DecodeData(buf, endStreamInBody)
if m.handleAction(res, phaseDecodeData, f) {
if m.handleAction(res, api.PhaseDecodeData, f) {
return false
}
}
Expand All @@ -656,7 +643,7 @@ func (m *filterManager) DecodeRequest(headers api.RequestHeaderMap, buf capi.Buf
for j := m.decodeIdx + 1; j < i; j++ {
f := m.filters[j]
res = f.DecodeTrailers(trailers)
if m.handleAction(res, phaseDecodeTrailers, f) {
if m.handleAction(res, api.PhaseDecodeTrailers, f) {
return false
}
}
Expand All @@ -667,7 +654,7 @@ func (m *filterManager) DecodeRequest(headers api.RequestHeaderMap, buf capi.Buf
m.decodeIdx = i
f := m.filters[m.decodeIdx]
res = f.DecodeRequest(headers, buf, trailers)
if m.handleAction(res, phaseDecodeRequest, f) {
if m.handleAction(res, api.PhaseDecodeRequest, f) {
return false
}
i++
Expand Down Expand Up @@ -724,7 +711,7 @@ func (m *filterManager) decodeData(buf capi.BufferInstance, endStream bool) capi
for i := 0; i < n; i++ {
f := m.filters[i]
res = f.DecodeData(buf, endStream)
if m.handleAction(res, phaseDecodeData, f) {
if m.handleAction(res, api.PhaseDecodeData, f) {
return capi.LocalReply
}
}
Expand Down Expand Up @@ -771,7 +758,7 @@ func (m *filterManager) decodeTrailers(trailers capi.RequestTrailerMap) capi.Sta
if m.decodeIdx == -1 {
for _, f := range m.filters {
res = f.DecodeTrailers(trailers)
if m.handleAction(res, phaseDecodeTrailers, f) {
if m.handleAction(res, api.PhaseDecodeTrailers, f) {
return capi.LocalReply
}
}
Expand Down Expand Up @@ -825,7 +812,7 @@ func (m *filterManager) encodeHeaders(headers capi.ResponseHeaderMap, endStream
for i := n - 1; i >= 0; i-- {
f := m.filters[i]
res = f.EncodeHeaders(headers, endStream)
if m.handleAction(res, phaseEncodeHeaders, f) {
if m.handleAction(res, api.PhaseEncodeHeaders, f) {
return capi.LocalReply
}

Expand All @@ -838,7 +825,7 @@ func (m *filterManager) encodeHeaders(headers capi.ResponseHeaderMap, endStream

// no body
res = f.EncodeResponse(headers, nil, nil)
if m.handleAction(res, phaseEncodeResponse, f) {
if m.handleAction(res, api.PhaseEncodeResponse, f) {
return capi.LocalReply
}

Check warning on line 830 in api/pkg/filtermanager/filtermanager.go

View check run for this annotation

Codecov / codecov/patch

api/pkg/filtermanager/filtermanager.go#L829-L830

Added lines #L829 - L830 were not covered by tests
}
Expand All @@ -858,7 +845,7 @@ func (m *filterManager) EncodeResponse(headers api.ResponseHeaderMap, buf capi.B
for i := n - 1; i > m.encodeIdx; i-- {
f := m.filters[i]
res = f.EncodeData(buf, endStreamInBody)
if m.handleAction(res, phaseEncodeData, f) {
if m.handleAction(res, api.PhaseEncodeData, f) {
return false
}
}
Expand All @@ -868,15 +855,15 @@ func (m *filterManager) EncodeResponse(headers api.ResponseHeaderMap, buf capi.B
for i := n - 1; i > m.encodeIdx; i-- {
f := m.filters[i]
res = f.EncodeTrailers(trailers)
if m.handleAction(res, phaseEncodeTrailers, f) {
if m.handleAction(res, api.PhaseEncodeTrailers, f) {

Check warning on line 858 in api/pkg/filtermanager/filtermanager.go

View check run for this annotation

Codecov / codecov/patch

api/pkg/filtermanager/filtermanager.go#L858

Added line #L858 was not covered by tests
return false
}
}
}

f := m.filters[m.encodeIdx]
res = f.EncodeResponse(m.rspHdr, buf, nil)
if m.handleAction(res, phaseEncodeResponse, f) {
if m.handleAction(res, api.PhaseEncodeResponse, f) {
return false
}

Expand All @@ -885,7 +872,7 @@ func (m *filterManager) EncodeResponse(headers api.ResponseHeaderMap, buf capi.B
for ; i >= 0; i-- {
f := m.filters[i]
res = f.EncodeHeaders(m.rspHdr, false)
if m.handleAction(res, phaseEncodeHeaders, f) {
if m.handleAction(res, api.PhaseEncodeHeaders, f) {
return false
}
if m.encodeResponseNeeded {
Expand All @@ -898,7 +885,7 @@ func (m *filterManager) EncodeResponse(headers api.ResponseHeaderMap, buf capi.B
for j := m.encodeIdx - 1; j > i; j-- {
f := m.filters[j]
res = f.EncodeData(buf, endStreamInBody)
if m.handleAction(res, phaseEncodeData, f) {
if m.handleAction(res, api.PhaseEncodeData, f) {
return false
}
}
Expand All @@ -908,7 +895,7 @@ func (m *filterManager) EncodeResponse(headers api.ResponseHeaderMap, buf capi.B
for j := m.encodeIdx - 1; j > i; j-- {
f := m.filters[j]
res = f.EncodeTrailers(trailers)
if m.handleAction(res, phaseEncodeTrailers, f) {
if m.handleAction(res, api.PhaseEncodeTrailers, f) {

Check warning on line 898 in api/pkg/filtermanager/filtermanager.go

View check run for this annotation

Codecov / codecov/patch

api/pkg/filtermanager/filtermanager.go#L898

Added line #L898 was not covered by tests
return false
}
}
Expand All @@ -919,7 +906,7 @@ func (m *filterManager) EncodeResponse(headers api.ResponseHeaderMap, buf capi.B
m.encodeIdx = i
f := m.filters[m.encodeIdx]
res = f.EncodeResponse(m.rspHdr, buf, nil)
if m.handleAction(res, phaseEncodeResponse, f) {
if m.handleAction(res, api.PhaseEncodeResponse, f) {
return false
}
i--
Expand Down Expand Up @@ -963,7 +950,7 @@ func (m *filterManager) encodeData(buf capi.BufferInstance, endStream bool) capi
for i := n - 1; i >= 0; i-- {
f := m.filters[i]
res = f.EncodeData(buf, endStream)
if m.handleAction(res, phaseEncodeData, f) {
if m.handleAction(res, api.PhaseEncodeData, f) {
return capi.LocalReply
}
}
Expand Down Expand Up @@ -1009,7 +996,7 @@ func (m *filterManager) encodeTrailers(trailers capi.ResponseTrailerMap) capi.St
if m.encodeIdx == -1 {
for _, f := range m.filters {
res = f.EncodeTrailers(trailers)
if m.handleAction(res, phaseEncodeTrailers, f) {
if m.handleAction(res, api.PhaseEncodeTrailers, f) {
return capi.LocalReply
}

Check warning on line 1001 in api/pkg/filtermanager/filtermanager.go

View check run for this annotation

Codecov / codecov/patch

api/pkg/filtermanager/filtermanager.go#L1000-L1001

Added lines #L1000 - L1001 were not covered by tests
}
Expand Down
6 changes: 3 additions & 3 deletions api/pkg/filtermanager/filtermanager_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ func BenchmarkFilterManagerAllPhaseCanSyncRun(b *testing.B) {
config := initFilterManagerConfig("ns")
config.parsed = []*model.ParsedFilterConfig{
{
Name: "allPhase",
Factory: PassThroughFactory,
CanSyncRun: true,
Name: "allPhase",
Factory: PassThroughFactory,
SyncRunPhases: api.AllPhases,
},
}
reqHdr := envoy.NewRequestHeaderMap(http.Header{})
Expand Down
Loading

0 comments on commit 0233a39

Please sign in to comment.