Skip to content

Commit

Permalink
feat: measure external latency
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik committed Apr 11, 2024
1 parent acfef3d commit da14237
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 13 deletions.
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module github.com/ory/x

go 1.21
go 1.22

toolchain go1.22.2

require (
code.dny.dev/ssrf v0.2.0
Expand Down
33 changes: 33 additions & 0 deletions httpx/external_latency.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package httpx

import (
"net/http"

"github.com/ory/x/reqlog"
)

// MeasureExternalLatencyTransport is an http.RoundTripper that measures the latency of all requests as external latency.
type MeasureExternalLatencyTransport struct {
Transport http.RoundTripper
}

var _ http.RoundTripper = (*MeasureExternalLatencyTransport)(nil)

func (m *MeasureExternalLatencyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
upstreamHostPath := req.URL.Scheme + "://" + req.URL.Host + req.URL.Path
defer reqlog.StartMeasureExternalCall(req.Context(), "http_request", upstreamHostPath)()

t := m.Transport
if t == nil {
t = http.DefaultTransport
}
return t.RoundTrip(req)
}

// ClientWithExternalLatencyMiddleware adds a middleware to the client that measures the latency of all requests as external latency.
func ClientWithExternalLatencyMiddleware(c *http.Client) {
c.Transport = &MeasureExternalLatencyTransport{Transport: c.Transport}
}
23 changes: 12 additions & 11 deletions httpx/ssrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ func (n noInternalIPRoundTripper) RoundTrip(request *http.Request) (*http.Respon
}

var (
prohibitInternalAllowIPv6 http.RoundTripper
prohibitInternalProhibitIPv6 http.RoundTripper
allowInternalAllowIPv6 http.RoundTripper
allowInternalProhibitIPv6 http.RoundTripper
prohibitInternalAllowIPv6,
prohibitInternalProhibitIPv6,
allowInternalAllowIPv6,
allowInternalProhibitIPv6 http.RoundTripper
)

func init() {
Expand All @@ -64,7 +64,7 @@ func init() {
ssrf.WithAnyPort(),
ssrf.WithNetworks("tcp4", "tcp6"),
).Safe
prohibitInternalAllowIPv6 = t
prohibitInternalAllowIPv6 = &MeasureExternalLatencyTransport{Transport: t}
}

func init() {
Expand All @@ -76,7 +76,7 @@ func init() {
t.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return d.DialContext(ctx, "tcp4", addr)
}
prohibitInternalProhibitIPv6 = t
prohibitInternalProhibitIPv6 = &MeasureExternalLatencyTransport{Transport: t}
}

func init() {
Expand All @@ -96,7 +96,7 @@ func init() {
netip.MustParsePrefix("fc00::/7"), // Unique Local (RFC 4193)
),
).Safe
allowInternalAllowIPv6 = t
allowInternalAllowIPv6 = &MeasureExternalLatencyTransport{Transport: t}
}

func init() {
Expand All @@ -119,21 +119,22 @@ func init() {
t.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
return d.DialContext(ctx, "tcp4", addr)
}
allowInternalProhibitIPv6 = t
allowInternalProhibitIPv6 = &MeasureExternalLatencyTransport{Transport: t}
}

func newDefaultTransport() (*http.Transport, *net.Dialer) {
dialer := net.Dialer{
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
return &http.Transport{
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}, &dialer
}
return transport, dialer
}
97 changes: 97 additions & 0 deletions reqlog/external_latency.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package reqlog

import (
"context"
"net/http"
"sync"
"time"
)

// ExternalCallsMiddleware is a middleware that sets up the request context to measure external calls.
// It has to be used before any other middleware that reads the final external latency.
func ExternalCallsMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
container := contextContainer{
latencies: make([]externalLatency, 0),
}
next(rw, r.WithContext(
context.WithValue(r.Context(), internalLatencyKey, &container),
))
}

// MeasureExternalCall measures the duration of a function and records it as an external call.
// The wrapped function's return value is returned.
func MeasureExternalCall[T any](ctx context.Context, cause, detail string, f func() T) T {
defer StartMeasureExternalCall(ctx, cause, detail)()
return f()
}

// MeasureExternalCallErr measures the duration of a function and records it as an external call.
// The wrapped function's return value and error is returned.
func MeasureExternalCallErr[T any](ctx context.Context, cause, detail string, f func() (T, error)) (T, error) {
defer StartMeasureExternalCall(ctx, cause, detail)()
return f()
}

