Skip to content

Commit

Permalink
feat: allow fuzzy-search on credential identifiers (#3526)
Browse files Browse the repository at this point in the history
This PR adds the ability to search for sub-strings and similar strings in credential identifiers.

Note that the **postgres** and **CRDB** migrations create special indexes useful for this feature. To use [online schema changes](https://www.cockroachlabs.com/docs/v23.1/online-schema-changes) with cockroach, we recommend to manually copy the index definition and run it before applying migrations. The migration will then be a no-op.

If you run on **mysql** (or **sqlite**), no special index is created. If desired, you can create such an index manually, and it would be highly appreciated if you could contribute its definition.

This feature is a preview and will change in behavior! Similarity search is not expected to return deterministic results but are useful for humans.
  • Loading branch information
zepatrik authored Oct 2, 2023
1 parent 39b0c3c commit 2cb3ea2
Show file tree
Hide file tree
Showing 21 changed files with 216 additions and 140 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- sdk-generate
services:
postgres:
image: postgres:9.6
image: postgres:11.8
env:
POSTGRES_DB: postgres
POSTGRES_PASSWORD: test
Expand Down Expand Up @@ -111,7 +111,7 @@ jobs:
- sdk-generate
services:
postgres:
image: postgres:9.6
image: postgres:11.8
env:
POSTGRES_DB: postgres
POSTGRES_PASSWORD: test
Expand Down Expand Up @@ -222,7 +222,7 @@ jobs:
- sdk-generate
services:
postgres:
image: postgres:9.6
image: postgres:11.8
env:
POSTGRES_DB: postgres
POSTGRES_PASSWORD: test
Expand Down
9 changes: 4 additions & 5 deletions cmd/identities/delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@ import (
)

func TestDeleteCmd(t *testing.T) {
c := identities.NewDeleteIdentityCmd()
reg := setup(t, c)
reg, cmd := setup(t, identities.NewDeleteIdentityCmd)

t.Run("case=deletes successfully", func(t *testing.T) {
// create identity to delete
i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
require.NoError(t, reg.Persister().CreateIdentity(context.Background(), i))

stdOut := execNoErr(t, c, i.ID.String())
stdOut := cmd.ExecNoErr(t, i.ID.String())

// expect ID and no error
assert.Equal(t, i.ID.String(), gjson.Parse(stdOut).String())
Expand All @@ -44,7 +43,7 @@ func TestDeleteCmd(t *testing.T) {
t.Run("case=deletes three identities", func(t *testing.T) {
is, ids := makeIdentities(t, reg, 3)

stdOut := execNoErr(t, c, ids...)
stdOut := cmd.ExecNoErr(t, ids...)

assert.Equal(t, `["`+strings.Join(ids, "\",\"")+"\"]\n", stdOut)

Expand All @@ -55,7 +54,7 @@ func TestDeleteCmd(t *testing.T) {
})

t.Run("case=fails with unknown ID", func(t *testing.T) {
stdErr := execErr(t, c, x.NewUUID().String())
stdErr := cmd.ExecExpectedErr(t, x.NewUUID().String())

assert.Contains(t, stdErr, "Unable to locate the resource", stdErr)
})
Expand Down
12 changes: 5 additions & 7 deletions cmd/identities/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@ import (
)

func TestGetCmd(t *testing.T) {
c := identities.NewGetIdentityCmd()
reg := setup(t, c)
reg, cmd := setup(t, identities.NewGetIdentityCmd)

t.Run("case=gets a single identity", func(t *testing.T) {
i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
i.MetadataPublic = []byte(`"public"`)
i.MetadataAdmin = []byte(`"admin"`)
require.NoError(t, reg.Persister().CreateIdentity(context.Background(), i))

stdOut := execNoErr(t, c, i.ID.String())
stdOut := cmd.ExecNoErr(t, i.ID.String())

ij, err := json.Marshal(identity.WithCredentialsMetadataAndAdminMetadataInJSON(*i))
require.NoError(t, err)
Expand All @@ -42,7 +41,7 @@ func TestGetCmd(t *testing.T) {
t.Run("case=gets three identities", func(t *testing.T) {
is, ids := makeIdentities(t, reg, 3)

stdOut := execNoErr(t, c, ids...)
stdOut := cmd.ExecNoErr(t, ids...)

isj, err := json.Marshal(is)
require.NoError(t, err)
Expand All @@ -51,7 +50,7 @@ func TestGetCmd(t *testing.T) {
})

t.Run("case=fails with unknown ID", func(t *testing.T) {
stdErr := execErr(t, c, x.NewUUID().String())
stdErr := cmd.ExecExpectedErr(t, x.NewUUID().String())

assert.Contains(t, stdErr, "Unable to locate the resource", stdErr)
})
Expand Down Expand Up @@ -100,10 +99,9 @@ func TestGetCmd(t *testing.T) {
di := i.CopyWithoutCredentials()
di.SetCredentials(identity.CredentialsTypeOIDC, applyCredentials("uniqueIdentifier", "accessBar", "refreshBar", "idBar", false))

require.NoError(t, c.Flags().Set(identities.FlagIncludeCreds, "oidc"))
require.NoError(t, reg.Persister().CreateIdentity(context.Background(), i))

stdOut := execNoErr(t, c, i.ID.String())
stdOut := cmd.ExecNoErr(t, "--"+identities.FlagIncludeCreds, "oidc", i.ID.String())
ij, err := json.Marshal(identity.WithCredentialsAndAdminMetadataInJSON(*di))
require.NoError(t, err)

Expand Down
46 changes: 9 additions & 37 deletions cmd/identities/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@
package identities_test

import (
"bytes"
"context"
"io"
"testing"

"github.com/ory/x/cmdx"

"github.com/pkg/errors"

"github.com/ory/kratos/identity"

"github.com/spf13/cobra"
Expand All @@ -25,44 +21,20 @@ import (
"github.com/ory/kratos/internal/testhelpers"
)

func setup(t *testing.T, cmd *cobra.Command) driver.Registry {
func setup(t *testing.T, newCmd func() *cobra.Command) (driver.Registry, *cmdx.CommandExecuter) {
conf, reg := internal.NewFastRegistryWithMocks(t)
_, admin := testhelpers.NewKratosServerWithCSRF(t, reg)
testhelpers.SetDefaultIdentitySchema(conf, "file://./stubs/identity.schema.json")
// setup command
cliclient.RegisterClientFlags(cmd.Flags())
cmdx.RegisterFormatFlags(cmd.Flags())
require.NoError(t, cmd.Flags().Set(cliclient.FlagEndpoint, admin.URL))
require.NoError(t, cmd.Flags().Set(cmdx.FlagFormat, string(cmdx.FormatJSON)))
return reg
}

func exec(cmd *cobra.Command, stdIn io.Reader, args ...string) (string, string, error) {
stdOut, stdErr := &bytes.Buffer{}, &bytes.Buffer{}
cmd.SetErr(stdErr)
cmd.SetOut(stdOut)
cmd.SetIn(stdIn)
defer cmd.SetIn(nil)
if args == nil {
args = []string{}
return reg, &cmdx.CommandExecuter{
New: func() *cobra.Command {
cmd := newCmd()
cliclient.RegisterClientFlags(cmd.Flags())
cmdx.RegisterFormatFlags(cmd.Flags())
return cmd
},
PersistentArgs: []string{"--" + cliclient.FlagEndpoint, admin.URL, "--" + cmdx.FlagFormat, string(cmdx.FormatJSON)},
}
cmd.SetArgs(args)
err := cmd.Execute()
return stdOut.String(), stdErr.String(), err
}

func execNoErr(t *testing.T, cmd *cobra.Command, args ...string) string {
stdOut, stdErr, err := exec(cmd, nil, args...)
require.NoError(t, err, "stdout: %s\nstderr: %s", stdOut, stdErr)
require.Len(t, stdErr, 0, stdOut)
return stdOut
}

func execErr(t *testing.T, cmd *cobra.Command, args ...string) string {
stdOut, stdErr, err := exec(cmd, nil, args...)
require.True(t, errors.Is(err, cmdx.ErrNoPrintButFail))
require.Len(t, stdOut, 0, stdErr)
return stdErr
}

func makeIdentities(t *testing.T, reg driver.Registry, n int) (is []*identity.Identity, ids []string) {
Expand Down
11 changes: 5 additions & 6 deletions cmd/identities/import_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import (
)

func TestImportCmd(t *testing.T) {
c := identities.NewImportIdentitiesCmd()
reg := setup(t, c)
reg, cmd := setup(t, identities.NewImportIdentitiesCmd)

t.Run("case=imports a new identity from file", func(t *testing.T) {
i := kratos.CreateIdentityBody{
Expand All @@ -40,7 +39,7 @@ func TestImportCmd(t *testing.T) {
require.NoError(t, err)
require.NoError(t, f.Close())

stdOut := execNoErr(t, c, f.Name())
stdOut := cmd.ExecNoErr(t, f.Name())

id, err := uuid.FromString(gjson.Get(stdOut, "id").String())
require.NoError(t, err)
Expand All @@ -67,7 +66,7 @@ func TestImportCmd(t *testing.T) {
require.NoError(t, err)
require.NoError(t, f.Close())

stdOut := execNoErr(t, c, f.Name())
stdOut := cmd.ExecNoErr(t, f.Name())

id, err := uuid.FromString(gjson.Get(stdOut, "0.id").String())
require.NoError(t, err)
Expand All @@ -94,7 +93,7 @@ func TestImportCmd(t *testing.T) {
ij, err := json.Marshal(i)
require.NoError(t, err)

stdOut, stdErr, err := exec(c, bytes.NewBuffer(ij))
stdOut, stdErr, err := cmd.Exec(bytes.NewBuffer(ij))
require.NoError(t, err, "%s %s", stdOut, stdErr)

id, err := uuid.FromString(gjson.Get(stdOut, "0.id").String())
Expand All @@ -116,7 +115,7 @@ func TestImportCmd(t *testing.T) {
ij, err := json.Marshal(i)
require.NoError(t, err)

stdOut, stdErr, err := exec(c, bytes.NewBuffer(ij))
stdOut, stdErr, err := cmd.Exec(bytes.NewBuffer(ij))
require.NoError(t, err, "%s %s", stdOut, stdErr)

id, err := uuid.FromString(gjson.Get(stdOut, "id").String())
Expand Down
30 changes: 11 additions & 19 deletions cmd/identities/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ import (
)

func TestListCmd(t *testing.T) {
c := identities.NewListIdentitiesCmd()
reg := setup(t, c)
reg, cmd := setup(t, identities.NewListIdentitiesCmd)

var deleteIdentities = func(t *testing.T, is []*identity.Identity) {
for _, i := range is {
Expand All @@ -31,13 +30,11 @@ func TestListCmd(t *testing.T) {

t.Run("case=lists all identities with default pagination", func(t *testing.T) {
is, ids := makeIdentities(t, reg, 5)
require.NoError(t, c.Flags().Set(cmdx.FlagQuiet, "true"))
t.Cleanup(func() {
require.NoError(t, c.Flags().Set(cmdx.FlagQuiet, "false"))
deleteIdentities(t, is)
})

stdOut := cmdx.ExecNoErr(t, c)
stdOut := cmd.ExecNoErr(t, "--"+cmdx.FlagQuiet)

for _, i := range ids {
assert.Contains(t, stdOut, i)
Expand All @@ -50,25 +47,20 @@ func TestListCmd(t *testing.T) {
deleteIdentities(t, is)
})

first := cmdx.ExecNoErr(t, c, "--format", "json-pretty", "--page-size", "2")
first := cmd.ExecNoErr(t, "--format", "json-pretty", "--page-size", "2")
nextPageToken := gjson.Get(first, "next_page_token").String()
results := gjson.Get(first, "identities").Array()
actualIDs := gjson.Get(first, "identities.#.id").Array()
for nextPageToken != "" {
next := cmdx.ExecNoErr(t, c, "--format", "json-pretty", "--page-size", "2", "--page-token", nextPageToken)
results = append(results, gjson.Get(next, "identities").Array()...)
next := cmd.ExecNoErr(t, "--page-size", "2", "--page-token", nextPageToken)
actualIDs = append(actualIDs, gjson.Get(next, "identities.#.id").Array()...)
nextPageToken = gjson.Get(next, "next_page_token").String()
}

assert.Len(t, results, len(ids))
for _, expected := range ids {
var found bool
for _, actual := range results {
if actual.Get("id").String() == expected {
found = true
break
}
}
require.True(t, found, "could not find id: %s", expected)
actualIDsString := make([]string, len(actualIDs))
for i, id := range actualIDs {
actualIDsString[i] = id.Str
}

assert.ElementsMatch(t, ids, actualIDsString)
})
}
21 changes: 19 additions & 2 deletions identity/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,22 @@ type listIdentitiesResponse struct {
type listIdentitiesParameters struct {
migrationpagination.RequestParameters

// CredentialsIdentifier is the identifier (username, email) of the credentials to look up.
// CredentialsIdentifier is the identifier (username, email) of the credentials to look up using exact match.
// Only one of CredentialsIdentifier and CredentialsIdentifierSimilar can be used.
//
// required: false
// in: query
CredentialsIdentifier string `json:"credentials_identifier"`

// This is an EXPERIMENTAL parameter that WILL CHANGE. Do NOT rely on consistent, deterministic behavior.
// THIS PARAMETER WILL BE REMOVED IN AN UPCOMING RELEASE WITHOUT ANY MIGRATION PATH.
//
// CredentialsIdentifierSimilar is the (partial) identifier (username, email) of the credentials to look up using similarity search.
// Only one of CredentialsIdentifier and CredentialsIdentifierSimilar can be used.
//
// required: false
// in: query
CredentialsIdentifierSimilar string `json:"preview_credentials_identifier_similar"`
}

// swagger:route GET /admin/identities identity listIdentities
Expand All @@ -160,7 +171,13 @@ type listIdentitiesParameters struct {
func (h *Handler) list(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
page, itemsPerPage := x.ParsePagination(r)

params := ListIdentityParameters{Expand: ExpandDefault, Page: page, PerPage: itemsPerPage, CredentialsIdentifier: r.URL.Query().Get("credentials_identifier")}
params := ListIdentityParameters{
Expand: ExpandDefault,
Page: page,
PerPage: itemsPerPage,
CredentialsIdentifier: r.URL.Query().Get("credentials_identifier"),
CredentialsIdentifierSimilar: r.URL.Query().Get("preview_credentials_identifier_similar"),
}
if params.CredentialsIdentifier != "" {
params.Expand = ExpandEverything
}
Expand Down
9 changes: 5 additions & 4 deletions identity/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ import (

type (
ListIdentityParameters struct {
Expand Expandables
CredentialsIdentifier string
Page int
PerPage int
Expand Expandables
CredentialsIdentifier string
CredentialsIdentifierSimilar string
Page int
PerPage int
}

Pool interface {
Expand Down
29 changes: 27 additions & 2 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,8 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister,
t.Run("case=list", func(t *testing.T) {
is, err := p.ListIdentities(ctx, identity.ListIdentityParameters{Expand: identity.ExpandDefault, Page: 0, PerPage: 25})
require.NoError(t, err)
assert.Len(t, is, len(createdIDs))
require.NotZero(t, len(is))
require.Len(t, is, len(createdIDs))
for _, id := range createdIDs {
var found bool
for _, i := range is {
Expand Down Expand Up @@ -686,7 +687,7 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister,
Expand: identity.ExpandEverything,
})
require.NoError(t, err)
require.True(t, len(actual) > 0)
require.Greater(t, len(actual), 0)

for c, ct := range []identity.CredentialsType{
identity.CredentialsTypePassword,
Expand All @@ -705,6 +706,30 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister,
})
}

t.Run("similarity search", func(t *testing.T) {
actual, err := p.ListIdentities(ctx, identity.ListIdentityParameters{
CredentialsIdentifierSimilar: "find-identity-by-identifier",
Expand: identity.ExpandCredentials,
})
require.NoError(t, err)
assert.Len(t, actual, 3)

outer:
for _, e := range append(expectedIdentities[:2], create) {
for _, a := range actual {
if e.ID == a.ID {
assertx.EqualAsJSONExcept(t, e, a, []string{"credentials.config", "created_at", "updated_at", "state_changed_at"})
continue outer
}
}
actualCredentials := make([]map[identity.CredentialsType]identity.Credentials, len(actual))
for k, a := range actual {
actualCredentials[k] = a.Credentials
}
t.Fatalf("expected identity %+v not found in actual result set %+v", e.Credentials, actualCredentials)
}
})

t.Run("only webauthn and password", func(t *testing.T) {
actual, err := p.ListIdentities(ctx, identity.ListIdentityParameters{
CredentialsIdentifier: "[email protected]",
Expand Down
Loading

0 comments on commit 2cb3ea2

Please sign in to comment.