-
Notifications
You must be signed in to change notification settings - Fork 108
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
248 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
module github.com/ory/x | ||
|
||
go 1.21 | ||
go 1.22 | ||
|
||
toolchain go1.22.2 | ||
|
||
require ( | ||
code.dny.dev/ssrf v0.2.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Copyright © 2024 Ory Corp | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package httpx | ||
|
||
import ( | ||
"net/http" | ||
|
||
"github.com/ory/x/reqlog" | ||
) | ||
|
||
// MeasureExternalLatencyTransport is an http.RoundTripper that measures the latency of all requests as external latency. | ||
type MeasureExternalLatencyTransport struct { | ||
Transport http.RoundTripper | ||
} | ||
|
||
var _ http.RoundTripper = (*MeasureExternalLatencyTransport)(nil) | ||
|
||
func (m *MeasureExternalLatencyTransport) RoundTrip(req *http.Request) (*http.Response, error) { | ||
upstreamHostPath := req.URL.Scheme + "://" + req.URL.Host + req.URL.Path | ||
defer reqlog.StartMeasureExternalCall(req.Context(), "http_request", upstreamHostPath)() | ||
|
||
t := m.Transport | ||
if t == nil { | ||
t = http.DefaultTransport | ||
} | ||
return t.RoundTrip(req) | ||
} | ||
|
||
// ClientWithExternalLatencyMiddleware adds a middleware to the client that measures the latency of all requests as external latency. | ||
func ClientWithExternalLatencyMiddleware(c *http.Client) { | ||
c.Transport = &MeasureExternalLatencyTransport{Transport: c.Transport} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// Copyright © 2024 Ory Corp | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package reqlog | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"sync" | ||
"time" | ||
) | ||
|
||
// ExternalCallsMiddleware is a middleware that sets up the request context to measure external calls. | ||
// It has to be used before any other middleware that reads the final external latency. | ||
func ExternalCallsMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) { | ||
container := contextContainer{ | ||
latencies: make([]externalLatency, 0), | ||
} | ||
next(rw, r.WithContext( | ||
context.WithValue(r.Context(), internalLatencyKey, &container), | ||
)) | ||
} | ||
|
||
// MeasureExternalCall measures the duration of a function and records it as an external call. | ||
// The wrapped function's return value is returned. | ||
func MeasureExternalCall[T any](ctx context.Context, cause, detail string, f func() T) T { | ||
defer StartMeasureExternalCall(ctx, cause, detail)() | ||
return f() | ||
} | ||
|
||
// MeasureExternalCallErr measures the duration of a function and records it as an external call. | ||
// The wrapped function's return value and error is returned. | ||
func MeasureExternalCallErr[T any](ctx context.Context, cause, detail string, f func() (T, error)) (T, error) { | ||
defer StartMeasureExternalCall(ctx, cause, detail)() | ||
return f() | ||
} | ||
|
||
// StartMeasureExternalCall starts measuring the duration of an external call. | ||
// The returned function has to be called to record the duration. | ||
func StartMeasureExternalCall(ctx context.Context, cause, detail string) func() { | ||
container, ok := ctx.Value(internalLatencyKey).(*contextContainer) | ||
if !ok { | ||
return func() {} | ||
} | ||
|
||
start := time.Now() | ||
return func() { | ||
container.Lock() | ||
defer container.Unlock() | ||
container.latencies = append(container.latencies, externalLatency{ | ||
Took: time.Since(start), | ||
Cause: cause, | ||
Detail: detail, | ||
}) | ||
} | ||
} | ||
|
||
// TotalExternalLatency returns the total duration of all external calls. | ||
func TotalExternalLatency(ctx context.Context) (total time.Duration) { | ||
if _, ok := ctx.Value(disableExternalLatencyMeasurement).(bool); ok { | ||
return 0 | ||
} | ||
container, ok := ctx.Value(internalLatencyKey).(*contextContainer) | ||
if !ok { | ||
return 0 | ||
} | ||
|
||
container.Lock() | ||
defer container.Unlock() | ||
for _, l := range container.latencies { | ||
total += l.Took | ||
} | ||
return total | ||
} | ||
|
||
// WithDisableExternalLatencyMeasurement returns a context that does not measure external latencies. | ||
// Use this when you want to disable external latency measurements for a specific request. | ||
func WithDisableExternalLatencyMeasurement(ctx context.Context) context.Context { | ||
return context.WithValue(ctx, disableExternalLatencyMeasurement, true) | ||
} | ||
|
||
type ( | ||
externalLatency = struct { | ||
Took time.Duration | ||
Cause, Detail string | ||
} | ||
contextContainer = struct { | ||
latencies []externalLatency | ||
sync.Mutex | ||
} | ||
contextKey int | ||
) | ||
|
||
const ( | ||
internalLatencyKey contextKey = 1 | ||
disableExternalLatencyMeasurement contextKey = 2 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
// Copyright © 2024 Ory Corp | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package reqlog | ||
|
||
import ( | ||
"encoding/json" | ||
"io" | ||
"net/http" | ||
"net/http/httptest" | ||
"sync" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
"github.com/tidwall/gjson" | ||
"golang.org/x/sync/errgroup" | ||
|
||
"github.com/ory/x/assertx" | ||
) | ||
|
||
func TestExternalLatencyMiddleware(t *testing.T) { | ||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
ExternalCallsMiddleware(w, r, func(w http.ResponseWriter, r *http.Request) { | ||
var ( | ||
wg sync.WaitGroup | ||
res0, res1 string | ||
err error | ||
) | ||
|
||
wg.Add(3) | ||
go func() { | ||
res0 = MeasureExternalCall(r.Context(), "", "", func() string { | ||
time.Sleep(100 * time.Millisecond) | ||
return "foo" | ||
}) | ||
wg.Done() | ||
}() | ||
go func() { | ||
res1, err = MeasureExternalCallErr(r.Context(), "", "", func() (string, error) { | ||
time.Sleep(100 * time.Millisecond) | ||
return "bar", nil | ||
}) | ||
wg.Done() | ||
}() | ||
go func() { | ||
_ = MeasureExternalCall(WithDisableExternalLatencyMeasurement(r.Context()), "", "", func() error { | ||
time.Sleep(100 * time.Millisecond) | ||
return nil | ||
}) | ||
wg.Done() | ||
}() | ||
wg.Wait() | ||
total := TotalExternalLatency(r.Context()) | ||
_ = json.NewEncoder(w).Encode(map[string]any{ | ||
"res0": res0, | ||
"res1": res1, | ||
"err": err, | ||
"total": total, | ||
}) | ||
}) | ||
})) | ||
defer ts.Close() | ||
|
||
bodies := make([][]byte, 100) | ||
eg := errgroup.Group{} | ||
for i := range bodies { | ||
eg.Go(func() error { | ||
res, err := http.Get(ts.URL) | ||
if err != nil { | ||
return err | ||
} | ||
defer res.Body.Close() | ||
bodies[i], err = io.ReadAll(res.Body) | ||
if err != nil { | ||
return err | ||
} | ||
return nil | ||
}) | ||
} | ||
|
||
require.NoError(t, eg.Wait()) | ||
|
||
for _, body := range bodies { | ||
assertx.EqualAsJSONExcept(t, map[string]any{ | ||
"res0": "foo", | ||
"res1": "bar", | ||
"err": nil, | ||
}, json.RawMessage(body), []string{"total"}) | ||
|
||
actualTotal := gjson.GetBytes(body, "total").Int() | ||
assert.GreaterOrEqual(t, actualTotal, int64(200*time.Millisecond), string(body)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters