From da14237f6f1b7fd66ba7ab2b43cb5bcb0fa1436f Mon Sep 17 00:00:00 2001 From: zepatrik Date: Thu, 11 Apr 2024 15:19:37 +0200 Subject: [PATCH] feat: measure external latency --- go.mod | 4 +- httpx/external_latency.go | 33 +++++++++++ httpx/ssrf.go | 23 ++++---- reqlog/external_latency.go | 97 +++++++++++++++++++++++++++++++++ reqlog/external_latency_test.go | 95 ++++++++++++++++++++++++++++++++ reqlog/middleware.go | 9 ++- 6 files changed, 248 insertions(+), 13 deletions(-) create mode 100644 httpx/external_latency.go create mode 100644 reqlog/external_latency.go create mode 100644 reqlog/external_latency_test.go diff --git a/go.mod b/go.mod index 740cb8be..9b46f6e9 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/ory/x -go 1.21 +go 1.22 + +toolchain go1.22.2 require ( code.dny.dev/ssrf v0.2.0 diff --git a/httpx/external_latency.go b/httpx/external_latency.go new file mode 100644 index 00000000..ac4f77f5 --- /dev/null +++ b/httpx/external_latency.go @@ -0,0 +1,33 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package httpx + +import ( + "net/http" + + "github.com/ory/x/reqlog" +) + +// MeasureExternalLatencyTransport is an http.RoundTripper that measures the latency of all requests as external latency. +type MeasureExternalLatencyTransport struct { + Transport http.RoundTripper +} + +var _ http.RoundTripper = (*MeasureExternalLatencyTransport)(nil) + +func (m *MeasureExternalLatencyTransport) RoundTrip(req *http.Request) (*http.Response, error) { + upstreamHostPath := req.URL.Scheme + "://" + req.URL.Host + req.URL.Path + defer reqlog.StartMeasureExternalCall(req.Context(), "http_request", upstreamHostPath)() + + t := m.Transport + if t == nil { + t = http.DefaultTransport + } + return t.RoundTrip(req) +} + +// ClientWithExternalLatencyMiddleware adds a middleware to the client that measures the latency of all requests as external latency. +func ClientWithExternalLatencyMiddleware(c *http.Client) { + c.Transport = &MeasureExternalLatencyTransport{Transport: c.Transport} +} diff --git a/httpx/ssrf.go b/httpx/ssrf.go index b92d579e..66a010ef 100644 --- a/httpx/ssrf.go +++ b/httpx/ssrf.go @@ -52,10 +52,10 @@ func (n noInternalIPRoundTripper) RoundTrip(request *http.Request) (*http.Respon } var ( - prohibitInternalAllowIPv6 http.RoundTripper - prohibitInternalProhibitIPv6 http.RoundTripper - allowInternalAllowIPv6 http.RoundTripper - allowInternalProhibitIPv6 http.RoundTripper + prohibitInternalAllowIPv6, + prohibitInternalProhibitIPv6, + allowInternalAllowIPv6, + allowInternalProhibitIPv6 http.RoundTripper ) func init() { @@ -64,7 +64,7 @@ func init() { ssrf.WithAnyPort(), ssrf.WithNetworks("tcp4", "tcp6"), ).Safe - prohibitInternalAllowIPv6 = t + prohibitInternalAllowIPv6 = &MeasureExternalLatencyTransport{Transport: t} } func init() { @@ -76,7 +76,7 @@ func init() { t.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { return d.DialContext(ctx, "tcp4", addr) } - prohibitInternalProhibitIPv6 = t + prohibitInternalProhibitIPv6 = &MeasureExternalLatencyTransport{Transport: t} } func init() { @@ -96,7 +96,7 @@ func init() { netip.MustParsePrefix("fc00::/7"), // Unique Local (RFC 4193) ), ).Safe - allowInternalAllowIPv6 = t + allowInternalAllowIPv6 = &MeasureExternalLatencyTransport{Transport: t} } func init() { @@ -119,15 +119,15 @@ func init() { t.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { return d.DialContext(ctx, "tcp4", addr) } - allowInternalProhibitIPv6 = t + allowInternalProhibitIPv6 = &MeasureExternalLatencyTransport{Transport: t} } func newDefaultTransport() (*http.Transport, *net.Dialer) { - dialer := net.Dialer{ + dialer := &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, } - return &http.Transport{ + transport := &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: dialer.DialContext, ForceAttemptHTTP2: true, @@ -135,5 +135,6 @@ func newDefaultTransport() (*http.Transport, *net.Dialer) { IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, - }, &dialer + } + return transport, dialer } diff --git a/reqlog/external_latency.go b/reqlog/external_latency.go new file mode 100644 index 00000000..b2393d1e --- /dev/null +++ b/reqlog/external_latency.go @@ -0,0 +1,97 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package reqlog + +import ( + "context" + "net/http" + "sync" + "time" +) + +// ExternalCallsMiddleware is a middleware that sets up the request context to measure external calls. +// It has to be used before any other middleware that reads the final external latency. +func ExternalCallsMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { + container := contextContainer{ + latencies: make([]externalLatency, 0), + } + next(rw, r.WithContext( + context.WithValue(r.Context(), internalLatencyKey, &container), + )) +} + +// MeasureExternalCall measures the duration of a function and records it as an external call. +// The wrapped function's return value is returned. +func MeasureExternalCall[T any](ctx context.Context, cause, detail string, f func() T) T { + defer StartMeasureExternalCall(ctx, cause, detail)() + return f() +} + +// MeasureExternalCallErr measures the duration of a function and records it as an external call. +// The wrapped function's return value and error is returned. +func MeasureExternalCallErr[T any](ctx context.Context, cause, detail string, f func() (T, error)) (T, error) { + defer StartMeasureExternalCall(ctx, cause, detail)() + return f() +} + +// StartMeasureExternalCall starts measuring the duration of an external call. +// The returned function has to be called to record the duration. +func StartMeasureExternalCall(ctx context.Context, cause, detail string) func() { + container, ok := ctx.Value(internalLatencyKey).(*contextContainer) + if !ok { + return func() {} + } + + start := time.Now() + return func() { + container.Lock() + defer container.Unlock() + container.latencies = append(container.latencies, externalLatency{ + Took: time.Since(start), + Cause: cause, + Detail: detail, + }) + } +} + +// TotalExternalLatency returns the total duration of all external calls. +func TotalExternalLatency(ctx context.Context) (total time.Duration) { + if _, ok := ctx.Value(disableExternalLatencyMeasurement).(bool); ok { + return 0 + } + container, ok := ctx.Value(internalLatencyKey).(*contextContainer) + if !ok { + return 0 + } + + container.Lock() + defer container.Unlock() + for _, l := range container.latencies { + total += l.Took + } + return total +} + +// WithDisableExternalLatencyMeasurement returns a context that does not measure external latencies. +// Use this when you want to disable external latency measurements for a specific request. +func WithDisableExternalLatencyMeasurement(ctx context.Context) context.Context { + return context.WithValue(ctx, disableExternalLatencyMeasurement, true) +} + +type ( + externalLatency = struct { + Took time.Duration + Cause, Detail string + } + contextContainer = struct { + latencies []externalLatency + sync.Mutex + } + contextKey int +) + +const ( + internalLatencyKey contextKey = 1 + disableExternalLatencyMeasurement contextKey = 2 +) diff --git a/reqlog/external_latency_test.go b/reqlog/external_latency_test.go new file mode 100644 index 00000000..44758b80 --- /dev/null +++ b/reqlog/external_latency_test.go @@ -0,0 +1,95 @@ +// Copyright © 2024 Ory Corp +// SPDX-License-Identifier: Apache-2.0 + +package reqlog + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "golang.org/x/sync/errgroup" + + "github.com/ory/x/assertx" +) + +func TestExternalLatencyMiddleware(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ExternalCallsMiddleware(w, r, func(w http.ResponseWriter, r *http.Request) { + var ( + wg sync.WaitGroup + res0, res1 string + err error + ) + + wg.Add(3) + go func() { + res0 = MeasureExternalCall(r.Context(), "", "", func() string { + time.Sleep(100 * time.Millisecond) + return "foo" + }) + wg.Done() + }() + go func() { + res1, err = MeasureExternalCallErr(r.Context(), "", "", func() (string, error) { + time.Sleep(100 * time.Millisecond) + return "bar", nil + }) + wg.Done() + }() + go func() { + _ = MeasureExternalCall(WithDisableExternalLatencyMeasurement(r.Context()), "", "", func() error { + time.Sleep(100 * time.Millisecond) + return nil + }) + wg.Done() + }() + wg.Wait() + total := TotalExternalLatency(r.Context()) + _ = json.NewEncoder(w).Encode(map[string]any{ + "res0": res0, + "res1": res1, + "err": err, + "total": total, + }) + }) + })) + defer ts.Close() + + bodies := make([][]byte, 100) + eg := errgroup.Group{} + for i := range bodies { + eg.Go(func() error { + res, err := http.Get(ts.URL) + if err != nil { + return err + } + defer res.Body.Close() + bodies[i], err = io.ReadAll(res.Body) + if err != nil { + return err + } + return nil + }) + } + + require.NoError(t, eg.Wait()) + + for _, body := range bodies { + assertx.EqualAsJSONExcept(t, map[string]any{ + "res0": "foo", + "res1": "bar", + "err": nil, + }, json.RawMessage(body), []string{"total"}) + + actualTotal := gjson.GetBytes(body, "total").Int() + assert.GreaterOrEqual(t, actualTotal, int64(200*time.Millisecond), string(body)) + } +} diff --git a/reqlog/middleware.go b/reqlog/middleware.go index 18592b10..5ed20e8b 100644 --- a/reqlog/middleware.go +++ b/reqlog/middleware.go @@ -161,11 +161,18 @@ func DefaultBefore(entry *logrusx.Logger, req *http.Request, remoteAddr string) // DefaultAfter is the default func assigned to *Middleware.After func DefaultAfter(entry *logrusx.Logger, req *http.Request, res negroni.ResponseWriter, latency time.Duration, name string) *logrusx.Logger { - return entry.WithRequest(req).WithField("http_response", map[string]interface{}{ + e := entry.WithRequest(req).WithField("http_response", map[string]any{ "status": res.Status(), "size": res.Size(), "text_status": http.StatusText(res.Status()), "took": latency, "headers": entry.HTTPHeadersRedacted(res.Header()), }) + if el := TotalExternalLatency(req.Context()); el > 0 { + e = e.WithFields(map[string]any{ + "took_internal": latency - el, + "took_external": el, + }) + } + return e }