Skip to content

Commit

Permalink
Query-Frontend: Add middleware to drop headers (#4298)
Browse files Browse the repository at this point in the history
* header strip ware

Signed-off-by: Joe Elliott <[email protected]>

* comment

Signed-off-by: Joe Elliott <[email protected]>

* changelog

Signed-off-by: Joe Elliott <[email protected]>

* remove header strip wear from metrics summary

Signed-off-by: Joe Elliott <[email protected]>

---------

Signed-off-by: Joe Elliott <[email protected]>
  • Loading branch information
joe-elliott authored Nov 12, 2024
1 parent 0ede155 commit 2bc0b62
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
* [ENHANCEMENT] Added `insecure-skip-verify` option in tempo-cli to skip SSL certificate validation when connecting to the S3 backend. [#44236](https://github.com/grafana/tempo/pull/4259) (@faridtmammadov)
* [ENHANCEMENT] Chore: delete spanlogger. [4312](https://github.com/grafana/tempo/pull/4312) (@javiermolinar)
* [ENHANCEMENT] Add `invalid_utf8` to reasons spanmetrics will discard spans. [#4293](https://github.com/grafana/tempo/pull/4293) (@zalegrala)
* [ENHANCEMENT] Reduce frontend and querier allocations by dropping HTTP headers early in the pipeline. [#4298](https://github.com/grafana/tempo/pull/4298) (@joe-elliott)
* [BUGFIX] Replace hedged requests roundtrips total with a counter. [#4063](https://github.com/grafana/tempo/pull/4063) [#4078](https://github.com/grafana/tempo/pull/4078) (@galalen)
* [BUGFIX] Metrics generators: Correctly drop from the ring before stopping ingestion to reduce drops during a rollout. [#4101](https://github.com/grafana/tempo/pull/4101) (@joe-elliott)
* [BUGFIX] Correctly handle 400 Bad Request and 404 Not Found in gRPC streaming [#4144](https://github.com/grafana/tempo/pull/4144) (@mapno)
Expand Down
4 changes: 2 additions & 2 deletions modules/frontend/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ type Config struct {
// A list of regexes for black listing requests, these will apply for every request regardless the endpoint
URLDenyList []string `yaml:"url_deny_list,omitempty"`

RequestWithWeights bool `yaml:"request_with_weights,omitempty"`
RetryWithWeights bool `yaml:"retry_with_weights,omitempty"`
// A list of headers allowed through the HTTP pipeline. Everything else will be stripped.
AllowedHeaders []string `yaml:"-"`
}

type SearchConfig struct {
Expand Down
7 changes: 7 additions & 0 deletions modules/frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t
traceIDStatusCodeWare := pipeline.NewStatusCodeAdjustWareWithAllowedCode(http.StatusNotFound)
urlDenyListWare := pipeline.NewURLDenyListWare(cfg.URLDenyList)
queryValidatorWare := pipeline.NewQueryValidatorWare()
headerStripWare := pipeline.NewStripHeadersWare(cfg.AllowedHeaders)

tracePipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
pipeline.NewWeightRequestWare(pipeline.TraceByID, cfg.Weights),
multiTenantMiddleware(cfg, logger),
Expand All @@ -109,6 +111,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t

searchPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
queryValidatorWare,
pipeline.NewWeightRequestWare(pipeline.TraceQLSearch, cfg.Weights),
Expand All @@ -120,6 +123,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t

searchTagsPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
pipeline.NewWeightRequestWare(pipeline.Default, cfg.Weights),
multiTenantMiddleware(cfg, logger),
Expand All @@ -130,6 +134,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t

searchTagValuesPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
pipeline.NewWeightRequestWare(pipeline.Default, cfg.Weights),
multiTenantMiddleware(cfg, logger),
Expand All @@ -152,6 +157,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t
// traceql metrics
queryRangePipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
queryValidatorWare,
pipeline.NewWeightRequestWare(pipeline.TraceQLMetrics, cfg.Weights),
Expand All @@ -163,6 +169,7 @@ func New(cfg Config, next pipeline.RoundTripper, o overrides.Interface, reader t

queryInstantPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
headerStripWare,
urlDenyListWare,
queryValidatorWare,
pipeline.NewWeightRequestWare(pipeline.TraceQLMetrics, cfg.Weights),
Expand Down
46 changes: 46 additions & 0 deletions modules/frontend/pipeline/async_strip_headers_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package pipeline

import (
"github.com/grafana/tempo/modules/frontend/combiner"
)

type stripHeadersWare struct {
allowed map[string]struct{}
next AsyncRoundTripper[combiner.PipelineResponse]
}

// NewStripHeadersWare creates a middleware that strips headers not in the allow list. This exists to reduce allocations further
// down the pipeline. All request headers should be handled at the Combiner/Collector levels. Once the request is in the pipeline
// nothing else needs HTTP headers. Stripping them out reduces allocations for copying, marshalling and unmashalling them to sometimes
// 100s of thousands of subrequests.
func NewStripHeadersWare(allowList []string) AsyncMiddleware[combiner.PipelineResponse] {
// build allowed map
allowed := make(map[string]struct{}, len(allowList))
for _, header := range allowList {
allowed[header] = struct{}{}
}

return AsyncMiddlewareFunc[combiner.PipelineResponse](func(next AsyncRoundTripper[combiner.PipelineResponse]) AsyncRoundTripper[combiner.PipelineResponse] {
return &stripHeadersWare{
next: next,
allowed: allowed,
}
})
}

func (c stripHeadersWare) RoundTrip(req Request) (Responses[combiner.PipelineResponse], error) {
httpReq := req.HTTPRequest()

if len(c.allowed) == 0 {
clear(httpReq.Header)
} else {
// clear out headers not in allow list
for header := range httpReq.Header {
if _, ok := c.allowed[header]; !ok {
delete(httpReq.Header, header)
}
}
}

return c.next.RoundTrip(req.CloneFromHTTPRequest(httpReq))
}
55 changes: 55 additions & 0 deletions modules/frontend/pipeline/async_strip_headers_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package pipeline

import (
"bytes"
"io"
"net/http"
"testing"

"github.com/grafana/tempo/modules/frontend/combiner"
"github.com/stretchr/testify/require"
)

func TestStripHeaders(t *testing.T) {
tcs := []struct {
name string
allow []string
headers map[string][]string
expected http.Header
}{
{
name: "empty allow list",
allow: []string{},
headers: map[string][]string{"header1": {"value1"}, "header2": {"value2"}},
expected: map[string][]string{},
},
{
name: "allow list with one header",
allow: []string{"header1"},
headers: map[string][]string{"header1": {"value1"}, "header2": {"value2"}},
expected: map[string][]string{"header1": {"value1"}},
},
}

for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) {
next := AsyncRoundTripperFunc[combiner.PipelineResponse](func(req Request) (Responses[combiner.PipelineResponse], error) {
actualHeaders := req.HTTPRequest().Header
require.Equal(t, tc.expected, actualHeaders)

return NewHTTPToAsyncResponse(&http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader([]byte{})),
}), nil
})

stripHeaders := NewStripHeadersWare(tc.allow).Wrap(next)

req, _ := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
req.Header = tc.headers

_, err := stripHeaders.RoundTrip(NewHTTPRequest(req))
require.NoError(t, err)
})
}
}
10 changes: 5 additions & 5 deletions modules/frontend/pipeline/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ var tracer = otel.Tracer("modules/frontend/pipeline")

type Request interface {
HTTPRequest() *http.Request
Context() context.Context

WithContext(context.Context)
CloneFromHTTPRequest(request *http.Request) Request

Weight() int
SetContext(context.Context)
Context() context.Context

SetWeight(int)
Weight() int

SetCacheKey(string)
CacheKey() string
Expand Down Expand Up @@ -51,7 +51,7 @@ func (r HTTPRequest) Context() context.Context {
return r.req.Context()
}

func (r *HTTPRequest) WithContext(ctx context.Context) {
func (r *HTTPRequest) SetContext(ctx context.Context) {
r.req = r.req.WithContext(ctx)
}

Expand Down
2 changes: 1 addition & 1 deletion modules/frontend/pipeline/sync_handler_retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (r retryWare) RoundTrip(req Request) (*http.Response, error) {
defer span.End()

// context propagation
req.WithContext(ctx)
req.SetContext(ctx)

tries := 0
defer func() { r.retriesCount.Observe(float64(tries)) }()
Expand Down
2 changes: 1 addition & 1 deletion modules/frontend/tag_sharder.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ func (s searchTagSharder) RoundTrip(pipelineRequest pipeline.Request) (pipeline.
}
ctx, span := tracer.Start(ctx, "frontend.ShardSearchTags")
defer span.End()
pipelineRequest.WithContext(ctx)
pipelineRequest.SetContext(ctx)

// calculate and enforce max search duration
maxDuration := s.maxDuration(tenantID)
Expand Down
2 changes: 1 addition & 1 deletion modules/frontend/traceid_sharder.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func newAsyncTraceIDSharder(cfg *TraceByIDConfig, logger log.Logger) pipeline.As
func (s asyncTraceSharder) RoundTrip(pipelineRequest pipeline.Request) (pipeline.Responses[combiner.PipelineResponse], error) {
ctx, span := tracer.Start(pipelineRequest.Context(), "frontend.ShardQuery")
defer span.End()
pipelineRequest.WithContext(ctx)
pipelineRequest.SetContext(ctx)

reqs, err := s.buildShardedRequests(pipelineRequest)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions modules/frontend/v1/request_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestRequestBatchContextError(t *testing.T) {

req := httptest.NewRequest("GET", "http://example.com", nil)
prequest := pipeline.NewHTTPRequest(req)
prequest.WithContext(ctx)
prequest.SetContext(ctx)

for i := 0; i < totalRequests-1; i++ {
_ = rb.add(&request{request: prequest})
Expand All @@ -61,7 +61,7 @@ func TestRequestBatchContextError(t *testing.T) {
// add a cancel context
cancelCtx, cancel := context.WithCancel(ctx)
prequest = pipeline.NewHTTPRequest(req)
prequest.WithContext(cancelCtx)
prequest.SetContext(cancelCtx)

_ = rb.add(&request{request: prequest})

Expand All @@ -83,7 +83,7 @@ func TestDoneChanCloses(_ *testing.T) {

req := httptest.NewRequest("GET", "http://example.com", nil)
prequest := pipeline.NewHTTPRequest(req)
prequest.WithContext(cancelCtx)
prequest.SetContext(cancelCtx)

for i := 0; i < totalRequests-1; i++ {
_ = rb.add(&request{request: prequest})
Expand Down

0 comments on commit 2bc0b62

Please sign in to comment.