From 01020f143c58dd21bdaedffdd74c5af63c41a03a Mon Sep 17 00:00:00 2001 From: zepatrik Date: Fri, 2 Aug 2024 15:53:34 +0200 Subject: [PATCH 1/4] fix: validate page tokens for better error codes --- internal/client-go/go.sum | 1 + persistence/sql/identity/persister_identity.go | 4 ++++ persistence/sql/persister_courier.go | 4 ++++ persistence/sql/persister_session.go | 4 ++++ 4 files changed, 13 insertions(+) diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0..6cc3f5911d1 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 984fb0199da..c38347e80df 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -799,6 +799,10 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. attribute.Stringer("network.id", p.NetworkID(ctx)))...)) defer otelx.End(span, &err) + if _, err := uuid.FromString(paginator.Token().Encode()); err != nil { + return nil, nil, errors.WithStack(herodot.ErrBadRequest.WithReason("The page token is invalid, do not craft your own page tokens")) + } + nid := p.NetworkID(ctx) var is []identity.Identity diff --git a/persistence/sql/persister_courier.go b/persistence/sql/persister_courier.go index 456efea4fe7..427e19b6a08 100644 --- a/persistence/sql/persister_courier.go +++ b/persistence/sql/persister_courier.go @@ -57,6 +57,10 @@ func (p *Persister) ListMessages(ctx context.Context, filter courier.ListCourier opts = append(opts, keysetpagination.WithColumn("created_at", "DESC")) paginator := keysetpagination.GetPaginator(opts...) + if _, err := uuid.FromString(paginator.Token().Encode()); err != nil { + return nil, 0, nil, errors.WithStack(herodot.ErrBadRequest.WithReason("The page token is invalid, do not craft your own page tokens")) + } + messages := make([]courier.Message, paginator.Size()) if err := q.Scope(keysetpagination.Paginate[courier.Message](paginator)). All(&messages); err != nil { diff --git a/persistence/sql/persister_session.go b/persistence/sql/persister_session.go index 37fccca0bd8..1f4c2ab549c 100644 --- a/persistence/sql/persister_session.go +++ b/persistence/sql/persister_session.go @@ -82,6 +82,10 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt paginatorOpts = append(paginatorOpts, keysetpagination.WithColumn("created_at", "DESC")) paginator := keysetpagination.GetPaginator(paginatorOpts...) + if _, err := uuid.FromString(paginator.Token().Encode()); err != nil { + return nil, 0, nil, errors.WithStack(herodot.ErrBadRequest.WithReason("The page token is invalid, do not craft your own page tokens")) + } + if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { q := c.Where("nid = ?", nid) if active != nil { From b3ec44cd70f14a1e6c58f6adbc349495c29333fc Mon Sep 17 00:00:00 2001 From: zepatrik Date: Fri, 9 Aug 2024 16:07:50 +0200 Subject: [PATCH 2/4] chore: extract error to variable --- persistence/sql/identity/persister_identity.go | 2 +- persistence/sql/persister_courier.go | 3 ++- persistence/sql/persister_session.go | 3 ++- x/err.go | 15 +++++++++------ 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index c38347e80df..0f415e5c1ad 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -800,7 +800,7 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. defer otelx.End(span, &err) if _, err := uuid.FromString(paginator.Token().Encode()); err != nil { - return nil, nil, errors.WithStack(herodot.ErrBadRequest.WithReason("The page token is invalid, do not craft your own page tokens")) + return nil, nil, errors.WithStack(x.PageTokenInvalid) } nid := p.NetworkID(ctx) diff --git a/persistence/sql/persister_courier.go b/persistence/sql/persister_courier.go index 427e19b6a08..c63728869bc 100644 --- a/persistence/sql/persister_courier.go +++ b/persistence/sql/persister_courier.go @@ -20,6 +20,7 @@ import ( "github.com/ory/kratos/courier" "github.com/ory/kratos/persistence/sql/update" + "github.com/ory/kratos/x" ) var _ courier.Persister = new(Persister) @@ -58,7 +59,7 @@ func (p *Persister) ListMessages(ctx context.Context, filter courier.ListCourier paginator := keysetpagination.GetPaginator(opts...) if _, err := uuid.FromString(paginator.Token().Encode()); err != nil { - return nil, 0, nil, errors.WithStack(herodot.ErrBadRequest.WithReason("The page token is invalid, do not craft your own page tokens")) + return nil, 0, nil, errors.WithStack(x.PageTokenInvalid) } messages := make([]courier.Message, paginator.Size()) diff --git a/persistence/sql/persister_session.go b/persistence/sql/persister_session.go index 1f4c2ab549c..10bcadffa4b 100644 --- a/persistence/sql/persister_session.go +++ b/persistence/sql/persister_session.go @@ -20,6 +20,7 @@ import ( "github.com/ory/kratos/identity" "github.com/ory/kratos/session" + "github.com/ory/kratos/x" "github.com/ory/kratos/x/events" "github.com/ory/x/otelx" "github.com/ory/x/pagination/keysetpagination" @@ -83,7 +84,7 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt paginator := keysetpagination.GetPaginator(paginatorOpts...) if _, err := uuid.FromString(paginator.Token().Encode()); err != nil { - return nil, 0, nil, errors.WithStack(herodot.ErrBadRequest.WithReason("The page token is invalid, do not craft your own page tokens")) + return nil, 0, nil, errors.WithStack(x.PageTokenInvalid) } if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error { diff --git a/x/err.go b/x/err.go index 2a42b3eb540..5b3868734cc 100644 --- a/x/err.go +++ b/x/err.go @@ -10,12 +10,15 @@ import ( "github.com/ory/herodot" ) -var PseudoPanic = herodot.DefaultError{ - StatusField: http.StatusText(http.StatusInternalServerError), - ErrorField: "Code Bug Detected", - ReasonField: "The code ended up at a place where it should not have. Please report this as an issue at https://github.com/ory/kratos", - CodeField: http.StatusInternalServerError, -} +var ( + PseudoPanic = herodot.DefaultError{ + StatusField: http.StatusText(http.StatusInternalServerError), + ErrorField: "Code Bug Detected", + ReasonField: "The code ended up at a place where it should not have. Please report this as an issue at https://github.com/ory/kratos", + CodeField: http.StatusInternalServerError, + } + PageTokenInvalid = herodot.ErrBadRequest.WithReason("The page token is invalid, do not craft your own page tokens") +) func RecoverStatusCode(err error, fallback int) int { var sc herodot.StatusCodeCarrier From 5d00bc9b240c58c0c58e29fa5c839da67648b813 Mon Sep 17 00:00:00 2001 From: zepatrik Date: Mon, 12 Aug 2024 10:14:12 +0200 Subject: [PATCH 3/4] fix: parse page tokens properly --- persistence/sql/identity/persister_identity.go | 2 +- persistence/sql/persister_courier.go | 2 +- persistence/sql/persister_session.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 0f415e5c1ad..acff91fb845 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -799,7 +799,7 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. attribute.Stringer("network.id", p.NetworkID(ctx)))...)) defer otelx.End(span, &err) - if _, err := uuid.FromString(paginator.Token().Encode()); err != nil { + if _, err := uuid.FromString(paginator.Token().Parse("id")["id"]); err != nil { return nil, nil, errors.WithStack(x.PageTokenInvalid) } diff --git a/persistence/sql/persister_courier.go b/persistence/sql/persister_courier.go index c63728869bc..ec9694924f6 100644 --- a/persistence/sql/persister_courier.go +++ b/persistence/sql/persister_courier.go @@ -58,7 +58,7 @@ func (p *Persister) ListMessages(ctx context.Context, filter courier.ListCourier opts = append(opts, keysetpagination.WithColumn("created_at", "DESC")) paginator := keysetpagination.GetPaginator(opts...) - if _, err := uuid.FromString(paginator.Token().Encode()); err != nil { + if _, err := uuid.FromString(paginator.Token().Parse("id")["id"]); err != nil { return nil, 0, nil, errors.WithStack(x.PageTokenInvalid) } diff --git a/persistence/sql/persister_session.go b/persistence/sql/persister_session.go index 10bcadffa4b..6279af56d26 100644 --- a/persistence/sql/persister_session.go +++ b/persistence/sql/persister_session.go @@ -83,7 +83,7 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt paginatorOpts = append(paginatorOpts, keysetpagination.WithColumn("created_at", "DESC")) paginator := keysetpagination.GetPaginator(paginatorOpts...) - if _, err := uuid.FromString(paginator.Token().Encode()); err != nil { + if _, err := uuid.FromString(paginator.Token().Parse("id")["id"]); err != nil { return nil, 0, nil, errors.WithStack(x.PageTokenInvalid) } From 97cea84d6742c2a3923f29c6ebf0c821d7f43a29 Mon Sep 17 00:00:00 2001 From: zepatrik Date: Tue, 27 Aug 2024 18:29:02 +0200 Subject: [PATCH 4/4] test: update assertions --- courier/handler_test.go | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/courier/handler_test.go b/courier/handler_test.go index 28a7ec55d8b..f4ca48ba30d 100644 --- a/courier/handler_test.go +++ b/courier/handler_test.go @@ -57,7 +57,7 @@ func TestHandler(t *testing.T) { conf.MustSet(ctx, config.ViperKeyAdminBaseURL, adminTS.URL) conf.MustSet(ctx, config.ViperKeyPublicBaseURL, mockServerURL.String()) - var get = func(t *testing.T, base *httptest.Server, href string, expectCode int) gjson.Result { + get := func(t *testing.T, base *httptest.Server, href string, expectCode int) gjson.Result { t.Helper() res, err := base.Client().Get(base.URL + href) require.NoError(t, err) @@ -69,7 +69,7 @@ func TestHandler(t *testing.T) { return gjson.ParseBytes(body) } - var getList = func(t *testing.T, tsName string, qs string) gjson.Result { + getList := func(t *testing.T, tsName string, qs string) gjson.Result { t.Helper() href := courier.AdminRouteListMessages + qs ts := adminTS @@ -109,12 +109,12 @@ func TestHandler(t *testing.T) { } require.NoError(t, reg.CourierPersister().AddMessage(context.Background(), &messages[i])) } - for i := 0; i < procCount; i++ { + for i := range procCount { require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[i].ID, courier.MessageStatusProcessing)) } t.Run("paging", func(t *testing.T) { - t.Run("case=should return half of the messages", func(t *testing.T) { + t.Run("case=should return first half of the messages", func(t *testing.T) { qs := fmt.Sprintf("?page_token=%s&page_size=%d", defaultPageToken, msgCount/2) for _, tc := range tss { @@ -124,7 +124,7 @@ func TestHandler(t *testing.T) { }) } }) - t.Run("case=should return no message", func(t *testing.T) { + t.Run("case=should error with random page token", func(t *testing.T) { token := keysetpagination.MapPageToken{ "id": "1232", "created_at": time.Now().Add(time.Duration(-10) * time.Hour).Format("2006-01-02 15:04:05.99999-07:00"), @@ -133,8 +133,13 @@ func TestHandler(t *testing.T) { for _, tc := range tss { t.Run("endpoint="+tc.name, func(t *testing.T) { - parsed := getList(t, tc.name, qs) - assert.Len(t, parsed.Array(), 0) + path := courier.AdminRouteListMessages + qs + if tc.name == "public" { + path = x.AdminPrefix + path + } + resp, err := tc.s.Client().Get(tc.s.URL + path) + require.NoError(t, err) + assert.Equal(t, http.StatusBadRequest, resp.StatusCode) }) } })