diff --git a/lib/cloud/aws/policy.go b/lib/cloud/aws/policy.go index 51f44dab7258e..0208dd8ae6bba 100644 --- a/lib/cloud/aws/policy.go +++ b/lib/cloud/aws/policy.go @@ -84,25 +84,71 @@ type Statement struct { // Condition: // StringEquals: // "proxy.example.com:aud": "discover.teleport" - Conditions map[string]map[string]SliceOrString `json:"Condition,omitempty"` + Conditions Conditions `json:"Condition,omitempty"` // StatementID is an optional identifier for the statement. StatementID string `json:"Sid,omitempty"` } +// Conditions is a list of conditions that must be satisfied for an action to be allowed. +type Conditions map[string]StringOrMap + +// Equals returns true if conditions are equal. +func (a Conditions) Equals(b Conditions) bool { + if len(a) != len(b) { + return false + } + for conditionKindA, conditionOpA := range a { + conditionOpB := b[conditionKindA] + if !conditionOpA.Equals(conditionOpB) { + return false + } + } + return true +} + // ensureResource ensures that the statement contains the specified resource. // -// Returns true if the resource was already a part of the statement. +// Returns true if the resource was added to the statement or false if the +// resource was already part of the statement. func (s *Statement) ensureResource(resource string) bool { if slices.Contains(s.Resources, resource) { - return true + return false } s.Resources = append(s.Resources, resource) - return false + return true } -func (s *Statement) ensureResources(resources []string) { +func (s *Statement) ensureResources(resources []string) bool { + var updated bool for _, resource := range resources { - s.ensureResource(resource) + updated = s.ensureResource(resource) || updated + } + return updated +} + +// ensurePrincipal ensures that the statement contains the specified principal. +// +// Returns true if the principal was already a part of the statement. +func (s *Statement) ensurePrincipal(kind string, value string) bool { + if len(s.Principals) == 0 { + s.Principals = make(StringOrMap) + } + values := s.Principals[kind] + if slices.Contains(values, value) { + return false + } + values = append(values, value) + s.Principals[kind] = values + return true +} + +func (s *Statement) ensurePrincipals(principals StringOrMap) bool { + var updated bool + for kind, values := range principals { + for _, v := range values { + updated = s.ensurePrincipal(kind, v) || updated + } } + return updated } // EqualStatement returns whether the receive statement is the same. @@ -115,39 +161,17 @@ func (s *Statement) EqualStatement(other *Statement) bool { return false } - if len(s.Principals) != len(other.Principals) { + if !s.Principals.Equals(other.Principals) { return false } - for principalKind, principalList := range s.Principals { - expectedPrincipalList := other.Principals[principalKind] - if !slices.Equal(principalList, expectedPrincipalList) { - return false - } - } - if !slices.Equal(s.Resources, other.Resources) { return false } - if len(s.Conditions) != len(other.Conditions) { + if !s.Conditions.Equals(other.Conditions) { return false } - for conditionKind, conditionOp := range s.Conditions { - expectedConditionOp := other.Conditions[conditionKind] - - if len(conditionOp) != len(expectedConditionOp) { - return false - } - - for conditionOpKind, conditionOpList := range conditionOp { - expectedConditionOpList := expectedConditionOp[conditionOpKind] - if !slices.Equal(conditionOpList, expectedConditionOpList) { - return false - } - } - } - return true } @@ -174,33 +198,38 @@ func NewPolicyDocument(statements ...*Statement) *PolicyDocument { } } -// Ensure ensures that the policy document contains the specified resource -// action. +// EnsureResourceAction ensures that the policy document contains the specified +// resource action. // -// Returns true if the resource action was already a part of the policy and -// false otherwise. -func (p *PolicyDocument) Ensure(effect, action, resource string) bool { - if existingStatement := p.findStatement(effect, action); existingStatement != nil { +// Returns true if the resource action was added to the policy and false if it +// was already part of the policy. +func (p *PolicyDocument) EnsureResourceAction(effect, action, resource string, conditions Conditions) bool { + if existingStatement := p.findStatement(effect, action, conditions); existingStatement != nil { return existingStatement.ensureResource(resource) } // No statement yet for this resource action, add it. p.Statements = append(p.Statements, &Statement{ - Effect: effect, - Actions: []string{action}, - Resources: []string{resource}, + Effect: effect, + Actions: []string{action}, + Resources: []string{resource}, + Conditions: conditions, }) - return false + return true } // Delete deletes the specified resource action from the policy. -func (p *PolicyDocument) Delete(effect, action, resource string) { +func (p *PolicyDocument) DeleteResourceAction(effect, action, resource string, conditions Conditions) { var statements []*Statement for _, s := range p.Statements { if s.Effect != effect { statements = append(statements, s) continue } + if !s.Conditions.Equals(conditions) { + statements = append(statements, s) + continue + } var resources []string for _, a := range s.Actions { for _, r := range s.Resources { @@ -225,7 +254,8 @@ func (p *PolicyDocument) Delete(effect, action, resource string) { // // The main benefit of using this function (versus appending to p.Statements // directly) is to avoid duplications. -func (p *PolicyDocument) EnsureStatements(statements ...*Statement) { +func (p *PolicyDocument) EnsureStatements(statements ...*Statement) bool { + var updated bool for _, statement := range statements { if statement == nil { continue @@ -234,8 +264,9 @@ func (p *PolicyDocument) EnsureStatements(statements ...*Statement) { // Try to find an existing statement by the action, and add the resources there. var newActions []string for _, action := range statement.Actions { - if existingStatement := p.findStatement(statement.Effect, action); existingStatement != nil { - existingStatement.ensureResources(statement.Resources) + if existingStatement := p.findStatement(statement.Effect, action, statement.Conditions); existingStatement != nil { + updated = existingStatement.ensureResources(statement.Resources) || updated + updated = existingStatement.ensurePrincipals(statement.Principals) || updated } else { newActions = append(newActions, action) } @@ -244,12 +275,21 @@ func (p *PolicyDocument) EnsureStatements(statements ...*Statement) { // Add the leftover actions as a new statement. if len(newActions) > 0 { p.Statements = append(p.Statements, &Statement{ - Effect: statement.Effect, - Actions: newActions, - Resources: statement.Resources, + Effect: statement.Effect, + Actions: newActions, + Resources: statement.Resources, + Conditions: statement.Conditions, + Principals: statement.Principals, }) + updated = true } } + return updated +} + +// IsEmpty returns whether the policy document is empty. +func (p *PolicyDocument) IsEmpty() bool { + return len(p.Statements) == 0 } // Marshal formats the PolicyDocument in a "friendly" format, which can be @@ -264,27 +304,30 @@ func (p *PolicyDocument) Marshal() (string, error) { } // ForEach loops through each action and resource of each statement. -func (p *PolicyDocument) ForEach(fn func(effect, action, resource string)) { +func (p *PolicyDocument) ForEach(fn func(effect, action, resource string, conditions Conditions)) { for _, statement := range p.Statements { for _, action := range statement.Actions { for _, resource := range statement.Resources { - fn(statement.Effect, action, resource) + fn(statement.Effect, action, resource, statement.Conditions) } } } } -func (p *PolicyDocument) findStatement(effect, action string) *Statement { +func (p *PolicyDocument) findStatement(effect, action string, conditions Conditions) *Statement { for _, s := range p.Statements { if s.Effect != effect { continue } - if slices.Contains(s.Actions, action) { - return s + if !slices.Contains(s.Actions, action) { + continue + } + if !s.Conditions.Equals(conditions) { + continue } + return s } return nil - } // SliceOrString defines a type that can be either a single string or a slice. @@ -337,6 +380,21 @@ func (s SliceOrString) MarshalJSON() ([]byte, error) { // https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_principal.html#principal-anonymous type StringOrMap map[string]SliceOrString +// Equals returns true if this StringOrMap is equal to another StringOrMap. +func (s StringOrMap) Equals(other StringOrMap) bool { + if len(s) != len(other) { + return false + } + + for key, list := range s { + otherList := other[key] + if !slices.Equal(list, otherList) { + return false + } + } + return true +} + // UnmarshalJSON implements json.Unmarshaller. // If it contains a string and not a map, it will create a map with a single entry: // { "str": [] } diff --git a/lib/cloud/aws/policy_statements.go b/lib/cloud/aws/policy_statements.go index a6fb3f3556739..644981e3ccdab 100644 --- a/lib/cloud/aws/policy_statements.go +++ b/lib/cloud/aws/policy_statements.go @@ -161,7 +161,7 @@ func StatementForAWSAppAccess() *Statement { "sts:AssumeRole", }, Resources: allResources, - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "StringEquals": { "iam:ResourceTag/" + requiredTag: SliceOrString{"true"}, }, @@ -198,7 +198,7 @@ func StatementForAWSOIDCRoleTrustRelationship(accountID, providerURL string, aud Principals: map[string]SliceOrString{ "Federated": []string{federatedARN}, }, - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "StringEquals": { federatedAudience: audiences, }, diff --git a/lib/cloud/aws/policy_test.go b/lib/cloud/aws/policy_test.go index b76cc70684de7..d9085a7f586a3 100644 --- a/lib/cloud/aws/policy_test.go +++ b/lib/cloud/aws/policy_test.go @@ -317,7 +317,7 @@ func TestMarshalPolicyDocument(t *testing.T) { Principals: map[string]SliceOrString{ "Federated": {"arn:aws:iam::123456789012:oidc-provider/proxy.example.com"}, }, - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "StringEquals": { "proxy.example.com:aud": SliceOrString{"discover.teleport"}, }, @@ -353,8 +353,8 @@ func TestIAMPolicy(t *testing.T) { policy := NewPolicyDocument() // Add a new action/resource. - alreadyExisted := policy.Ensure(EffectAllow, "action-1", "resource-1") - require.False(t, alreadyExisted) + updated := policy.EnsureResourceAction(EffectAllow, "action-1", "resource-1", nil) + require.True(t, updated) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, Statements: []*Statement{ @@ -367,8 +367,8 @@ func TestIAMPolicy(t *testing.T) { }, policy) // Add the same action/resource. - alreadyExisted = policy.Ensure(EffectAllow, "action-1", "resource-1") - require.True(t, alreadyExisted) + updated = policy.EnsureResourceAction(EffectAllow, "action-1", "resource-1", nil) + require.False(t, updated) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, Statements: []*Statement{ @@ -381,8 +381,8 @@ func TestIAMPolicy(t *testing.T) { }, policy) // Add a new resource to existing action. - alreadyExisted = policy.Ensure(EffectAllow, "action-1", "resource-2") - require.False(t, alreadyExisted) + updated = policy.EnsureResourceAction(EffectAllow, "action-1", "resource-2", nil) + require.True(t, updated) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, Statements: []*Statement{ @@ -395,8 +395,8 @@ func TestIAMPolicy(t *testing.T) { }, policy) // Add another action/resource. - alreadyExisted = policy.Ensure(EffectAllow, "action-2", "resource-3") - require.False(t, alreadyExisted) + updated = policy.EnsureResourceAction(EffectAllow, "action-2", "resource-3", nil) + require.True(t, updated) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, Statements: []*Statement{ @@ -414,7 +414,7 @@ func TestIAMPolicy(t *testing.T) { }, policy) // Delete existing resource action. - policy.Delete(EffectAllow, "action-1", "resource-1") + policy.DeleteResourceAction(EffectAllow, "action-1", "resource-1", nil) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, Statements: []*Statement{ @@ -432,7 +432,7 @@ func TestIAMPolicy(t *testing.T) { }, policy) // Delete last resource from first action, statement should get removed as well. - policy.Delete(EffectAllow, "action-1", "resource-2") + policy.DeleteResourceAction(EffectAllow, "action-1", "resource-2", nil) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, Statements: []*Statement{ @@ -445,7 +445,7 @@ func TestIAMPolicy(t *testing.T) { }, policy) // Delete last resource action, policy should be empty. - policy.Delete(EffectAllow, "action-2", "resource-3") + policy.DeleteResourceAction(EffectAllow, "action-2", "resource-3", nil) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, }, policy) @@ -466,7 +466,7 @@ func TestIAMPolicy(t *testing.T) { }, }, } - policy.Delete(EffectAllow, "action-1", "resource-1") + policy.DeleteResourceAction(EffectAllow, "action-1", "resource-1", nil) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, }, policy) @@ -487,7 +487,7 @@ func TestIAMPolicy(t *testing.T) { }, }, } - policy.Delete(EffectAllow, "action-1", "resource-2") + policy.DeleteResourceAction(EffectAllow, "action-1", "resource-2", nil) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, Statements: []*Statement{ @@ -532,6 +532,19 @@ func TestPolicyEnsureStatements(t *testing.T) { Actions: []string{"action-1"}, Resources: []string{"resource-3"}, }, + // Existing action with different condition and new principals + &Statement{ + Effect: EffectAllow, + Actions: []string{"action-1"}, + Principals: StringOrMap{ + "Federated": []string{"arn:aws:iam::123456789012:oidc-provider/example.com"}, + }, + Conditions: Conditions{ + "StringEquals": StringOrMap{ + "example.com:aud": []string{"discover.teleport"}, + }, + }, + }, // New actions and new resources. &Statement{ Effect: EffectAllow, @@ -565,6 +578,18 @@ func TestPolicyEnsureStatements(t *testing.T) { Actions: []string{"action-2"}, Resources: []string{"resource-1", "resource-4"}, }, + { + Effect: EffectAllow, + Actions: []string{"action-1"}, + Principals: StringOrMap{ + "Federated": []string{"arn:aws:iam::123456789012:oidc-provider/example.com"}, + }, + Conditions: Conditions{ + "StringEquals": StringOrMap{ + "example.com:aud": []string{"discover.teleport"}, + }, + }, + }, { Effect: EffectAllow, Actions: []string{"action-3", "action-4"}, @@ -931,13 +956,13 @@ func TestEqualStatement(t *testing.T) { { name: "different number of conditions", statementA: &Statement{ - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3600"}}, "StringLike": {"s3:prefix": []string{"janedoe/*"}}, }, }, statementB: &Statement{ - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3601"}}, }, }, @@ -946,12 +971,12 @@ func TestEqualStatement(t *testing.T) { { name: "different conditions", statementA: &Statement{ - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3600"}}, }, }, statementB: &Statement{ - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3601"}}, }, }, @@ -960,12 +985,12 @@ func TestEqualStatement(t *testing.T) { { name: "different condition values", statementA: &Statement{ - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3600", "3601"}}, }, }, statementB: &Statement{ - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3600"}}, }, }, @@ -990,7 +1015,7 @@ func TestEqualStatement(t *testing.T) { }, Actions: []string{"s3:GetObject"}, Resources: []string{"arn:aws:s3:::my-bucket/my-prefix/*"}, - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "StringLike": {"s3:prefix": []string{"my-prefix/*"}}, }, }, @@ -1001,7 +1026,7 @@ func TestEqualStatement(t *testing.T) { }, Actions: []string{"s3:GetObject"}, Resources: []string{"arn:aws:s3:::my-bucket/my-prefix/*"}, - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "StringLike": {"s3:prefix": []string{"my-prefix/*"}}, }, }, diff --git a/lib/cloud/provisioning/awsactions/common.go b/lib/cloud/provisioning/awsactions/common.go new file mode 100644 index 0000000000000..bd4a23f97c64b --- /dev/null +++ b/lib/cloud/provisioning/awsactions/common.go @@ -0,0 +1,32 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package awsactions + +import ( + "encoding/json" + + "github.com/gravitational/trace" +) + +func formatDetails(in any) (string, error) { + const prefix = "" + const indent = " " + out, err := json.MarshalIndent(in, prefix, indent) + return string(out), trace.Wrap(err) +} diff --git a/lib/cloud/provisioning/awsactions/create_oidc_idp.go b/lib/cloud/provisioning/awsactions/create_oidc_idp.go new file mode 100644 index 0000000000000..d90425d80f06c --- /dev/null +++ b/lib/cloud/provisioning/awsactions/create_oidc_idp.go @@ -0,0 +1,81 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package awsactions + +import ( + "context" + "log/slog" + + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/gravitational/trace" + + awslib "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/provisioning" + "github.com/gravitational/teleport/lib/integrations/awsoidc/tags" +) + +// OpenIDConnectProviderCreator can create an OpenID Connect Identity Provider +// (OIDC IdP) in AWS IAM. +type OpenIDConnectProviderCreator interface { + // CreateOpenIDConnectProvider creates an AWS IAM OIDC IdP. + CreateOpenIDConnectProvider(ctx context.Context, params *iam.CreateOpenIDConnectProviderInput, optFns ...func(*iam.Options)) (*iam.CreateOpenIDConnectProviderOutput, error) +} + +// CreateOIDCProvider wraps a [OpenIDConnectPRoviderCreator] in a +// [provisioning.Action] that creates an OIDC IdP in AWS IAM when invoked. +func CreateOIDCProvider( + clt OpenIDConnectProviderCreator, + thumbprints []string, + issuerURL string, + clientIDs []string, + tags tags.AWSTags, +) (*provisioning.Action, error) { + input := &iam.CreateOpenIDConnectProviderInput{ + ThumbprintList: thumbprints, + Url: &issuerURL, + ClientIDList: clientIDs, + Tags: tags.ToIAMTags(), + } + details, err := formatDetails(input) + if err != nil { + return nil, trace.Wrap(err) + } + + config := provisioning.ActionConfig{ + Name: "CreateOpenIDConnectProvider", + Summary: "Create an OpenID Connect identity provider in AWS IAM for your Teleport cluster", + Details: details, + RunnerFn: func(ctx context.Context) error { + slog.InfoContext(ctx, "Creating OpenID Connect identity provider") + _, err = clt.CreateOpenIDConnectProvider(ctx, input) + if err != nil { + awsErr := awslib.ConvertIAMv2Error(err) + if trace.IsAlreadyExists(awsErr) { + slog.InfoContext(ctx, "OpenID Connect identity provider already exists") + return nil + } + + return trace.Wrap(err) + } + return nil + }, + } + action, err := provisioning.NewAction(config) + return action, trace.Wrap(err) +} diff --git a/lib/cloud/provisioning/awsactions/create_role.go b/lib/cloud/provisioning/awsactions/create_role.go new file mode 100644 index 0000000000000..c1ed634706fc4 --- /dev/null +++ b/lib/cloud/provisioning/awsactions/create_role.go @@ -0,0 +1,215 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package awsactions + +import ( + "context" + "fmt" + "log/slog" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/iam" + iamtypes "github.com/aws/aws-sdk-go-v2/service/iam/types" + "github.com/gravitational/trace" + + awslib "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/provisioning" + "github.com/gravitational/teleport/lib/integrations/awsoidc/tags" +) + +// RoleCreator can create an IAM role. +type RoleCreator interface { + // CreateRole creates a new IAM Role. + CreateRole(ctx context.Context, params *iam.CreateRoleInput, optFns ...func(*iam.Options)) (*iam.CreateRoleOutput, error) +} + +// RoleGetter can get an IAM role. +type RoleGetter interface { + // GetRole retrieves information about the specified role, including the + // role's path, GUID, ARN, and the role's trust policy that grants + // permission to assume the role. + GetRole(ctx context.Context, params *iam.GetRoleInput, optFns ...func(*iam.Options)) (*iam.GetRoleOutput, error) +} + +// AssumeRolePolicyUpdater can update an IAM role's trust policy. +type AssumeRolePolicyUpdater interface { + // UpdateAssumeRolePolicy updates the policy that grants an IAM entity + // permission to assume a role. + // This is typically referred to as the "role trust policy". + UpdateAssumeRolePolicy(ctx context.Context, params *iam.UpdateAssumeRolePolicyInput, optFns ...func(*iam.Options)) (*iam.UpdateAssumeRolePolicyOutput, error) +} + +// RoleTagger can tag an AWS IAM role. +type RoleTagger interface { + // TagRole adds one or more tags to an IAM role. The role can be a regular + // role or a service-linked role. If a tag with the same key name already + // exists, then that tag is overwritten with the new value. + TagRole(ctx context.Context, params *iam.TagRoleInput, optFns ...func(*iam.Options)) (*iam.TagRoleOutput, error) +} + +// CreateRole returns a [provisioning.Action] that creates or updates an IAM +// role when invoked. +func CreateRole( + clt interface { + AssumeRolePolicyUpdater + RoleCreator + RoleGetter + RoleTagger + }, + roleName string, + description string, + trustPolicy *awslib.PolicyDocument, + tags tags.AWSTags, +) (*provisioning.Action, error) { + trustPolicyJSON, err := trustPolicy.Marshal() + if err != nil { + return nil, trace.Wrap(err) + } + input := &iam.CreateRoleInput{ + RoleName: &roleName, + Description: &description, + AssumeRolePolicyDocument: &trustPolicyJSON, + Tags: tags.ToIAMTags(), + } + type createRoleInput struct { + // AssumeRolePolicyDocument shadows the input's field of the same name + // to marshal the trust policy doc as unescpaed JSON. + AssumeRolePolicyDocument *awslib.PolicyDocument + *iam.CreateRoleInput + } + details, err := formatDetails(createRoleInput{ + AssumeRolePolicyDocument: trustPolicy, + CreateRoleInput: input, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + config := provisioning.ActionConfig{ + Name: "CreateRole", + Summary: fmt.Sprintf("Create IAM role %q with a custom trust policy", roleName), + Details: details, + RunnerFn: func(ctx context.Context) error { + slog.InfoContext(ctx, "Checking for existing IAM role", + "role", roleName, + ) + getRoleOut, err := clt.GetRole(ctx, &iam.GetRoleInput{ + RoleName: &roleName, + }) + if err != nil { + convertedErr := awslib.ConvertIAMv2Error(err) + if !trace.IsNotFound(convertedErr) { + return trace.Wrap(convertedErr) + } + slog.InfoContext(ctx, "Creating IAM role", "role", roleName) + _, err = clt.CreateRole(ctx, input) + if err != nil { + return trace.Wrap(awslib.ConvertIAMv2Error(err)) + } + return nil + } + + slog.InfoContext(ctx, "IAM role already exists", + "role", roleName, + ) + existingTrustPolicy, err := awslib.ParsePolicyDocument(aws.ToString(getRoleOut.Role.AssumeRolePolicyDocument)) + if err != nil { + return trace.Wrap(err) + } + err = ensureTrustPolicy(ctx, clt, roleName, trustPolicy, existingTrustPolicy) + if err != nil { + return trace.Wrap(err) + } + + err = ensureTags(ctx, clt, roleName, tags, getRoleOut.Role.Tags) + if err != nil { + // Tagging an existing role after we update it is a + // nice-to-have, but not a need-to-have. + slog.WarnContext(ctx, "Failed to update IAM role tags", + "role", roleName, + "error", err, + "tags", tags.ToMap(), + ) + } + return nil + }, + } + action, err := provisioning.NewAction(config) + return action, trace.Wrap(err) +} + +func ensureTrustPolicy( + ctx context.Context, + clt AssumeRolePolicyUpdater, + roleName string, + trustPolicy *awslib.PolicyDocument, + existingTrustPolicy *awslib.PolicyDocument, +) error { + slog.InfoContext(ctx, "Checking IAM role trust policy", + "role", roleName, + ) + + if !existingTrustPolicy.EnsureStatements(trustPolicy.Statements...) { + slog.InfoContext(ctx, "IAM role trust policy does not require update", + "role", roleName, + ) + return nil + } + + slog.InfoContext(ctx, "Updating IAM role trust policy", + "role", roleName, + ) + trustPolicyJSON, err := existingTrustPolicy.Marshal() + if err != nil { + return trace.Wrap(err) + } + + _, err = clt.UpdateAssumeRolePolicy(ctx, &iam.UpdateAssumeRolePolicyInput{ + RoleName: &roleName, + PolicyDocument: &trustPolicyJSON, + }) + return trace.Wrap(err) +} + +func ensureTags( + ctx context.Context, + clt RoleTagger, + roleName string, + tags tags.AWSTags, + existingTags []iamtypes.Tag, +) error { + slog.InfoContext(ctx, "Checking for tags on IAM role", + "role", roleName, + ) + if tags.MatchesIAMTags(existingTags) { + slog.InfoContext(ctx, "IAM role is already tagged", + "role", roleName, + ) + return nil + } + + slog.InfoContext(ctx, "Updating IAM role tags", + "role", roleName, + ) + _, err := clt.TagRole(ctx, &iam.TagRoleInput{ + RoleName: &roleName, + Tags: tags.ToIAMTags(), + }) + return trace.Wrap(err) +} diff --git a/lib/cloud/provisioning/operations.go b/lib/cloud/provisioning/operations.go new file mode 100644 index 0000000000000..8b16a445ccace --- /dev/null +++ b/lib/cloud/provisioning/operations.go @@ -0,0 +1,219 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package provisioning + +import ( + "context" + "fmt" + "io" + "log/slog" + "os" + "regexp" + "strings" + "text/template" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/utils/prompt" +) + +// validName is used to ensure that [OperationConfig] and [ActionConfig] names +// start with a letter and only consist of letters, numbers, and hyphen +// characters thereafter. +var validName = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9-]+$`) + +// OperationConfig is the configuration for an operation. +type OperationConfig struct { + // Name is the operation name. Must consist of only letters and hyphens. + Name string + // Actions is the list of actions that make up the operation. + Actions []Action + // AutoConfirm is whether to skip the operation plan confirmation prompt. + AutoConfirm bool + // Output is an [io.Writer] where the operation plan and confirmation prompt + // are written to. + // Defaults to [os.Stdout]. + Output io.Writer +} + +// checkAndSetDefaults validates the operation config and sets defaults. +func (c *OperationConfig) checkAndSetDefaults() error { + c.Name = strings.TrimSpace(c.Name) + if c.Name == "" { + return trace.BadParameter("missing operation name") + } + if !validName.MatchString(c.Name) { + return trace.BadParameter( + "operation name %q does not match regex used for validation %q", + c.Name, validName.String(), + ) + } + if len(c.Actions) == 0 { + return trace.BadParameter("missing operation actions") + } + if c.Output == nil { + c.Output = os.Stdout + } + return nil +} + +// Action wraps a runnable function to provide a name, summary, and detailed +// explanation of the behavior. +type Action struct { + // config is an unexported value-type to prevent mutation after it's been + // validated by the checkAndSetDefaults func. + config ActionConfig +} + +// GetName returns the action's configured name. +func (a *Action) GetName() string { + return a.config.Name +} + +// GetSummary returns the action's configured summary. +func (a *Action) GetSummary() string { + return a.config.Summary +} + +// GetDetails returns the action's configured details. +func (a *Action) GetDetails() string { + return a.config.Details +} + +// Run runs the action. +func (a *Action) Run(ctx context.Context) error { + return a.config.RunnerFn(ctx) +} + +// NewAction creates a new [Action]. +func NewAction(config ActionConfig) (*Action, error) { + if err := config.checkAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + return &Action{ + config: config, + }, nil +} + +// ActionConfig is the configuration for an [Action]. +type ActionConfig struct { + // Name is the action name. + Name string + // Summary is the action summary in prose. + Summary string + // Details is the detailed explanation of the action, explaining exactly + // what it will do. + Details string + // RunnerFn is a function that actually runs the action. + RunnerFn func(context.Context) error +} + +// checkAndSetDefaults validates the action config and sets defaults. +func (c *ActionConfig) checkAndSetDefaults() error { + c.Name = strings.TrimSpace(c.Name) + c.Summary = strings.TrimSpace(c.Summary) + c.Details = strings.TrimSpace(c.Details) + if c.Name == "" { + return trace.BadParameter("missing action name") + } + if !validName.MatchString(c.Name) { + return trace.BadParameter( + "action name %q does not match regex used for validation %q", + c.Name, validName.String(), + ) + } + if c.Summary == "" { + return trace.BadParameter("missing action summary") + } + if c.Details == "" { + return trace.BadParameter("missing action details") + } + if c.RunnerFn == nil { + return trace.BadParameter("missing action runner") + } + + return nil +} + +// Run writes the operation plan, optionally prompts for user confirmation, +// then executes the operation plan. +func Run(ctx context.Context, config OperationConfig) error { + if err := config.checkAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + + if err := writeOperationPlan(config); err != nil { + return trace.Wrap(err) + } + + if !config.AutoConfirm { + question := fmt.Sprintf("Do you want %q to perform these actions?", config.Name) + ok, err := prompt.Confirmation(ctx, config.Output, prompt.Stdin(), question) + if err != nil { + return trace.Wrap(err) + } + if !ok { + return trace.BadParameter("operation %q canceled", config.Name) + } + } + + enumerateSteps := len(config.Actions) > 1 + for i, action := range config.Actions { + if enumerateSteps { + slog.InfoContext(ctx, "Running", "step", i+1, "action", action.config.Name) + } else { + slog.InfoContext(ctx, "Running", "action", action.config.Name) + } + + if err := action.Run(ctx); err != nil { + if enumerateSteps { + return trace.Wrap(err, "step %d %q failed", i+1, action.config.Name) + } + return trace.Wrap(err, "%q failed", action.config.Name) + } + } + slog.InfoContext(ctx, "Success!", "operation", config.Name) + return nil +} + +// writeOperationPlan writes the operational plan to the given [io.Writer] as +// a structured summary of the operation and the actions that compose it. +func writeOperationPlan(config OperationConfig) error { + data := map[string]any{ + "config": config, + "showStepNumbers": len(config.Actions) > 1, + } + return trace.Wrap(operationPlanTemplate.Execute(config.Output, data)) +} + +var operationPlanTemplate = template.Must(template.New("plan"). + Funcs(template.FuncMap{ + // used to enumerate the action steps starting from 1 instead of 0. + "addOne": func(x int) int { return x + 1 }, + }). + Parse(` +{{- printf "%q" .config.Name }} will perform the following actions: + +{{ $global := . }} +{{- range $index, $action := .config.Actions }} +{{- if $global.showStepNumbers }}{{ $index | addOne }}. {{ end -}}{{$action.GetSummary}}. +{{$action.GetName}}: {{$action.GetDetails}} + +{{end -}} +`)) diff --git a/lib/cloud/provisioning/operations_test.go b/lib/cloud/provisioning/operations_test.go new file mode 100644 index 0000000000000..56f8af546c7b8 --- /dev/null +++ b/lib/cloud/provisioning/operations_test.go @@ -0,0 +1,372 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package provisioning + +import ( + "bytes" + "context" + "os" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +func withErrCheck(fn func(*testing.T, error)) require.ErrorAssertionFunc { + return func(t require.TestingT, err error, msgAndArgs ...interface{}) { + tt := t.(*testing.T) + tt.Helper() + fn(tt, err) + } +} + +var badParameterCheck = withErrCheck(func(t *testing.T, err error) { + t.Helper() + require.True(t, trace.IsBadParameter(err), `expected "bad parameter", but got %v`, err) +}) + +func TestOperationCheckAndSetDefaults(t *testing.T) { + stubAction := Action{} + baseCfg := OperationConfig{ + Name: "name", + Actions: []Action{stubAction}, + AutoConfirm: true, + Output: os.Stdout, + } + tests := []struct { + desc string + getConfig func() OperationConfig + want OperationConfig + wantErrContains string + }{ + { + desc: "valid config", + getConfig: func() OperationConfig { return baseCfg }, + want: baseCfg, + }, + { + desc: "defaults output to stdout", + getConfig: func() OperationConfig { + cfg := baseCfg + cfg.Output = nil + return cfg + }, + want: baseCfg, + }, + { + desc: "trims trailing and leading spaces", + getConfig: func() OperationConfig { + cfg := baseCfg + cfg.Name = "\t\n " + cfg.Name + "\t\n " + return cfg + }, + want: baseCfg, + }, + { + desc: "missing name is an error", + getConfig: func() OperationConfig { + cfg := baseCfg + cfg.Name = "" + return cfg + }, + wantErrContains: "missing operation name", + }, + { + desc: "spaces in name is an error", + getConfig: func() OperationConfig { + cfg := baseCfg + cfg.Name = "some thing" + return cfg + }, + wantErrContains: "does not match regex", + }, + { + desc: "missing actions is an error", + getConfig: func() OperationConfig { + cfg := baseCfg + cfg.Actions = nil + return cfg + }, + wantErrContains: "missing operation actions", + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + got := test.getConfig() + err := got.checkAndSetDefaults() + if test.wantErrContains != "" { + require.Error(t, err) + require.ErrorContains(t, err, test.wantErrContains) + return + } + require.Empty(t, cmp.Diff(test.want, got, + cmp.AllowUnexported(Action{}), + cmpopts.IgnoreFields(OperationConfig{}, "Output"), + )) + require.NotNil(t, got.Output) + }) + } +} + +func TestNewAction(t *testing.T) { + noOpRunnerFn := func(context.Context) error { return nil } + baseCfg := ActionConfig{ + Name: "name", + Summary: "summary", + Details: "details", + RunnerFn: noOpRunnerFn, + } + tests := []struct { + desc string + getConfig func() ActionConfig + errCheck require.ErrorAssertionFunc + want Action + }{ + { + desc: "valid config", + getConfig: func() ActionConfig { return baseCfg }, + errCheck: require.NoError, + want: Action{config: baseCfg}, + }, + { + desc: "trims trailing and leading spaces", + getConfig: func() ActionConfig { + cfg := baseCfg + cfg.Name = "\t\n " + cfg.Name + "\t\n " + return cfg + }, + errCheck: require.NoError, + want: Action{config: baseCfg}, + }, + { + desc: "missing name is an error", + getConfig: func() ActionConfig { + cfg := baseCfg + cfg.Name = "" + return cfg + }, + errCheck: badParameterCheck, + }, + { + desc: "spaces in name is an error", + getConfig: func() ActionConfig { + cfg := baseCfg + cfg.Name = "some thing" + return cfg + }, + errCheck: badParameterCheck, + }, + { + desc: "missing summary is an error", + getConfig: func() ActionConfig { + cfg := baseCfg + cfg.Summary = "" + return cfg + }, + errCheck: badParameterCheck, + }, + { + desc: "missing details is an error", + getConfig: func() ActionConfig { + cfg := baseCfg + cfg.Details = "" + return cfg + }, + errCheck: badParameterCheck, + }, + { + desc: "missing runner is an error", + getConfig: func() ActionConfig { + cfg := baseCfg + cfg.RunnerFn = nil + return cfg + }, + errCheck: badParameterCheck, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + _, err := NewAction(test.getConfig()) + test.errCheck(t, err) + }) + } +} + +func TestRun(t *testing.T) { + actionA := Action{ + config: ActionConfig{ + Name: "actionA", + Summary: "actionA summary", + Details: "actionA details", + RunnerFn: func(context.Context) error { return nil }, + }, + } + actionB := Action{ + config: ActionConfig{ + Name: "actionB", + Summary: "actionB summary", + Details: "actionB details", + RunnerFn: func(context.Context) error { return nil }, + }, + } + failingAction := Action{ + config: ActionConfig{ + Name: "actionC", + Summary: "actionC summary", + Details: "actionC details", + RunnerFn: func(context.Context) error { return trace.AccessDenied("access denied") }, + }, + } + + tests := []struct { + desc string + config OperationConfig + errCheck require.ErrorAssertionFunc + }{ + { + desc: "success", + config: OperationConfig{ + Name: "op-name", + Actions: []Action{ + actionA, + actionB, + }, + AutoConfirm: true, + }, + errCheck: require.NoError, + }, + { + desc: "failed action does not enumerate the failed step", + config: OperationConfig{ + Name: "op-name", + Actions: []Action{ + failingAction, + }, + AutoConfirm: true, + }, + errCheck: withErrCheck(func(t *testing.T, err error) { + t.Helper() + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + // the error should indicate the failed action + require.Contains(t, err.Error(), `"actionC" failed`) + // but it should not number the step, since there's only + // one action. + require.NotContains(t, err.Error(), `1`) + }), + }, + { + desc: "failed action fails the entire operation", + config: OperationConfig{ + Name: "op-name", + Actions: []Action{ + actionA, + failingAction, + actionB, + }, + AutoConfirm: true, + }, + errCheck: withErrCheck(func(t *testing.T, err error) { + t.Helper() + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + // error should indicate step number and action name that failed + require.Contains(t, err.Error(), `step 2 "actionC" failed`) + }), + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + err := Run(ctx, test.config) + test.errCheck(t, err) + }) + } +} + +func TestWriteOperationPlan(t *testing.T) { + actionA := Action{ + config: ActionConfig{ + Name: "actionA", + Summary: "", + Details: "", + RunnerFn: func(context.Context) error { return trace.NotImplemented("not implemented") }, + }, + } + actionB := Action{ + config: ActionConfig{ + Name: "actionB", + Summary: "", + Details: "", + RunnerFn: func(context.Context) error { return trace.NotImplemented("not implemented") }, + }, + } + tests := []struct { + desc string + config OperationConfig + want string + }{ + { + desc: "operation with only one action does not enumerate the step", + config: OperationConfig{ + Name: "op-name", + Actions: []Action{actionA}, + AutoConfirm: true, + }, + want: strings.TrimLeft(` +"op-name" will perform the following actions: + +. +actionA: + +`, "\n"), + }, + { + desc: "operation with multiple actions enumerates the steps", + config: OperationConfig{ + Name: "op-name", + Actions: []Action{actionA, actionB}, + AutoConfirm: true, + }, + want: strings.TrimLeft(` +"op-name" will perform the following actions: + +1. . +actionA: + +2. . +actionB: + +`, "\n"), + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + var got bytes.Buffer + test.config.Output = &got + require.NoError(t, writeOperationPlan(test.config)) + require.Empty(t, cmp.Diff(test.want, got.String())) + }) + } +} diff --git a/lib/config/configuration.go b/lib/config/configuration.go index ac3e56ff59e8b..c1cc43fc2783f 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -380,6 +380,8 @@ type IntegrationConfAWSOIDCIdP struct { // ProxyPublicURL is the IdP Issuer URL (Teleport Proxy Public Address). // Eg, https://.teleport.sh ProxyPublicURL string + // AutoConfirm skips user confirmation of the operation plan if true. + AutoConfirm bool } // IntegrationConfListDatabasesIAM contains the arguments of diff --git a/lib/configurators/aws/aws.go b/lib/configurators/aws/aws.go index 4e31b202e4d88..cece7047dfc13 100644 --- a/lib/configurators/aws/aws.go +++ b/lib/configurators/aws/aws.go @@ -523,7 +523,7 @@ func buildCommonActions(config ConfiguratorConfig, targetCfg targetConfig) ([]co // If the policy has no statements means that the agent doesn't require // any IAM permission. In this case, return without errors and with empty // actions. - if len(policy.Document.Statements) == 0 { + if policy.Document.IsEmpty() { return []configurators.ConfiguratorAction{}, nil } diff --git a/lib/integrations/awsoidc/idp_iam_config.go b/lib/integrations/awsoidc/idp_iam_config.go index 24d14ffdcb2ee..6fff791c34b5f 100644 --- a/lib/integrations/awsoidc/idp_iam_config.go +++ b/lib/integrations/awsoidc/idp_iam_config.go @@ -20,7 +20,7 @@ package awsoidc import ( "context" - "log/slog" + "io" "net/http" "net/url" @@ -32,6 +32,8 @@ import ( "github.com/gravitational/teleport/api/types" awslib "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/provisioning" + "github.com/gravitational/teleport/lib/cloud/provisioning/awsactions" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/integrations/awsoidc/tags" ) @@ -59,6 +61,9 @@ type IdPIAMConfigureRequest struct { // Eg, https://.teleport.sh, https://proxy.example.org:443, https://teleport.ec2.aws:3080 ProxyPublicAddress string + // AutoConfirm skips user confirmation of the operation plan if true. + AutoConfirm bool + // issuer is the above value but only contains the host. // Eg, .teleport.sh, proxy.example.org issuer string @@ -70,6 +75,12 @@ type IdPIAMConfigureRequest struct { IntegrationRole string ownershipTags tags.AWSTags + + // stdout is used to override stdout output in tests. + stdout io.Writer + // fakeThumbprint is used to override thumbprint in output tests, to produce + // consistent output. + fakeThumbprint string } // CheckAndSetDefaults ensures the required fields are present. @@ -109,21 +120,11 @@ func (r *IdPIAMConfigureRequest) CheckAndSetDefaults() error { // There is no guarantee that the client is thread safe. type IdPIAMConfigureClient interface { CallerIdentityGetter - - // CreateOpenIDConnectProvider creates an IAM OIDC IdP. - CreateOpenIDConnectProvider(ctx context.Context, params *iam.CreateOpenIDConnectProviderInput, optFns ...func(*iam.Options)) (*iam.CreateOpenIDConnectProviderOutput, error) - - // CreateRole creates a new IAM Role. - CreateRole(ctx context.Context, params *iam.CreateRoleInput, optFns ...func(*iam.Options)) (*iam.CreateRoleOutput, error) - - // GetRole retrieves information about the specified role, including the role's path, - // GUID, ARN, and the role's trust policy that grants permission to assume the - // role. - GetRole(ctx context.Context, params *iam.GetRoleInput, optFns ...func(*iam.Options)) (*iam.GetRoleOutput, error) - - // UpdateAssumeRolePolicy updates the policy that grants an IAM entity permission to assume a role. - // This is typically referred to as the "role trust policy". - UpdateAssumeRolePolicy(ctx context.Context, params *iam.UpdateAssumeRolePolicyInput, optFns ...func(*iam.Options)) (*iam.UpdateAssumeRolePolicyOutput, error) + awsactions.AssumeRolePolicyUpdater + awsactions.OpenIDConnectProviderCreator + awsactions.RoleCreator + awsactions.RoleGetter + awsactions.RoleTagger } type defaultIdPIAMConfigureClient struct { @@ -183,99 +184,53 @@ func ConfigureIdPIAM(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMC req.AccountID = aws.ToString(callerIdentity.Account) } - slog.InfoContext(ctx, "Creating IAM OpenID Connect Provider", "url", req.issuerURL) - if err := ensureOIDCIdPIAM(ctx, clt, req); err != nil { - return trace.Wrap(err) - } - - slog.InfoContext(ctx, "Creating IAM Role", "role", req.IntegrationRole) - if err := upsertIdPIAMRole(ctx, clt, req); err != nil { - return trace.Wrap(err) - } - - return nil -} - -func ensureOIDCIdPIAM(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) error { - thumbprint, err := ThumbprintIdP(ctx, req.ProxyPublicAddress) - if err != nil { - return trace.Wrap(err) - } - - _, err = clt.CreateOpenIDConnectProvider(ctx, &iam.CreateOpenIDConnectProviderInput{ - ThumbprintList: []string{thumbprint}, - Url: &req.issuerURL, - ClientIDList: []string{types.IntegrationAWSOIDCAudience}, - Tags: req.ownershipTags.ToIAMTags(), - }) + createOIDCIdP, err := createOIDCIdPAction(ctx, clt, req) if err != nil { - awsErr := awslib.ConvertIAMv2Error(err) - if trace.IsAlreadyExists(awsErr) { - return nil - } - return trace.Wrap(err) } - return nil -} - -func createIdPIAMRole(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) error { - integrationRoleAssumeRoleDocument, err := awslib.NewPolicyDocument( - awslib.StatementForAWSOIDCRoleTrustRelationship(req.AccountID, req.issuer, []string{types.IntegrationAWSOIDCAudience}), - ).Marshal() + createIdPIAMRole, err := createIdPIAMRoleAction(clt, req) if err != nil { return trace.Wrap(err) } - _, err = clt.CreateRole(ctx, &iam.CreateRoleInput{ - RoleName: &req.IntegrationRole, - Description: aws.String(descriptionOIDCIdPRole), - AssumeRolePolicyDocument: &integrationRoleAssumeRoleDocument, - Tags: req.ownershipTags.ToIAMTags(), - }) - return trace.Wrap(err) + return trace.Wrap(provisioning.Run(ctx, provisioning.OperationConfig{ + Name: "awsoidc-idp", + Actions: []provisioning.Action{ + *createOIDCIdP, + *createIdPIAMRole, + }, + AutoConfirm: req.AutoConfirm, + Output: req.stdout, + })) } -func upsertIdPIAMRole(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) error { - getRoleOut, err := clt.GetRole(ctx, &iam.GetRoleInput{ - RoleName: &req.IntegrationRole, - }) - if err != nil { - convertedErr := awslib.ConvertIAMv2Error(err) - if !trace.IsNotFound(convertedErr) { - return trace.Wrap(convertedErr) - } - - return trace.Wrap(createIdPIAMRole(ctx, clt, req)) - } - - if !req.ownershipTags.MatchesIAMTags(getRoleOut.Role.Tags) { - return trace.BadParameter("IAM Role %q already exists but is not managed by Teleport. "+ - "Add the following tags to allow Teleport to manage this Role: %s", req.IntegrationRole, req.ownershipTags) - } - - trustRelationshipDoc, err := awslib.ParsePolicyDocument(aws.ToString(getRoleOut.Role.AssumeRolePolicyDocument)) - if err != nil { - return trace.Wrap(err) - } - - trustRelationshipForIdP := awslib.StatementForAWSOIDCRoleTrustRelationship(req.AccountID, req.issuer, []string{types.IntegrationAWSOIDCAudience}) - for _, existingStatement := range trustRelationshipDoc.Statements { - if existingStatement.EqualStatement(trustRelationshipForIdP) { - return nil +func createOIDCIdPAction(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) (*provisioning.Action, error) { + var thumbprint string + if req.fakeThumbprint != "" { + // only happens in tests. + thumbprint = req.fakeThumbprint + } else { + var err error + thumbprint, err = ThumbprintIdP(ctx, req.ProxyPublicAddress) + if err != nil { + return nil, trace.Wrap(err) } } - trustRelationshipDoc.Statements = append(trustRelationshipDoc.Statements, trustRelationshipForIdP) - trustRelationshipDocString, err := trustRelationshipDoc.Marshal() - if err != nil { - return trace.Wrap(err) - } + clientIDs := []string{types.IntegrationAWSOIDCAudience} + thumbprints := []string{thumbprint} + return awsactions.CreateOIDCProvider(clt, thumbprints, req.issuerURL, clientIDs, req.ownershipTags) +} - _, err = clt.UpdateAssumeRolePolicy(ctx, &iam.UpdateAssumeRolePolicyInput{ - RoleName: &req.IntegrationRole, - PolicyDocument: &trustRelationshipDocString, - }) - return trace.Wrap(err) +func createIdPIAMRoleAction(clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) (*provisioning.Action, error) { + integrationRoleAssumeRoleDocument := awslib.NewPolicyDocument( + awslib.StatementForAWSOIDCRoleTrustRelationship(req.AccountID, req.issuer, []string{types.IntegrationAWSOIDCAudience}), + ) + return awsactions.CreateRole(clt, + req.IntegrationRole, + descriptionOIDCIdPRole, + integrationRoleAssumeRoleDocument, + req.ownershipTags, + ) } diff --git a/lib/integrations/awsoidc/idp_iam_config_test.go b/lib/integrations/awsoidc/idp_iam_config_test.go index 03bfcca1373c0..a7430d1bee953 100644 --- a/lib/integrations/awsoidc/idp_iam_config_test.go +++ b/lib/integrations/awsoidc/idp_iam_config_test.go @@ -19,6 +19,7 @@ package awsoidc import ( + "bytes" "context" "fmt" "net/http/httptest" @@ -35,6 +36,7 @@ import ( "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/integrations/awsoidc/tags" + "github.com/gravitational/teleport/lib/utils/golden" ) func TestIdPIAMConfigReqDefaults(t *testing.T) { @@ -44,6 +46,7 @@ func TestIdPIAMConfigReqDefaults(t *testing.T) { IntegrationName: "myintegration", IntegrationRole: "integrationrole", ProxyPublicAddress: "https://proxy.example.com", + AutoConfirm: true, } } @@ -69,6 +72,7 @@ func TestIdPIAMConfigReqDefaults(t *testing.T) { "teleport.dev/integration": "myintegration", "teleport.dev/origin": "integration_awsoidc", }, + AutoConfirm: true, }, }, { @@ -166,6 +170,7 @@ func TestConfigureIdPIAM(t *testing.T) { IntegrationName: "myintegration", IntegrationRole: "integrationrole", ProxyPublicAddress: tlsServer.URL, + AutoConfirm: true, } } @@ -195,15 +200,7 @@ func TestConfigureIdPIAM(t *testing.T) { errCheck: require.NoError, }, { - name: "role exists, no ownership tags", - mockAccountID: "123456789012", - mockExistingIdPUrl: []string{}, - mockExistingRoles: map[string]mockRole{"integrationrole": {}}, - req: baseIdPIAMConfigReqWithTLServer, - errCheck: badParameterCheck, - }, - { - name: "role exists, ownership tags, no assume role", + name: "role exists with empty trust policy", mockAccountID: "123456789012", mockExistingIdPUrl: []string{}, mockExistingRoles: map[string]mockRole{"integrationrole": { @@ -225,14 +222,12 @@ func TestConfigureIdPIAM(t *testing.T) { }, }, { - name: "role exists, ownership tags, with existing assume role", + name: "role exists with existing trust policy and without matching tags", mockAccountID: "123456789012", mockExistingIdPUrl: []string{}, mockExistingRoles: map[string]mockRole{"integrationrole": { tags: []iamtypes.Tag{ - {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, - {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, - {Key: aws.String("teleport.dev/integration"), Value: aws.String("myintegration")}, + {Key: aws.String("teleport.dev/origin"), Value: aws.String("should be overwritten")}, }, assumeRolePolicyDoc: policyDocWithStatementsJSON( assumeRoleStatementJSON("some-other-issuer"), @@ -247,10 +242,20 @@ func TestConfigureIdPIAM(t *testing.T) { assumeRoleStatementJSON(tlsServerIssuer), ) require.JSONEq(t, *expectedAssumeRolePolicyDoc, aws.ToString(role.assumeRolePolicyDoc)) + gotTags := map[string]string{} + for _, tag := range role.tags { + gotTags[aws.ToString(tag.Key)] = aws.ToString(tag.Value) + } + wantTags := map[string]string{ + "teleport.dev/origin": "integration_awsoidc", + "teleport.dev/cluster": "mycluster", + "teleport.dev/integration": "myintegration", + } + require.Equal(t, wantTags, gotTags) }, }, { - name: "role exists, ownership tags, assume role already exists", + name: "role exists with matching trust policy", mockAccountID: "123456789012", mockExistingIdPUrl: []string{}, mockExistingRoles: map[string]mockRole{"integrationrole": { @@ -291,6 +296,32 @@ func TestConfigureIdPIAM(t *testing.T) { } } +func TestConfigureIdPIAMOutput(t *testing.T) { + ctx := context.Background() + var buf bytes.Buffer + req := IdPIAMConfigureRequest{ + Cluster: "mycluster", + IntegrationName: "myintegration", + IntegrationRole: "integrationrole", + ProxyPublicAddress: "https://example.com", + AutoConfirm: true, + stdout: &buf, + fakeThumbprint: "15dbd260c7465ecca6de2c0b2181187f66ee0d1a", + } + + clt := mockIdPIAMConfigClient{ + CallerIdentityGetter: mockSTSClient{accountID: "123456789012"}, + existingRoles: map[string]mockRole{}, + existingIDPUrl: []string{}, + } + + require.NoError(t, ConfigureIdPIAM(ctx, &clt, req)) + if golden.ShouldSet() { + golden.Set(t, buf.Bytes()) + } + require.Equal(t, string(golden.Get(t)), buf.String()) +} + type mockRole struct { assumeRolePolicyDoc *string tags []iamtypes.Tag @@ -366,6 +397,25 @@ func (m *mockIdPIAMConfigClient) UpdateAssumeRolePolicy(ctx context.Context, par return &iam.UpdateAssumeRolePolicyOutput{}, nil } +func (m *mockIdPIAMConfigClient) TagRole(ctx context.Context, params *iam.TagRoleInput, _ ...func(*iam.Options)) (*iam.TagRoleOutput, error) { + roleName := aws.ToString(params.RoleName) + role, found := m.existingRoles[roleName] + if !found { + return nil, trace.NotFound("role not found") + } + + tags := tags.AWSTags{} + for _, existingTag := range role.tags { + tags[*existingTag.Key] = *existingTag.Value + } + for _, newTag := range params.Tags { + tags[*newTag.Key] = *newTag.Value + } + role.tags = tags.ToIAMTags() + m.existingRoles[roleName] = role + return &iam.TagRoleOutput{}, nil +} + func TestNewIdPIAMConfigureClient(t *testing.T) { t.Run("no aws_region env var, returns an error", func(t *testing.T) { _, err := NewIdPIAMConfigureClient(context.Background()) diff --git a/lib/integrations/awsoidc/tags/tags.go b/lib/integrations/awsoidc/tags/tags.go index 5a810f759a64b..fe4ab77991f28 100644 --- a/lib/integrations/awsoidc/tags/tags.go +++ b/lib/integrations/awsoidc/tags/tags.go @@ -19,8 +19,10 @@ package tags import ( + "cmp" "fmt" "maps" + "slices" "strings" athenatypes "github.com/aws/aws-sdk-go-v2/service/athena/types" @@ -67,6 +69,9 @@ func (d AWSTags) ToECSTags() []ecstypes.Tag { Value: &v, }) } + slices.SortFunc(ecsTags, func(a, b ecstypes.Tag) int { + return cmp.Compare(*a.Key, *b.Key) + }) return ecsTags } @@ -79,6 +84,9 @@ func (d AWSTags) ToEC2Tags() []ec2types.Tag { Value: &v, }) } + slices.SortFunc(ec2Tags, func(a, b ec2types.Tag) int { + return cmp.Compare(*a.Key, *b.Key) + }) return ec2Tags } @@ -125,6 +133,9 @@ func (d AWSTags) ToIAMTags() []iamtypes.Tag { Value: &v, }) } + slices.SortFunc(iamTags, func(a, b iamtypes.Tag) int { + return cmp.Compare(*a.Key, *b.Key) + }) return iamTags } @@ -137,6 +148,9 @@ func (d AWSTags) ToS3Tags() []s3types.Tag { Value: &v, }) } + slices.SortFunc(s3Tags, func(a, b s3types.Tag) int { + return cmp.Compare(*a.Key, *b.Key) + }) return s3Tags } @@ -149,6 +163,9 @@ func (d AWSTags) ToAthenaTags() []athenatypes.Tag { Value: &v, }) } + slices.SortFunc(athenaTags, func(a, b athenatypes.Tag) int { + return cmp.Compare(*a.Key, *b.Key) + }) return athenaTags } @@ -161,6 +178,9 @@ func (d AWSTags) ToSSMTags() []ssmtypes.Tag { Value: &v, }) } + slices.SortFunc(ssmTags, func(a, b ssmtypes.Tag) int { + return cmp.Compare(*a.Key, *b.Key) + }) return ssmTags } diff --git a/lib/integrations/awsoidc/testdata/TestConfigureIdPIAMOutput.golden b/lib/integrations/awsoidc/testdata/TestConfigureIdPIAMOutput.golden new file mode 100644 index 0000000000000..9cabf4b0d00d1 --- /dev/null +++ b/lib/integrations/awsoidc/testdata/TestConfigureIdPIAMOutput.golden @@ -0,0 +1,67 @@ +"awsoidc-idp" will perform the following actions: + +1. Create an OpenID Connect identity provider in AWS IAM for your Teleport cluster. +CreateOpenIDConnectProvider: { + "Url": "https://example.com", + "ClientIDList": [ + "discover.teleport" + ], + "Tags": [ + { + "Key": "teleport.dev/cluster", + "Value": "mycluster" + }, + { + "Key": "teleport.dev/integration", + "Value": "myintegration" + }, + { + "Key": "teleport.dev/origin", + "Value": "integration_awsoidc" + } + ], + "ThumbprintList": [ + "15dbd260c7465ecca6de2c0b2181187f66ee0d1a" + ] +} + +2. Create IAM role "integrationrole" with a custom trust policy. +CreateRole: { + "AssumeRolePolicyDocument": { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": "sts:AssumeRoleWithWebIdentity", + "Principal": { + "Federated": "arn:aws:iam::123456789012:oidc-provider/example.com" + }, + "Condition": { + "StringEquals": { + "example.com:aud": "discover.teleport" + } + } + } + ] + }, + "RoleName": "integrationrole", + "Description": "Used by Teleport to provide access to AWS resources.", + "MaxSessionDuration": null, + "Path": null, + "PermissionsBoundary": null, + "Tags": [ + { + "Key": "teleport.dev/cluster", + "Value": "mycluster" + }, + { + "Key": "teleport.dev/integration", + "Value": "myintegration" + }, + { + "Key": "teleport.dev/origin", + "Value": "integration_awsoidc" + } + ] +} + diff --git a/lib/srv/db/cloud/aws.go b/lib/srv/db/cloud/aws.go index 2b8a597029c76..5f69145230434 100644 --- a/lib/srv/db/cloud/aws.go +++ b/lib/srv/db/cloud/aws.go @@ -172,12 +172,12 @@ func (r *awsClient) ensureIAMPolicy(ctx context.Context) error { return trace.Wrap(err) } var changed bool - dbIAM.ForEach(func(effect, action, resource string) { - if policy.Ensure(effect, action, resource) { - r.log.Debugf("Permission %q for %q is already part of policy.", action, resource) - } else { + dbIAM.ForEach(func(effect, action, resource string, conditions awslib.Conditions) { + if policy.EnsureResourceAction(effect, action, resource, conditions) { r.log.Debugf("Adding permission %q for %q to policy.", action, resource) changed = true + } else { + r.log.Debugf("Permission %q for %q is already part of policy.", action, resource) } }) if !changed { @@ -206,11 +206,11 @@ func (r *awsClient) deleteIAMPolicy(ctx context.Context) error { if err != nil { return trace.Wrap(err) } - dbIAM.ForEach(func(effect, action, resource string) { - policy.Delete(effect, action, resource) + dbIAM.ForEach(func(effect, action, resource string, conditions awslib.Conditions) { + policy.DeleteResourceAction(effect, action, resource, conditions) }) // If policy is empty now, delete it as IAM policy can't be empty. - if len(policy.Statements) == 0 { + if policy.IsEmpty() { return r.detachIAMPolicy(ctx) } return r.updateIAMPolicy(ctx, policy) diff --git a/tool/teleport/common/integration_configure.go b/tool/teleport/common/integration_configure.go index dceb1bd2ca529..c542356226264 100644 --- a/tool/teleport/common/integration_configure.go +++ b/tool/teleport/common/integration_configure.go @@ -151,6 +151,7 @@ func onIntegrationConfAWSOIDCIdP(ctx context.Context, clf config.CommandLineFlag IntegrationName: clf.IntegrationConfAWSOIDCIdPArguments.Name, IntegrationRole: clf.IntegrationConfAWSOIDCIdPArguments.Role, ProxyPublicAddress: clf.IntegrationConfAWSOIDCIdPArguments.ProxyPublicURL, + AutoConfirm: clf.IntegrationConfAWSOIDCIdPArguments.AutoConfirm, } return trace.Wrap(awsoidc.ConfigureIdPIAM(ctx, iamClient, confReq)) } diff --git a/tool/teleport/common/teleport.go b/tool/teleport/common/teleport.go index 7047f1eb62759..45376e1ec8d41 100644 --- a/tool/teleport/common/teleport.go +++ b/tool/teleport/common/teleport.go @@ -515,6 +515,7 @@ func Run(options Options) (app *kingpin.Application, executedCommand string, con IntegrationConfAWSOIDCIdPArguments.Role) integrationConfAWSOIDCIdPCmd.Flag("proxy-public-url", "Proxy Public URL (eg https://mytenant.teleport.sh).").Required().StringVar(&ccf. IntegrationConfAWSOIDCIdPArguments.ProxyPublicURL) + integrationConfAWSOIDCIdPCmd.Flag("confirm", "Do not prompt user and auto-confirm all actions.").BoolVar(&ccf.IntegrationConfAWSOIDCIdPArguments.AutoConfirm) integrationConfAWSOIDCIdPCmd.Flag("insecure", "Insecure mode disables certificate validation.").BoolVar(&ccf.InsecureMode) integrationConfListDatabasesCmd := integrationConfigureCmd.Command("listdatabases-iam", "Adds required IAM permissions to List RDS Databases (Instances and Clusters).")