Skip to content

Commit

Permalink
feat: added new middleware to validate request query values (#3993)
Browse files Browse the repository at this point in the history
* feat: added new middleware to validate request query values

* remove unused test

* use different regex for tests

* rename middlewares and add e2e tests
  • Loading branch information
javiermolinar authored Aug 22, 2024
1 parent 457a563 commit 8a0b0e7
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 22 deletions.
20 changes: 14 additions & 6 deletions integration/e2e/query_range_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,13 @@ sendLoop:
callQueryRange(t, tempo.Endpoint(3200), query, debugMode)
})
}

res := doRequest(t, tempo.Endpoint(3200), "{. a}")
require.Equal(t, 400, res.StatusCode)
}

func callQueryRange(t *testing.T, endpoint, query string, printBody bool) {
url := buildURL(endpoint, fmt.Sprintf("%s with(exemplars=true)", query))
req, err := http.NewRequest(http.MethodGet, url, nil)
require.NoError(t, err)

res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
res := doRequest(t, endpoint, query)
require.Equal(t, http.StatusOK, res.StatusCode)

// Read body and print it
Expand All @@ -89,6 +87,16 @@ func callQueryRange(t *testing.T, endpoint, query string, printBody bool) {
require.GreaterOrEqual(t, exemplarCount, 1)
}

func doRequest(t *testing.T, endpoint, query string) *http.Response {
url := buildURL(endpoint, fmt.Sprintf("%s with(exemplars=true)", query))
req, err := http.NewRequest(http.MethodGet, url, nil)
require.NoError(t, err)

res, err := http.DefaultClient.Do(req)
require.NoError(t, err)
return res
}

func buildURL(endpoint, query string) string {
return fmt.Sprintf(
"http://%s/api/metrics/query_range?query=%s&start=%d&end=%d&step=%s",
Expand Down
4 changes: 4 additions & 0 deletions modules/frontend/frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo
statusCodeWare := pipeline.NewStatusCodeAdjustWare()
traceIDStatusCodeWare := pipeline.NewStatusCodeAdjustWareWithAllowedCode(http.StatusNotFound)
urlDenyListWare := pipeline.NewURLDenyListWare(cfg.URLDenyList)
queryValidatorWare := pipeline.NewQueryValidatorWare()

tracePipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
Expand All @@ -106,6 +107,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo
searchPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
urlDenyListWare,
queryValidatorWare,
multiTenantMiddleware(cfg, logger),
newAsyncSearchSharder(reader, o, cfg.Search.Sharder, logger),
},
Expand Down Expand Up @@ -134,6 +136,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo
metricsPipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
urlDenyListWare,
queryValidatorWare,
multiTenantUnsupportedMiddleware(cfg, logger),
},
[]pipeline.Middleware{statusCodeWare, retryWare},
Expand All @@ -143,6 +146,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo
queryRangePipeline := pipeline.Build(
[]pipeline.AsyncMiddleware[combiner.PipelineResponse]{
urlDenyListWare,
queryValidatorWare,
multiTenantMiddleware(cfg, logger),
newAsyncQueryRangeSharder(reader, o, cfg.Metrics.Sharder, logger),
},
Expand Down
4 changes: 2 additions & 2 deletions modules/frontend/metrics_query_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ func newMetricsQueryInstantHTTPHandler(cfg Config, next pipeline.AsyncRoundTripp
if err != nil {
level.Error(logger).Log("msg", "query instant: query range combiner failed", "err", err)
return &http.Response{
StatusCode: http.StatusInternalServerError,
Status: http.StatusText(http.StatusInternalServerError),
StatusCode: http.StatusBadRequest,
Status: http.StatusText(http.StatusBadRequest),
Body: io.NopCloser(strings.NewReader(err.Error())),
}, nil
}
Expand Down
4 changes: 2 additions & 2 deletions modules/frontend/metrics_query_range_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ func newMetricsQueryRangeHTTPHandler(cfg Config, next pipeline.AsyncRoundTripper
if err != nil {
level.Error(logger).Log("msg", "query range: query range combiner failed", "err", err)
return &http.Response{
StatusCode: http.StatusInternalServerError,
Status: http.StatusText(http.StatusInternalServerError),
StatusCode: http.StatusBadRequest,
Status: http.StatusText(http.StatusBadRequest),
Body: io.NopCloser(strings.NewReader(err.Error())),
}, nil
}
Expand Down
50 changes: 50 additions & 0 deletions modules/frontend/pipeline/async_query_validator_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package pipeline

import (
"fmt"
"net/url"

"github.com/grafana/tempo/modules/frontend/combiner"
"github.com/grafana/tempo/pkg/traceql"
)

type queryValidatorWare struct {
next AsyncRoundTripper[combiner.PipelineResponse]
}

func NewQueryValidatorWare() AsyncMiddleware[combiner.PipelineResponse] {
return AsyncMiddlewareFunc[combiner.PipelineResponse](func(next AsyncRoundTripper[combiner.PipelineResponse]) AsyncRoundTripper[combiner.PipelineResponse] {
return &queryValidatorWare{
next: next,
}
})
}

func (c queryValidatorWare) RoundTrip(req Request) (Responses[combiner.PipelineResponse], error) {
query := req.HTTPRequest().URL.Query()
err := c.validateTraceQLQuery(query)
if err != nil {
return NewBadRequest(err), nil
}
return c.next.RoundTrip(req)
}

func (c queryValidatorWare) validateTraceQLQuery(queryParams url.Values) error {
var traceQLQuery string
if queryParams.Has("q") {
traceQLQuery = queryParams.Get("q")
}
if queryParams.Has("query") {
traceQLQuery = queryParams.Get("query")
}
if traceQLQuery != "" {
expr, err := traceql.Parse(traceQLQuery)
if err == nil {
err = traceql.Validate(expr)
}
if err != nil {
return fmt.Errorf("invalid TraceQL query: %w", err)
}
}
return nil
}
52 changes: 52 additions & 0 deletions modules/frontend/pipeline/async_query_validator_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package pipeline

import (
"bytes"
"context"
"io"
"net/http"
"testing"

"github.com/grafana/tempo/modules/frontend/combiner"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var nextFunc = AsyncRoundTripperFunc[combiner.PipelineResponse](func(_ Request) (Responses[combiner.PipelineResponse], error) {
return NewHTTPToAsyncResponse(&http.Response{
StatusCode: 200,
Body: io.NopCloser(bytes.NewReader([]byte{})),
}), nil
})

func TestQueryValidator(t *testing.T) {
roundTrip := NewQueryValidatorWare().Wrap(nextFunc)
statusCode := doRequest(t, "http://localhost:8080/api/search", roundTrip)
assert.Equal(t, 200, statusCode)
}

func TestQueryValidatorForAValidQuery(t *testing.T) {
roundTrip := NewQueryValidatorWare().Wrap(nextFunc)
statusCode := doRequest(t, "http://localhost:8080/api/search&q={}", roundTrip)
assert.Equal(t, 200, statusCode)
}

func TestQueryValidatorForAnInvalidTraceQLQuery(t *testing.T) {
roundTrip := NewQueryValidatorWare().Wrap(nextFunc)
statusCode := doRequest(t, "http://localhost:8080/api/search?q={. hi}", roundTrip)
assert.Equal(t, 400, statusCode)
}

func TestQueryValidatorForAnInvalidTraceQlQueryRegex(t *testing.T) {
roundTrip := NewQueryValidatorWare().Wrap(nextFunc)
statusCode := doRequest(t, "http://localhost:8080/api/search?query={span.a =~ \"[\"}", roundTrip)
assert.Equal(t, 400, statusCode)
}

func doRequest(t *testing.T, url string, rt AsyncRoundTripper[combiner.PipelineResponse]) int {
req, _ := http.NewRequest(http.MethodGet, url, nil)
resp, _ := rt.RoundTrip(NewHTTPRequest(req))
httpResponse, _, err := resp.Next(context.Background())
require.NoError(t, err)
return httpResponse.HTTPResponse().StatusCode
}
7 changes: 0 additions & 7 deletions pkg/api/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,6 @@ func ParseSearchRequest(r *http.Request) (*tempopb.SearchRequest, error) {

query, queryFound := extractQueryParam(vals, urlParamQuery)
if queryFound {
// TODO hacky fix: we don't validate {} since this isn't handled correctly yet
if query != "{}" {
_, err := traceql.Parse(query)
if err != nil {
return nil, fmt.Errorf("invalid TraceQL query: %w", err)
}
}
req.Query = query
}

Expand Down
5 changes: 0 additions & 5 deletions pkg/api/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,6 @@ func TestQuerierParseSearchRequest(t *testing.T) {
SpansPerSpanSet: defaultSpansPerSpanSet,
},
},
{
name: "invalid traceql query",
urlQuery: "q=" + url.QueryEscape(`{ .foo="bar" `),
err: "invalid TraceQL query: parse error at line 1, col 14: syntax error: unexpected $end",
},
{
name: "traceql query and tags",
urlQuery: "q=" + url.QueryEscape(`{ .foo="bar" }`) + "&tags=" + url.QueryEscape("service.name=foo"),
Expand Down
5 changes: 5 additions & 0 deletions pkg/traceql/validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package traceql

func Validate(expr *RootExpr) error {
return expr.validate()
}

0 comments on commit 8a0b0e7

Please sign in to comment.