From 32737dc708c1ecf0ec0ceaa4bbc0ac09286186fd Mon Sep 17 00:00:00 2001 From: Patrik Date: Mon, 2 Sep 2024 09:27:23 +0200 Subject: [PATCH] fix: validate page tokens for better error codes (#4021) --- courier/handler_test.go | 19 ++++++++++++------- .../sql/identity/persister_identity.go | 4 ++++ persistence/sql/persister_courier.go | 5 +++++ persistence/sql/persister_session.go | 5 +++++ x/err.go | 15 +++++++++------ 5 files changed, 35 insertions(+), 13 deletions(-) diff --git a/courier/handler_test.go b/courier/handler_test.go index 28a7ec55d8b2..f4ca48ba30d6 100644 --- a/courier/handler_test.go +++ b/courier/handler_test.go @@ -57,7 +57,7 @@ func TestHandler(t *testing.T) { conf.MustSet(ctx, config.ViperKeyAdminBaseURL, adminTS.URL) conf.MustSet(ctx, config.ViperKeyPublicBaseURL, mockServerURL.String()) - var get = func(t *testing.T, base *httptest.Server, href string, expectCode int) gjson.Result { + get := func(t *testing.T, base *httptest.Server, href string, expectCode int) gjson.Result { t.Helper() res, err := base.Client().Get(base.URL + href) require.NoError(t, err) @@ -69,7 +69,7 @@ func TestHandler(t *testing.T) { return gjson.ParseBytes(body) } - var getList = func(t *testing.T, tsName string, qs string) gjson.Result { + getList := func(t *testing.T, tsName string, qs string) gjson.Result { t.Helper() href := courier.AdminRouteListMessages + qs ts := adminTS @@ -109,12 +109,12 @@ func TestHandler(t *testing.T) { } require.NoError(t, reg.CourierPersister().AddMessage(context.Background(), &messages[i])) } - for i := 0; i < procCount; i++ { + for i := range procCount { require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[i].ID, courier.MessageStatusProcessing)) } t.Run("paging", func(t *testing.T) { - t.Run("case=should return half of the messages", func(t *testing.T) { + t.Run("case=should return first half of the messages", func(t *testing.T) { qs := fmt.Sprintf("?page_token=%s&page_size=%d", defaultPageToken, msgCount/2) for _, tc := range tss { @@ -124,7 +124,7 @@ func TestHandler(t *testing.T) { }) } }) - t.Run("case=should return no message", func(t *testing.T) { + t.Run("case=should error with random page token", func(t *testing.T) { token := keysetpagination.MapPageToken{ "id": "1232", "created_at": time.Now().Add(time.Duration(-10) * time.Hour).Format("2006-01-02 15:04:05.99999-07:00"), @@ -133,8 +133,13 @@ func TestHandler(t *testing.T) { for _, tc := range tss { t.Run("endpoint="+tc.name, func(t *testing.T) { - parsed := getList(t, tc.name, qs) - assert.Len(t, parsed.Array(), 0) + path := courier.AdminRouteListMessages + qs + if tc.name == "public" { + path = x.AdminPrefix + path + } + resp, err := tc.s.Client().Get(tc.s.URL + path) + require.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) }) } }) diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 762dc9ab0871..807d3d67779d 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -799,6 +799,10 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. attribute.Stringer("network.id", p.NetworkID(ctx)))...)) defer otelx.End(span, &err) + if _, err := uuid.FromString(paginator.Token().Parse("id")["id"]); err != nil { + return nil, nil, errors.WithStack(x.PageTokenInvalid) + } + nid := p.NetworkID(ctx) var is []identity.Identity diff --git a/persistence/sql/persister_courier.go b/persistence/sql/persister_courier.go index 456efea4fe70..ec9694924f6e 100644 --- a/persistence/sql/persister_courier.go +++ b/persistence/sql/persister_courier.go @@ -20,6 +20,7 @@ import ( "github.com/ory/kratos/courier" "github.com/ory/kratos/persistence/sql/update" + "github.com/ory/kratos/x" ) var _ courier.Persister = new(Persister) @@ -57,6 +58,10 @@ func (p *Persister) ListMessages(ctx context.Context, filter courier.ListCourier opts = append(opts, keysetpagination.WithColumn("created_at", "DESC")) paginator := keysetpagination.GetPaginator(opts...) + if _, err := uuid.FromString(paginator.Token().Parse("id")["id"]); err != nil { + return nil, 0, nil, errors.WithStack(x.PageTokenInvalid) + } + messages := make([]courier.Message, paginator.Size()) if err := q.Scope(keysetpagination.Paginate[courier.Message](paginator)). All(&messages); err != nil { diff --git a/persistence/sql/persister_session.go b/persistence/sql/persister_session.go index 37fccca0bd84..6279af56d26a 100644 --- a/persistence/sql/persister_session.go +++ b/persistence/sql/persister_session.go @@ -20,6 +20,7 @@ import ( "github.com/ory/kratos/identity" "github.com/ory/kratos/session" + "github.com/ory/kratos/x" "github.com/ory/kratos/x/events" "github.com/ory/x/otelx" "github.com/ory/x/pagination/keysetpagination" @@ -82,6 +83,10 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt paginatorOpts = append(paginatorOpts, keysetpagination.WithColumn("created_at", "DESC")) paginator := keysetpagination.GetPaginator(paginatorOpts...) + if _, err := uuid.FromString(paginator.Token().Parse("id")["id"]); err != nil { + return nil, 0, nil, errors.WithStack(x.PageTokenInvalid) + } + if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { q := c.Where("nid = ?", nid) if active != nil { diff --git a/x/err.go b/x/err.go index 2a42b3eb540b..5b3868734cca 100644 --- a/x/err.go +++ b/x/err.go @@ -10,12 +10,15 @@ import ( "github.com/ory/herodot" ) -var PseudoPanic = herodot.DefaultError{ - StatusField: http.StatusText(http.StatusInternalServerError), - ErrorField: "Code Bug Detected", - ReasonField: "The code ended up at a place where it should not have. Please report this as an issue at https://github.com/ory/kratos", - CodeField: http.StatusInternalServerError, -} +var ( + PseudoPanic = herodot.DefaultError{ + StatusField: http.StatusText(http.StatusInternalServerError), + ErrorField: "Code Bug Detected", + ReasonField: "The code ended up at a place where it should not have. Please report this as an issue at https://github.com/ory/kratos", + CodeField: http.StatusInternalServerError, + } + PageTokenInvalid = herodot.ErrBadRequest.WithReason("The page token is invalid, do not craft your own page tokens") +) func RecoverStatusCode(err error, fallback int) int { var sc herodot.StatusCodeCarrier