// StartMeasureExternalCall starts measuring the duration of an external call.
// The returned function has to be called to record the duration.
func StartMeasureExternalCall(ctx context.Context, cause, detail string) func() {
container, ok := ctx.Value(internalLatencyKey).(*contextContainer)
if !ok {
return func() {}
}

start := time.Now()
return func() {
container.Lock()
defer container.Unlock()
container.latencies = append(container.latencies, externalLatency{
Took: time.Since(start),
Cause: cause,
Detail: detail,
})
}
}

// TotalExternalLatency returns the total duration of all external calls.
func TotalExternalLatency(ctx context.Context) (total time.Duration) {
if _, ok := ctx.Value(disableExternalLatencyMeasurement).(bool); ok {
return 0
}
container, ok := ctx.Value(internalLatencyKey).(*contextContainer)
if !ok {
return 0
}

container.Lock()
defer container.Unlock()
for _, l := range container.latencies {
total += l.Took
}
return total
}

// WithDisableExternalLatencyMeasurement returns a context that does not measure external latencies.
// Use this when you want to disable external latency measurements for a specific request.
func WithDisableExternalLatencyMeasurement(ctx context.Context) context.Context {
return context.WithValue(ctx, disableExternalLatencyMeasurement, true)
}

type (
externalLatency = struct {
Took time.Duration
Cause, Detail string
}
contextContainer = struct {
latencies []externalLatency
sync.Mutex
}
contextKey int
)

const (
internalLatencyKey contextKey = 1
disableExternalLatencyMeasurement contextKey = 2
)
95 changes: 95 additions & 0 deletions reqlog/external_latency_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright © 2024 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package reqlog

import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"golang.org/x/sync/errgroup"

"github.com/ory/x/assertx"
)

func TestExternalLatencyMiddleware(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ExternalCallsMiddleware(w, r, func(w http.ResponseWriter, r *http.Request) {
var (
wg sync.WaitGroup
res0, res1 string
err error
)

wg.Add(3)
go func() {
res0 = MeasureExternalCall(r.Context(), "", "", func() string {
time.Sleep(100 * time.Millisecond)
return "foo"
})
wg.Done()
}()
go func() {
res1, err = MeasureExternalCallErr(r.Context(), "", "", func() (string, error) {
time.Sleep(100 * time.Millisecond)
return "bar", nil
})
wg.Done()
}()
go func() {
_ = MeasureExternalCall(WithDisableExternalLatencyMeasurement(r.Context()), "", "", func() error {
time.Sleep(100 * time.Millisecond)
return nil
})
wg.Done()
}()
wg.Wait()
total := TotalExternalLatency(r.Context())
_ = json.NewEncoder(w).Encode(map[string]any{
"res0": res0,
"res1": res1,
"err": err,
"total": total,
})
})
}))
defer ts.Close()

bodies := make([][]byte, 100)
eg := errgroup.Group{}
for i := range bodies {
eg.Go(func() error {
res, err := http.Get(ts.URL)
if err != nil {
return err
}
defer res.Body.Close()
bodies[i], err = io.ReadAll(res.Body)
if err != nil {
return err
}
return nil
})
}

require.NoError(t, eg.Wait())

for _, body := range bodies {
assertx.EqualAsJSONExcept(t, map[string]any{
"res0": "foo",
"res1": "bar",
"err": nil,
}, json.RawMessage(body), []string{"total"})

actualTotal := gjson.GetBytes(body, "total").Int()
assert.GreaterOrEqual(t, actualTotal, int64(200*time.Millisecond), string(body))
}
}
9 changes: 8 additions & 1 deletion reqlog/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,18 @@ func DefaultBefore(entry *logrusx.Logger, req *http.Request, remoteAddr string)

// DefaultAfter is the default func assigned to *Middleware.After
func DefaultAfter(entry *logrusx.Logger, req *http.Request, res negroni.ResponseWriter, latency time.Duration, name string) *logrusx.Logger {
return entry.WithRequest(req).WithField("http_response", map[string]interface{}{
e := entry.WithRequest(req).WithField("http_response", map[string]any{
"status": res.Status(),
"size": res.Size(),
"text_status": http.StatusText(res.Status()),
"took": latency,
"headers": entry.HTTPHeadersRedacted(res.Header()),
})
if el := TotalExternalLatency(req.Context()); el > 0 {
e = e.WithFields(map[string]any{
"took_internal": latency - el,
"took_external": el,
})
}
return e
}

0 comments on commit da14237

Please sign in to comment.