Skip to content

Commit

Permalink
fix: validate page tokens for better error codes (#4021)
Browse files Browse the repository at this point in the history
  • Loading branch information
zepatrik authored Sep 2, 2024
1 parent ff6ed5b commit 32737dc
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 13 deletions.
19 changes: 12 additions & 7 deletions courier/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func TestHandler(t *testing.T) {
conf.MustSet(ctx, config.ViperKeyAdminBaseURL, adminTS.URL)
conf.MustSet(ctx, config.ViperKeyPublicBaseURL, mockServerURL.String())

var get = func(t *testing.T, base *httptest.Server, href string, expectCode int) gjson.Result {
get := func(t *testing.T, base *httptest.Server, href string, expectCode int) gjson.Result {
t.Helper()
res, err := base.Client().Get(base.URL + href)
require.NoError(t, err)
Expand All @@ -69,7 +69,7 @@ func TestHandler(t *testing.T) {
return gjson.ParseBytes(body)
}

var getList = func(t *testing.T, tsName string, qs string) gjson.Result {
getList := func(t *testing.T, tsName string, qs string) gjson.Result {
t.Helper()
href := courier.AdminRouteListMessages + qs
ts := adminTS
Expand Down Expand Up @@ -109,12 +109,12 @@ func TestHandler(t *testing.T) {
}
require.NoError(t, reg.CourierPersister().AddMessage(context.Background(), &messages[i]))
}
for i := 0; i < procCount; i++ {
for i := range procCount {
require.NoError(t, reg.CourierPersister().SetMessageStatus(context.Background(), messages[i].ID, courier.MessageStatusProcessing))
}

t.Run("paging", func(t *testing.T) {
t.Run("case=should return half of the messages", func(t *testing.T) {
t.Run("case=should return first half of the messages", func(t *testing.T) {
qs := fmt.Sprintf("?page_token=%s&page_size=%d", defaultPageToken, msgCount/2)

for _, tc := range tss {
Expand All @@ -124,7 +124,7 @@ func TestHandler(t *testing.T) {
})
}
})
t.Run("case=should return no message", func(t *testing.T) {
t.Run("case=should error with random page token", func(t *testing.T) {
token := keysetpagination.MapPageToken{
"id": "1232",
"created_at": time.Now().Add(time.Duration(-10) * time.Hour).Format("2006-01-02 15:04:05.99999-07:00"),
Expand All @@ -133,8 +133,13 @@ func TestHandler(t *testing.T) {

for _, tc := range tss {
t.Run("endpoint="+tc.name, func(t *testing.T) {
parsed := getList(t, tc.name, qs)
assert.Len(t, parsed.Array(), 0)
path := courier.AdminRouteListMessages + qs
if tc.name == "public" {
path = x.AdminPrefix + path
}
resp, err := tc.s.Client().Get(tc.s.URL + path)
require.NoError(t, err)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
})
}
})
Expand Down
4 changes: 4 additions & 0 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,10 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.
attribute.Stringer("network.id", p.NetworkID(ctx)))...))
defer otelx.End(span, &err)

if _, err := uuid.FromString(paginator.Token().Parse("id")["id"]); err != nil {
return nil, nil, errors.WithStack(x.PageTokenInvalid)
}

nid := p.NetworkID(ctx)
var is []identity.Identity

Expand Down
5 changes: 5 additions & 0 deletions persistence/sql/persister_courier.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/ory/kratos/courier"
"github.com/ory/kratos/persistence/sql/update"
"github.com/ory/kratos/x"
)

var _ courier.Persister = new(Persister)
Expand Down Expand Up @@ -57,6 +58,10 @@ func (p *Persister) ListMessages(ctx context.Context, filter courier.ListCourier
opts = append(opts, keysetpagination.WithColumn("created_at", "DESC"))
paginator := keysetpagination.GetPaginator(opts...)

if _, err := uuid.FromString(paginator.Token().Parse("id")["id"]); err != nil {
return nil, 0, nil, errors.WithStack(x.PageTokenInvalid)
}

messages := make([]courier.Message, paginator.Size())
if err := q.Scope(keysetpagination.Paginate[courier.Message](paginator)).
All(&messages); err != nil {
Expand Down
5 changes: 5 additions & 0 deletions persistence/sql/persister_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/ory/kratos/identity"
"github.com/ory/kratos/session"
"github.com/ory/kratos/x"
"github.com/ory/kratos/x/events"
"github.com/ory/x/otelx"
"github.com/ory/x/pagination/keysetpagination"
Expand Down Expand Up @@ -82,6 +83,10 @@ func (p *Persister) ListSessions(ctx context.Context, active *bool, paginatorOpt
paginatorOpts = append(paginatorOpts, keysetpagination.WithColumn("created_at", "DESC"))
paginator := keysetpagination.GetPaginator(paginatorOpts...)

if _, err := uuid.FromString(paginator.Token().Parse("id")["id"]); err != nil {
return nil, 0, nil, errors.WithStack(x.PageTokenInvalid)
}

if err := p.Transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
q := c.Where("nid = ?", nid)
if active != nil {
Expand Down
15 changes: 9 additions & 6 deletions x/err.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ import (
"github.com/ory/herodot"
)

var PseudoPanic = herodot.DefaultError{
StatusField: http.StatusText(http.StatusInternalServerError),
ErrorField: "Code Bug Detected",
ReasonField: "The code ended up at a place where it should not have. Please report this as an issue at https://github.com/ory/kratos",
CodeField: http.StatusInternalServerError,
}
var (
PseudoPanic = herodot.DefaultError{
StatusField: http.StatusText(http.StatusInternalServerError),
ErrorField: "Code Bug Detected",
ReasonField: "The code ended up at a place where it should not have. Please report this as an issue at https://github.com/ory/kratos",
CodeField: http.StatusInternalServerError,
}
PageTokenInvalid = herodot.ErrBadRequest.WithReason("The page token is invalid, do not craft your own page tokens")
)

func RecoverStatusCode(err error, fallback int) int {
var sc herodot.StatusCodeCarrier
Expand Down

0 comments on commit 32737dc

Please sign in to comment.