diff --git a/integration/e2e/query_range_test.go b/integration/e2e/query_range_test.go index 782ee0798ce..821cb9b1c20 100644 --- a/integration/e2e/query_range_test.go +++ b/integration/e2e/query_range_test.go @@ -60,15 +60,13 @@ sendLoop: callQueryRange(t, tempo.Endpoint(3200), query, debugMode) }) } + + res := doRequest(t, tempo.Endpoint(3200), "{. a}") + require.Equal(t, 400, res.StatusCode) } func callQueryRange(t *testing.T, endpoint, query string, printBody bool) { - url := buildURL(endpoint, fmt.Sprintf("%s with(exemplars=true)", query)) - req, err := http.NewRequest(http.MethodGet, url, nil) - require.NoError(t, err) - - res, err := http.DefaultClient.Do(req) - require.NoError(t, err) + res := doRequest(t, endpoint, query) require.Equal(t, http.StatusOK, res.StatusCode) // Read body and print it @@ -89,6 +87,16 @@ func callQueryRange(t *testing.T, endpoint, query string, printBody bool) { require.GreaterOrEqual(t, exemplarCount, 1) } +func doRequest(t *testing.T, endpoint, query string) *http.Response { + url := buildURL(endpoint, fmt.Sprintf("%s with(exemplars=true)", query)) + req, err := http.NewRequest(http.MethodGet, url, nil) + require.NoError(t, err) + + res, err := http.DefaultClient.Do(req) + require.NoError(t, err) + return res +} + func buildURL(endpoint, query string) string { return fmt.Sprintf( "http://%s/api/metrics/query_range?query=%s&start=%d&end=%d&step=%s", diff --git a/modules/frontend/frontend.go b/modules/frontend/frontend.go index 0c67e135f97..2fc76f5e518 100644 --- a/modules/frontend/frontend.go +++ b/modules/frontend/frontend.go @@ -93,6 +93,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo statusCodeWare := pipeline.NewStatusCodeAdjustWare() traceIDStatusCodeWare := pipeline.NewStatusCodeAdjustWareWithAllowedCode(http.StatusNotFound) urlDenyListWare := pipeline.NewURLDenyListWare(cfg.URLDenyList) + queryValidatorWare := pipeline.NewQueryValidatorWare() tracePipeline := pipeline.Build( []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ @@ -106,6 +107,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo searchPipeline := pipeline.Build( []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ urlDenyListWare, + queryValidatorWare, multiTenantMiddleware(cfg, logger), newAsyncSearchSharder(reader, o, cfg.Search.Sharder, logger), }, @@ -134,6 +136,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo metricsPipeline := pipeline.Build( []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ urlDenyListWare, + queryValidatorWare, multiTenantUnsupportedMiddleware(cfg, logger), }, []pipeline.Middleware{statusCodeWare, retryWare}, @@ -143,6 +146,7 @@ func New(cfg Config, next http.RoundTripper, o overrides.Interface, reader tempo queryRangePipeline := pipeline.Build( []pipeline.AsyncMiddleware[combiner.PipelineResponse]{ urlDenyListWare, + queryValidatorWare, multiTenantMiddleware(cfg, logger), newAsyncQueryRangeSharder(reader, o, cfg.Metrics.Sharder, logger), }, diff --git a/modules/frontend/metrics_query_handler.go b/modules/frontend/metrics_query_handler.go index a347dedb198..2038d5887eb 100644 --- a/modules/frontend/metrics_query_handler.go +++ b/modules/frontend/metrics_query_handler.go @@ -116,8 +116,8 @@ func newMetricsQueryInstantHTTPHandler(cfg Config, next pipeline.AsyncRoundTripp if err != nil { level.Error(logger).Log("msg", "query instant: query range combiner failed", "err", err) return &http.Response{ - StatusCode: http.StatusInternalServerError, - Status: http.StatusText(http.StatusInternalServerError), + StatusCode: http.StatusBadRequest, + Status: http.StatusText(http.StatusBadRequest), Body: io.NopCloser(strings.NewReader(err.Error())), }, nil } diff --git a/modules/frontend/metrics_query_range_handler.go b/modules/frontend/metrics_query_range_handler.go index 1b3e3c2f59a..dd8e760a466 100644 --- a/modules/frontend/metrics_query_range_handler.go +++ b/modules/frontend/metrics_query_range_handler.go @@ -87,8 +87,8 @@ func newMetricsQueryRangeHTTPHandler(cfg Config, next pipeline.AsyncRoundTripper if err != nil { level.Error(logger).Log("msg", "query range: query range combiner failed", "err", err) return &http.Response{ - StatusCode: http.StatusInternalServerError, - Status: http.StatusText(http.StatusInternalServerError), + StatusCode: http.StatusBadRequest, + Status: http.StatusText(http.StatusBadRequest), Body: io.NopCloser(strings.NewReader(err.Error())), }, nil } diff --git a/modules/frontend/pipeline/deny_list_middleware.go b/modules/frontend/pipeline/async_deny_list_middleware.go similarity index 100% rename from modules/frontend/pipeline/deny_list_middleware.go rename to modules/frontend/pipeline/async_deny_list_middleware.go diff --git a/modules/frontend/pipeline/deny_list_middleware_test.go b/modules/frontend/pipeline/async_deny_list_middleware_test.go similarity index 100% rename from modules/frontend/pipeline/deny_list_middleware_test.go rename to modules/frontend/pipeline/async_deny_list_middleware_test.go diff --git a/modules/frontend/pipeline/async_query_validator_middleware.go b/modules/frontend/pipeline/async_query_validator_middleware.go new file mode 100644 index 00000000000..6aa2fd3f0fa --- /dev/null +++ b/modules/frontend/pipeline/async_query_validator_middleware.go @@ -0,0 +1,50 @@ +package pipeline + +import ( + "fmt" + "net/url" + + "github.com/grafana/tempo/modules/frontend/combiner" + "github.com/grafana/tempo/pkg/traceql" +) + +type queryValidatorWare struct { + next AsyncRoundTripper[combiner.PipelineResponse] +} + +func NewQueryValidatorWare() AsyncMiddleware[combiner.PipelineResponse] { + return AsyncMiddlewareFunc[combiner.PipelineResponse](func(next AsyncRoundTripper[combiner.PipelineResponse]) AsyncRoundTripper[combiner.PipelineResponse] { + return &queryValidatorWare{ + next: next, + } + }) +} + +func (c queryValidatorWare) RoundTrip(req Request) (Responses[combiner.PipelineResponse], error) { + query := req.HTTPRequest().URL.Query() + err := c.validateTraceQLQuery(query) + if err != nil { + return NewBadRequest(err), nil + } + return c.next.RoundTrip(req) +} + +func (c queryValidatorWare) validateTraceQLQuery(queryParams url.Values) error { + var traceQLQuery string + if queryParams.Has("q") { + traceQLQuery = queryParams.Get("q") + } + if queryParams.Has("query") { + traceQLQuery = queryParams.Get("query") + } + if traceQLQuery != "" { + expr, err := traceql.Parse(traceQLQuery) + if err == nil { + err = traceql.Validate(expr) + } + if err != nil { + return fmt.Errorf("invalid TraceQL query: %w", err) + } + } + return nil +} diff --git a/modules/frontend/pipeline/async_query_validator_middleware_test.go b/modules/frontend/pipeline/async_query_validator_middleware_test.go new file mode 100644 index 00000000000..be711e8f4ec --- /dev/null +++ b/modules/frontend/pipeline/async_query_validator_middleware_test.go @@ -0,0 +1,52 @@ +package pipeline + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" + + "github.com/grafana/tempo/modules/frontend/combiner" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var nextFunc = AsyncRoundTripperFunc[combiner.PipelineResponse](func(_ Request) (Responses[combiner.PipelineResponse], error) { + return NewHTTPToAsyncResponse(&http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader([]byte{})), + }), nil +}) + +func TestQueryValidator(t *testing.T) { + roundTrip := NewQueryValidatorWare().Wrap(nextFunc) + statusCode := doRequest(t, "http://localhost:8080/api/search", roundTrip) + assert.Equal(t, 200, statusCode) +} + +func TestQueryValidatorForAValidQuery(t *testing.T) { + roundTrip := NewQueryValidatorWare().Wrap(nextFunc) + statusCode := doRequest(t, "http://localhost:8080/api/search&q={}", roundTrip) + assert.Equal(t, 200, statusCode) +} + +func TestQueryValidatorForAnInvalidTraceQLQuery(t *testing.T) { + roundTrip := NewQueryValidatorWare().Wrap(nextFunc) + statusCode := doRequest(t, "http://localhost:8080/api/search?q={. hi}", roundTrip) + assert.Equal(t, 400, statusCode) +} + +func TestQueryValidatorForAnInvalidTraceQlQueryRegex(t *testing.T) { + roundTrip := NewQueryValidatorWare().Wrap(nextFunc) + statusCode := doRequest(t, "http://localhost:8080/api/search?query={span.a =~ \"[\"}", roundTrip) + assert.Equal(t, 400, statusCode) +} + +func doRequest(t *testing.T, url string, rt AsyncRoundTripper[combiner.PipelineResponse]) int { + req, _ := http.NewRequest(http.MethodGet, url, nil) + resp, _ := rt.RoundTrip(NewHTTPRequest(req)) + httpResponse, _, err := resp.Next(context.Background()) + require.NoError(t, err) + return httpResponse.HTTPResponse().StatusCode +} diff --git a/pkg/api/http.go b/pkg/api/http.go index 83029129145..285ad995fc1 100644 --- a/pkg/api/http.go +++ b/pkg/api/http.go @@ -146,13 +146,6 @@ func ParseSearchRequest(r *http.Request) (*tempopb.SearchRequest, error) { query, queryFound := extractQueryParam(vals, urlParamQuery) if queryFound { - // TODO hacky fix: we don't validate {} since this isn't handled correctly yet - if query != "{}" { - _, err := traceql.Parse(query) - if err != nil { - return nil, fmt.Errorf("invalid TraceQL query: %w", err) - } - } req.Query = query } diff --git a/pkg/api/http_test.go b/pkg/api/http_test.go index c3387192b01..53e9af55de8 100644 --- a/pkg/api/http_test.go +++ b/pkg/api/http_test.go @@ -103,11 +103,6 @@ func TestQuerierParseSearchRequest(t *testing.T) { SpansPerSpanSet: defaultSpansPerSpanSet, }, }, - { - name: "invalid traceql query", - urlQuery: "q=" + url.QueryEscape(`{ .foo="bar" `), - err: "invalid TraceQL query: parse error at line 1, col 14: syntax error: unexpected $end", - }, { name: "traceql query and tags", urlQuery: "q=" + url.QueryEscape(`{ .foo="bar" }`) + "&tags=" + url.QueryEscape("service.name=foo"), diff --git a/pkg/traceql/validate.go b/pkg/traceql/validate.go new file mode 100644 index 00000000000..b02c1b9bb39 --- /dev/null +++ b/pkg/traceql/validate.go @@ -0,0 +1,5 @@ +package traceql + +func Validate(expr *RootExpr) error { + return expr.validate() +}