diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 94990e8d4827..11b0dddfae49 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -28,7 +28,7 @@ jobs: - sdk-generate services: postgres: - image: postgres:9.6 + image: postgres:11.8 env: POSTGRES_DB: postgres POSTGRES_PASSWORD: test @@ -111,7 +111,7 @@ jobs: - sdk-generate services: postgres: - image: postgres:9.6 + image: postgres:11.8 env: POSTGRES_DB: postgres POSTGRES_PASSWORD: test @@ -222,7 +222,7 @@ jobs: - sdk-generate services: postgres: - image: postgres:9.6 + image: postgres:11.8 env: POSTGRES_DB: postgres POSTGRES_PASSWORD: test diff --git a/cmd/identities/delete_test.go b/cmd/identities/delete_test.go index 70b7282b365d..15a1993c7c94 100644 --- a/cmd/identities/delete_test.go +++ b/cmd/identities/delete_test.go @@ -23,15 +23,14 @@ import ( ) func TestDeleteCmd(t *testing.T) { - c := identities.NewDeleteIdentityCmd() - reg := setup(t, c) + reg, cmd := setup(t, identities.NewDeleteIdentityCmd) t.Run("case=deletes successfully", func(t *testing.T) { // create identity to delete i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) require.NoError(t, reg.Persister().CreateIdentity(context.Background(), i)) - stdOut := execNoErr(t, c, i.ID.String()) + stdOut := cmd.ExecNoErr(t, i.ID.String()) // expect ID and no error assert.Equal(t, i.ID.String(), gjson.Parse(stdOut).String()) @@ -44,7 +43,7 @@ func TestDeleteCmd(t *testing.T) { t.Run("case=deletes three identities", func(t *testing.T) { is, ids := makeIdentities(t, reg, 3) - stdOut := execNoErr(t, c, ids...) + stdOut := cmd.ExecNoErr(t, ids...) assert.Equal(t, `["`+strings.Join(ids, "\",\"")+"\"]\n", stdOut) @@ -55,7 +54,7 @@ func TestDeleteCmd(t *testing.T) { }) t.Run("case=fails with unknown ID", func(t *testing.T) { - stdErr := execErr(t, c, x.NewUUID().String()) + stdErr := cmd.ExecExpectedErr(t, x.NewUUID().String()) assert.Contains(t, stdErr, "Unable to locate the resource", stdErr) }) diff --git a/cmd/identities/get_test.go b/cmd/identities/get_test.go index eafe31520828..03a1291d5872 100644 --- a/cmd/identities/get_test.go +++ b/cmd/identities/get_test.go @@ -22,8 +22,7 @@ import ( ) func TestGetCmd(t *testing.T) { - c := identities.NewGetIdentityCmd() - reg := setup(t, c) + reg, cmd := setup(t, identities.NewGetIdentityCmd) t.Run("case=gets a single identity", func(t *testing.T) { i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) @@ -31,7 +30,7 @@ func TestGetCmd(t *testing.T) { i.MetadataAdmin = []byte(`"admin"`) require.NoError(t, reg.Persister().CreateIdentity(context.Background(), i)) - stdOut := execNoErr(t, c, i.ID.String()) + stdOut := cmd.ExecNoErr(t, i.ID.String()) ij, err := json.Marshal(identity.WithCredentialsMetadataAndAdminMetadataInJSON(*i)) require.NoError(t, err) @@ -42,7 +41,7 @@ func TestGetCmd(t *testing.T) { t.Run("case=gets three identities", func(t *testing.T) { is, ids := makeIdentities(t, reg, 3) - stdOut := execNoErr(t, c, ids...) + stdOut := cmd.ExecNoErr(t, ids...) isj, err := json.Marshal(is) require.NoError(t, err) @@ -51,7 +50,7 @@ func TestGetCmd(t *testing.T) { }) t.Run("case=fails with unknown ID", func(t *testing.T) { - stdErr := execErr(t, c, x.NewUUID().String()) + stdErr := cmd.ExecExpectedErr(t, x.NewUUID().String()) assert.Contains(t, stdErr, "Unable to locate the resource", stdErr) }) @@ -100,10 +99,9 @@ func TestGetCmd(t *testing.T) { di := i.CopyWithoutCredentials() di.SetCredentials(identity.CredentialsTypeOIDC, applyCredentials("uniqueIdentifier", "accessBar", "refreshBar", "idBar", false)) - require.NoError(t, c.Flags().Set(identities.FlagIncludeCreds, "oidc")) require.NoError(t, reg.Persister().CreateIdentity(context.Background(), i)) - stdOut := execNoErr(t, c, i.ID.String()) + stdOut := cmd.ExecNoErr(t, "--"+identities.FlagIncludeCreds, "oidc", i.ID.String()) ij, err := json.Marshal(identity.WithCredentialsAndAdminMetadataInJSON(*di)) require.NoError(t, err) diff --git a/cmd/identities/helpers_test.go b/cmd/identities/helpers_test.go index b0be50ba4d4c..5997b32c7623 100644 --- a/cmd/identities/helpers_test.go +++ b/cmd/identities/helpers_test.go @@ -4,15 +4,11 @@ package identities_test import ( - "bytes" "context" - "io" "testing" "github.com/ory/x/cmdx" - "github.com/pkg/errors" - "github.com/ory/kratos/identity" "github.com/spf13/cobra" @@ -25,44 +21,20 @@ import ( "github.com/ory/kratos/internal/testhelpers" ) -func setup(t *testing.T, cmd *cobra.Command) driver.Registry { +func setup(t *testing.T, newCmd func() *cobra.Command) (driver.Registry, *cmdx.CommandExecuter) { conf, reg := internal.NewFastRegistryWithMocks(t) _, admin := testhelpers.NewKratosServerWithCSRF(t, reg) testhelpers.SetDefaultIdentitySchema(conf, "file://./stubs/identity.schema.json") // setup command - cliclient.RegisterClientFlags(cmd.Flags()) - cmdx.RegisterFormatFlags(cmd.Flags()) - require.NoError(t, cmd.Flags().Set(cliclient.FlagEndpoint, admin.URL)) - require.NoError(t, cmd.Flags().Set(cmdx.FlagFormat, string(cmdx.FormatJSON))) - return reg -} - -func exec(cmd *cobra.Command, stdIn io.Reader, args ...string) (string, string, error) { - stdOut, stdErr := &bytes.Buffer{}, &bytes.Buffer{} - cmd.SetErr(stdErr) - cmd.SetOut(stdOut) - cmd.SetIn(stdIn) - defer cmd.SetIn(nil) - if args == nil { - args = []string{} + return reg, &cmdx.CommandExecuter{ + New: func() *cobra.Command { + cmd := newCmd() + cliclient.RegisterClientFlags(cmd.Flags()) + cmdx.RegisterFormatFlags(cmd.Flags()) + return cmd + }, + PersistentArgs: []string{"--" + cliclient.FlagEndpoint, admin.URL, "--" + cmdx.FlagFormat, string(cmdx.FormatJSON)}, } - cmd.SetArgs(args) - err := cmd.Execute() - return stdOut.String(), stdErr.String(), err -} - -func execNoErr(t *testing.T, cmd *cobra.Command, args ...string) string { - stdOut, stdErr, err := exec(cmd, nil, args...) - require.NoError(t, err, "stdout: %s\nstderr: %s", stdOut, stdErr) - require.Len(t, stdErr, 0, stdOut) - return stdOut -} - -func execErr(t *testing.T, cmd *cobra.Command, args ...string) string { - stdOut, stdErr, err := exec(cmd, nil, args...) - require.True(t, errors.Is(err, cmdx.ErrNoPrintButFail)) - require.Len(t, stdOut, 0, stdErr) - return stdErr } func makeIdentities(t *testing.T, reg driver.Registry, n int) (is []*identity.Identity, ids []string) { diff --git a/cmd/identities/import_test.go b/cmd/identities/import_test.go index 567cac256957..8db159de9cff 100644 --- a/cmd/identities/import_test.go +++ b/cmd/identities/import_test.go @@ -24,8 +24,7 @@ import ( ) func TestImportCmd(t *testing.T) { - c := identities.NewImportIdentitiesCmd() - reg := setup(t, c) + reg, cmd := setup(t, identities.NewImportIdentitiesCmd) t.Run("case=imports a new identity from file", func(t *testing.T) { i := kratos.CreateIdentityBody{ @@ -40,7 +39,7 @@ func TestImportCmd(t *testing.T) { require.NoError(t, err) require.NoError(t, f.Close()) - stdOut := execNoErr(t, c, f.Name()) + stdOut := cmd.ExecNoErr(t, f.Name()) id, err := uuid.FromString(gjson.Get(stdOut, "id").String()) require.NoError(t, err) @@ -67,7 +66,7 @@ func TestImportCmd(t *testing.T) { require.NoError(t, err) require.NoError(t, f.Close()) - stdOut := execNoErr(t, c, f.Name()) + stdOut := cmd.ExecNoErr(t, f.Name()) id, err := uuid.FromString(gjson.Get(stdOut, "0.id").String()) require.NoError(t, err) @@ -94,7 +93,7 @@ func TestImportCmd(t *testing.T) { ij, err := json.Marshal(i) require.NoError(t, err) - stdOut, stdErr, err := exec(c, bytes.NewBuffer(ij)) + stdOut, stdErr, err := cmd.Exec(bytes.NewBuffer(ij)) require.NoError(t, err, "%s %s", stdOut, stdErr) id, err := uuid.FromString(gjson.Get(stdOut, "0.id").String()) @@ -116,7 +115,7 @@ func TestImportCmd(t *testing.T) { ij, err := json.Marshal(i) require.NoError(t, err) - stdOut, stdErr, err := exec(c, bytes.NewBuffer(ij)) + stdOut, stdErr, err := cmd.Exec(bytes.NewBuffer(ij)) require.NoError(t, err, "%s %s", stdOut, stdErr) id, err := uuid.FromString(gjson.Get(stdOut, "id").String()) diff --git a/cmd/identities/list_test.go b/cmd/identities/list_test.go index e8f0c2769203..d1abb614cbde 100644 --- a/cmd/identities/list_test.go +++ b/cmd/identities/list_test.go @@ -20,8 +20,7 @@ import ( ) func TestListCmd(t *testing.T) { - c := identities.NewListIdentitiesCmd() - reg := setup(t, c) + reg, cmd := setup(t, identities.NewListIdentitiesCmd) var deleteIdentities = func(t *testing.T, is []*identity.Identity) { for _, i := range is { @@ -31,13 +30,11 @@ func TestListCmd(t *testing.T) { t.Run("case=lists all identities with default pagination", func(t *testing.T) { is, ids := makeIdentities(t, reg, 5) - require.NoError(t, c.Flags().Set(cmdx.FlagQuiet, "true")) t.Cleanup(func() { - require.NoError(t, c.Flags().Set(cmdx.FlagQuiet, "false")) deleteIdentities(t, is) }) - stdOut := cmdx.ExecNoErr(t, c) + stdOut := cmd.ExecNoErr(t, "--"+cmdx.FlagQuiet) for _, i := range ids { assert.Contains(t, stdOut, i) @@ -50,25 +47,20 @@ func TestListCmd(t *testing.T) { deleteIdentities(t, is) }) - first := cmdx.ExecNoErr(t, c, "--format", "json-pretty", "--page-size", "2") + first := cmd.ExecNoErr(t, "--format", "json-pretty", "--page-size", "2") nextPageToken := gjson.Get(first, "next_page_token").String() - results := gjson.Get(first, "identities").Array() + actualIDs := gjson.Get(first, "identities.#.id").Array() for nextPageToken != "" { - next := cmdx.ExecNoErr(t, c, "--format", "json-pretty", "--page-size", "2", "--page-token", nextPageToken) - results = append(results, gjson.Get(next, "identities").Array()...) + next := cmd.ExecNoErr(t, "--page-size", "2", "--page-token", nextPageToken) + actualIDs = append(actualIDs, gjson.Get(next, "identities.#.id").Array()...) nextPageToken = gjson.Get(next, "next_page_token").String() } - assert.Len(t, results, len(ids)) - for _, expected := range ids { - var found bool - for _, actual := range results { - if actual.Get("id").String() == expected { - found = true - break - } - } - require.True(t, found, "could not find id: %s", expected) + actualIDsString := make([]string, len(actualIDs)) + for i, id := range actualIDs { + actualIDsString[i] = id.Str } + + assert.ElementsMatch(t, ids, actualIDsString) }) } diff --git a/identity/handler.go b/identity/handler.go index a878a0b0c31f..c7505b0dae6c 100644 --- a/identity/handler.go +++ b/identity/handler.go @@ -133,11 +133,22 @@ type listIdentitiesResponse struct { type listIdentitiesParameters struct { migrationpagination.RequestParameters - // CredentialsIdentifier is the identifier (username, email) of the credentials to look up. + // CredentialsIdentifier is the identifier (username, email) of the credentials to look up using exact match. + // Only one of CredentialsIdentifier and CredentialsIdentifierSimilar can be used. // // required: false // in: query CredentialsIdentifier string `json:"credentials_identifier"` + + // This is an EXPERIMENTAL parameter that WILL CHANGE. Do NOT rely on consistent, deterministic behavior. + // THIS PARAMETER WILL BE REMOVED IN AN UPCOMING RELEASE WITHOUT ANY MIGRATION PATH. + // + // CredentialsIdentifierSimilar is the (partial) identifier (username, email) of the credentials to look up using similarity search. + // Only one of CredentialsIdentifier and CredentialsIdentifierSimilar can be used. + // + // required: false + // in: query + CredentialsIdentifierSimilar string `json:"preview_credentials_identifier_similar"` } // swagger:route GET /admin/identities identity listIdentities @@ -160,7 +171,13 @@ type listIdentitiesParameters struct { func (h *Handler) list(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { page, itemsPerPage := x.ParsePagination(r) - params := ListIdentityParameters{Expand: ExpandDefault, Page: page, PerPage: itemsPerPage, CredentialsIdentifier: r.URL.Query().Get("credentials_identifier")} + params := ListIdentityParameters{ + Expand: ExpandDefault, + Page: page, + PerPage: itemsPerPage, + CredentialsIdentifier: r.URL.Query().Get("credentials_identifier"), + CredentialsIdentifierSimilar: r.URL.Query().Get("preview_credentials_identifier_similar"), + } if params.CredentialsIdentifier != "" { params.Expand = ExpandEverything } diff --git a/identity/pool.go b/identity/pool.go index 1c0cbddb47dc..1cf4888e52ae 100644 --- a/identity/pool.go +++ b/identity/pool.go @@ -13,10 +13,11 @@ import ( type ( ListIdentityParameters struct { - Expand Expandables - CredentialsIdentifier string - Page int - PerPage int + Expand Expandables + CredentialsIdentifier string + CredentialsIdentifierSimilar string + Page int + PerPage int } Pool interface { diff --git a/identity/test/pool.go b/identity/test/pool.go index f695c6dda35d..5adf9fb364b6 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -636,7 +636,8 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, t.Run("case=list", func(t *testing.T) { is, err := p.ListIdentities(ctx, identity.ListIdentityParameters{Expand: identity.ExpandDefault, Page: 0, PerPage: 25}) require.NoError(t, err) - assert.Len(t, is, len(createdIDs)) + require.NotZero(t, len(is)) + require.Len(t, is, len(createdIDs)) for _, id := range createdIDs { var found bool for _, i := range is { @@ -686,7 +687,7 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, Expand: identity.ExpandEverything, }) require.NoError(t, err) - require.True(t, len(actual) > 0) + require.Greater(t, len(actual), 0) for c, ct := range []identity.CredentialsType{ identity.CredentialsTypePassword, @@ -705,6 +706,30 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, }) } + t.Run("similarity search", func(t *testing.T) { + actual, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ + CredentialsIdentifierSimilar: "find-identity-by-identifier", + Expand: identity.ExpandCredentials, + }) + require.NoError(t, err) + assert.Len(t, actual, 3) + + outer: + for _, e := range append(expectedIdentities[:2], create) { + for _, a := range actual { + if e.ID == a.ID { + assertx.EqualAsJSONExcept(t, e, a, []string{"credentials.config", "created_at", "updated_at", "state_changed_at"}) + continue outer + } + } + actualCredentials := make([]map[identity.CredentialsType]identity.Credentials, len(actual)) + for k, a := range actual { + actualCredentials[k] = a.Credentials + } + t.Fatalf("expected identity %+v not found in actual result set %+v", e.Credentials, actualCredentials) + } + }) + t.Run("only webauthn and password", func(t *testing.T) { actual, err := p.ListIdentities(ctx, identity.ListIdentityParameters{ CredentialsIdentifier: "find-identity-by-identifier-oidc@ory.sh", diff --git a/internal/client-go/api_identity.go b/internal/client-go/api_identity.go index 864934fef723..93dcc20dddd1 100644 --- a/internal/client-go/api_identity.go +++ b/internal/client-go/api_identity.go @@ -2045,13 +2045,14 @@ func (a *IdentityApiService) GetSessionExecute(r IdentityApiApiGetSessionRequest } type IdentityApiApiListIdentitiesRequest struct { - ctx context.Context - ApiService IdentityApi - perPage *int64 - page *int64 - pageSize *int64 - pageToken *string - credentialsIdentifier *string + ctx context.Context + ApiService IdentityApi + perPage *int64 + page *int64 + pageSize *int64 + pageToken *string + credentialsIdentifier *string + previewCredentialsIdentifierSimilar *string } func (r IdentityApiApiListIdentitiesRequest) PerPage(perPage int64) IdentityApiApiListIdentitiesRequest { @@ -2074,6 +2075,10 @@ func (r IdentityApiApiListIdentitiesRequest) CredentialsIdentifier(credentialsId r.credentialsIdentifier = &credentialsIdentifier return r } +func (r IdentityApiApiListIdentitiesRequest) PreviewCredentialsIdentifierSimilar(previewCredentialsIdentifierSimilar string) IdentityApiApiListIdentitiesRequest { + r.previewCredentialsIdentifierSimilar = &previewCredentialsIdentifierSimilar + return r +} func (r IdentityApiApiListIdentitiesRequest) Execute() ([]Identity, *http.Response, error) { return r.ApiService.ListIdentitiesExecute(r) @@ -2132,6 +2137,9 @@ func (a *IdentityApiService) ListIdentitiesExecute(r IdentityApiApiListIdentitie if r.credentialsIdentifier != nil { localVarQueryParams.Add("credentials_identifier", parameterToString(*r.credentialsIdentifier, "")) } + if r.previewCredentialsIdentifierSimilar != nil { + localVarQueryParams.Add("preview_credentials_identifier_similar", parameterToString(*r.previewCredentialsIdentifierSimilar, "")) + } // to determine the Content-Type header localVarHTTPContentTypes := []string{} diff --git a/internal/httpclient/api_identity.go b/internal/httpclient/api_identity.go index 864934fef723..93dcc20dddd1 100644 --- a/internal/httpclient/api_identity.go +++ b/internal/httpclient/api_identity.go @@ -2045,13 +2045,14 @@ func (a *IdentityApiService) GetSessionExecute(r IdentityApiApiGetSessionRequest } type IdentityApiApiListIdentitiesRequest struct { - ctx context.Context - ApiService IdentityApi - perPage *int64 - page *int64 - pageSize *int64 - pageToken *string - credentialsIdentifier *string + ctx context.Context + ApiService IdentityApi + perPage *int64 + page *int64 + pageSize *int64 + pageToken *string + credentialsIdentifier *string + previewCredentialsIdentifierSimilar *string } func (r IdentityApiApiListIdentitiesRequest) PerPage(perPage int64) IdentityApiApiListIdentitiesRequest { @@ -2074,6 +2075,10 @@ func (r IdentityApiApiListIdentitiesRequest) CredentialsIdentifier(credentialsId r.credentialsIdentifier = &credentialsIdentifier return r } +func (r IdentityApiApiListIdentitiesRequest) PreviewCredentialsIdentifierSimilar(previewCredentialsIdentifierSimilar string) IdentityApiApiListIdentitiesRequest { + r.previewCredentialsIdentifierSimilar = &previewCredentialsIdentifierSimilar + return r +} func (r IdentityApiApiListIdentitiesRequest) Execute() ([]Identity, *http.Response, error) { return r.ApiService.ListIdentitiesExecute(r) @@ -2132,6 +2137,9 @@ func (a *IdentityApiService) ListIdentitiesExecute(r IdentityApiApiListIdentitie if r.credentialsIdentifier != nil { localVarQueryParams.Add("credentials_identifier", parameterToString(*r.credentialsIdentifier, "")) } + if r.previewCredentialsIdentifierSimilar != nil { + localVarQueryParams.Add("preview_credentials_identifier_similar", parameterToString(*r.previewCredentialsIdentifierSimilar, "")) + } // to determine the Content-Type header localVarHTTPContentTypes := []string{} diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index e2f1d23a882f..99469f365f1b 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -666,56 +666,91 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. con := p.GetConnection(ctx) nid := p.NetworkID(ctx) - query := con.Where("identities.nid = ?", nid).Order("identities.id DESC") - // Credentials are not expanded through `EagerPreload` but manually after - // fetching the identities, hence we filter out the relevant expand options. - var expandExceptCredentials sqlxx.Expandables - for _, e := range params.Expand { - if e != identity.ExpandFieldCredentials { - expandExceptCredentials = append(expandExceptCredentials, e) + joins := "" + wheres := "" + args := []any{nid} + identifier := params.CredentialsIdentifier + identifierOperator := "=" + if identifier == "" && params.CredentialsIdentifierSimilar != "" { + identifier = params.CredentialsIdentifierSimilar + identifierOperator = "%" + switch con.Dialect.Name() { + case "postgres", "cockroach": + default: + identifier = "%" + identifier + "%" + identifierOperator = "LIKE" } } - if len(expandExceptCredentials) > 0 { - query = query.EagerPreload(expandExceptCredentials.ToEager()...) - } - if match := params.CredentialsIdentifier; len(match) > 0 { + if len(identifier) > 0 { // When filtering by credentials identifier, we most likely are looking for a username or email. It is therefore // important to normalize the identifier before querying the database. - match = NormalizeIdentifier(identity.CredentialsTypePassword, match) - query = query. - InnerJoin("identity_credentials ic", "ic.identity_id = identities.id"). - InnerJoin("identity_credential_types ict", "ict.id = ic.identity_credential_type_id"). - InnerJoin("identity_credential_identifiers ici", "ici.identity_credential_id = ic.id"). - Where("(ic.nid = ? AND ici.nid = ? AND ici.identifier = ?)", nid, nid, match). - Where("ict.name IN (?)", identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword). - Limit(1) - } else { - query = query.Paginate(params.Page+1, params.PerPage) - } - - if err := sqlcon.HandleError(query.All(&is)); err != nil { - return nil, err + identifier = NormalizeIdentifier(identity.CredentialsTypePassword, identifier) + + joins = ` +INNER JOIN identity_credentials ic ON ic.identity_id = identities.id +INNER JOIN identity_credential_types ict ON ict.id = ic.identity_credential_type_id +INNER JOIN identity_credential_identifiers ici ON ici.identity_credential_id = ic.id` + wheres = fmt.Sprintf(` +AND (ic.nid = ? AND ici.nid = ? AND ici.identifier %s ?) +AND ict.name IN (?, ?)`, identifierOperator) + args = append(args, nid, nid, identifier, identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword) + } + + // Follow up: add page token support here, will be easy. + paginator := pop.NewPaginator(params.Page+1, params.PerPage) + + if err := con.RawQuery(fmt.Sprintf(`SELECT DISTINCT identities.* +FROM identities AS identities +%s +WHERE identities.nid = ? +%s +ORDER BY identities.id DESC +LIMIT %d +OFFSET %d`, joins, wheres, paginator.PerPage, paginator.Offset), args...).All(&is); err != nil { + return nil, sqlcon.HandleError(err) } if len(is) == 0 { return is, nil } - if params.Expand.Has(identity.ExpandFieldCredentials) { - var ids []interface{} - for k := range is { - ids = append(ids, is[k].ID) - } - creds, err := QueryForCredentials(con, - Where{"identity_credentials.nid = ?", []interface{}{nid}}, - Where{"identity_credentials.identity_id IN (?)", ids}) - if err != nil { - return nil, err - } - for k := range is { - is[k].Credentials = creds[is[k].ID] + identitiesByID := make(map[uuid.UUID]*identity.Identity, len(is)) + identityIDs := make([]any, len(is)) + for k := range is { + identitiesByID[is[k].ID] = &is[k] + identityIDs[k] = is[k].ID + } + + for _, e := range params.Expand { + switch e { + case identity.ExpandFieldCredentials: + creds, err := QueryForCredentials(con, + Where{"identity_credentials.nid = ?", []any{nid}}, + Where{"identity_credentials.identity_id IN (?)", identityIDs}) + if err != nil { + return nil, err + } + for k := range is { + is[k].Credentials = creds[is[k].ID] + } + case identity.ExpandFieldVerifiableAddresses: + addrs := make([]identity.VerifiableAddress, 0) + if err := con.Where("nid = ?", nid).Where("identity_id IN (?)", identityIDs).Order("id").All(&addrs); err != nil { + return nil, sqlcon.HandleError(err) + } + for _, addr := range addrs { + identitiesByID[addr.IdentityID].VerifiableAddresses = append(identitiesByID[addr.IdentityID].VerifiableAddresses, addr) + } + case identity.ExpandFieldRecoveryAddresses: + addrs := make([]identity.RecoveryAddress, 0) + if err := con.Where("nid = ?", nid).Where("identity_id IN (?)", identityIDs).Order("id").All(&addrs); err != nil { + return nil, sqlcon.HandleError(err) + } + for _, addr := range addrs { + identitiesByID[addr.IdentityID].RecoveryAddresses = append(identitiesByID[addr.IdentityID].RecoveryAddresses, addr) + } } } diff --git a/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.cockroach.down.sql b/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.cockroach.down.sql new file mode 100644 index 000000000000..d1c8cdf26841 --- /dev/null +++ b/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.cockroach.down.sql @@ -0,0 +1 @@ +DROP INDEX IF EXISTS identity_credential_identifiers_nid_identifier_gin; diff --git a/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.cockroach.up.sql b/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.cockroach.up.sql new file mode 100644 index 000000000000..f692e9943649 --- /dev/null +++ b/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.cockroach.up.sql @@ -0,0 +1 @@ +CREATE INDEX IF NOT EXISTS identity_credential_identifiers_nid_identifier_gin ON identity_credential_identifiers USING GIN (nid, identifier gin_trgm_ops); diff --git a/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.down.sql b/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.down.sql new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.postgres.down.sql b/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.postgres.down.sql new file mode 100644 index 000000000000..159cb60805d1 --- /dev/null +++ b/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.postgres.down.sql @@ -0,0 +1 @@ +DROP INDEX identity_credential_identifiers_nid_identifier_gin; diff --git a/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.postgres.up.sql b/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.postgres.up.sql new file mode 100644 index 000000000000..70d519fb44bc --- /dev/null +++ b/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.postgres.up.sql @@ -0,0 +1,4 @@ +CREATE EXTENSION IF NOT EXISTS pg_trgm; +CREATE EXTENSION IF NOT EXISTS btree_gin; + +CREATE INDEX identity_credential_identifiers_nid_identifier_gin ON identity_credential_identifiers USING GIN (nid, identifier gin_trgm_ops); diff --git a/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.up.sql b/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.up.sql new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/persistence/sql/migrations/sql/20230920171028000000_identity_search_index.up.sql @@ -0,0 +1 @@ + diff --git a/quickstart-postgres.yml b/quickstart-postgres.yml index b97e3f499cfb..8c0dee47a2f3 100644 --- a/quickstart-postgres.yml +++ b/quickstart-postgres.yml @@ -10,7 +10,7 @@ services: - DSN=postgres://kratos:secret@postgresd:5432/kratos?sslmode=disable&max_conns=20&max_idle_conns=4 postgresd: - image: postgres:9.6 + image: postgres:11.8 ports: - "5432:5432" environment: diff --git a/spec/api.json b/spec/api.json index c0629d544021..f3bd1fa14d28 100644 --- a/spec/api.json +++ b/spec/api.json @@ -3459,12 +3459,20 @@ } }, { - "description": "CredentialsIdentifier is the identifier (username, email) of the credentials to look up.", + "description": "CredentialsIdentifier is the identifier (username, email) of the credentials to look up using exact match.\nOnly one of CredentialsIdentifier and CredentialsIdentifierSimilar can be used.", "in": "query", "name": "credentials_identifier", "schema": { "type": "string" } + }, + { + "description": "This is an EXPERIMENTAL parameter that WILL CHANGE. Do NOT rely on consistent, deterministic behavior.\nTHIS PARAMETER MIGHT BE REMOVED IN AN UPCOMING RELEASE WITHOUT ANY MIGRATION PATH.\n\nCredentialsIdentifierSimilar is the (partial) identifier (username, email) of the credentials to look up using similarity search.\nOnly one of CredentialsIdentifier and CredentialsIdentifierSimilar can be used.", + "in": "query", + "name": "preview_credentials_identifier_similar", + "schema": { + "type": "string" + } } ], "responses": { diff --git a/spec/swagger.json b/spec/swagger.json index 78549f804e7d..c5f8e624ae1d 100755 --- a/spec/swagger.json +++ b/spec/swagger.json @@ -222,9 +222,15 @@ }, { "type": "string", - "description": "CredentialsIdentifier is the identifier (username, email) of the credentials to look up.", + "description": "CredentialsIdentifier is the identifier (username, email) of the credentials to look up using exact match.\nOnly one of CredentialsIdentifier and CredentialsIdentifierSimilar can be used.", "name": "credentials_identifier", "in": "query" + }, + { + "type": "string", + "description": "This is an EXPERIMENTAL parameter that WILL CHANGE. Do NOT rely on consistent, deterministic behavior.\nTHIS PARAMETER MIGHT BE REMOVED IN AN UPCOMING RELEASE WITHOUT ANY MIGRATION PATH.\n\nCredentialsIdentifierSimilar is the (partial) identifier (username, email) of the credentials to look up using similarity search.\nOnly one of CredentialsIdentifier and CredentialsIdentifierSimilar can be used.", + "name": "preview_credentials_identifier_similar", + "in": "query" } ], "responses": {