From a3e1316260ac02c1ffbee9b8433682a099659a95 Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Thu, 3 Oct 2024 13:50:02 -0700 Subject: [PATCH] make configure awsoidc-idp actions transparent (#46747) Applies to the integration command that the web UI discover flows tell users to run in AWS CloudShell to setup the AWS OIDC identity provider: teleport integration configure awsoidc-idp The command describes itself, its actions, and the desired state after it runs. It then prompts the user (by default) to confirm the action plan before proceeding. The confirmation prompt can be overridden with cli flag --confirm if desired. The IAM role it configures is no longer required to have the "ownership" tags that teleport applies if it's created by teleport, since the user is now prompted for confirmation before making changes. This allows a user to configure an existing IAM role without tagging the role for configuration by teleport. The command will still attempt to ensure the IAM role it configures has teleport tags, but failing to do so is only a warning. --- lib/cloud/aws/policy.go | 164 +++++--- lib/cloud/aws/policy_statements.go | 4 +- lib/cloud/aws/policy_test.go | 69 ++-- lib/cloud/provisioning/awsactions/common.go | 32 ++ .../awsactions/create_oidc_idp.go | 81 ++++ .../provisioning/awsactions/create_role.go | 215 ++++++++++ lib/cloud/provisioning/operations.go | 219 +++++++++++ lib/cloud/provisioning/operations_test.go | 372 ++++++++++++++++++ lib/config/configuration.go | 2 + lib/configurators/aws/aws.go | 2 +- lib/integrations/awsoidc/idp_iam_config.go | 149 +++---- .../awsoidc/idp_iam_config_test.go | 78 +++- lib/integrations/awsoidc/tags/tags.go | 20 + .../testdata/TestConfigureIdPIAMOutput.golden | 67 ++++ lib/srv/db/cloud/aws.go | 14 +- tool/teleport/common/integration_configure.go | 1 + tool/teleport/common/teleport.go | 1 + 17 files changed, 1294 insertions(+), 196 deletions(-) create mode 100644 lib/cloud/provisioning/awsactions/common.go create mode 100644 lib/cloud/provisioning/awsactions/create_oidc_idp.go create mode 100644 lib/cloud/provisioning/awsactions/create_role.go create mode 100644 lib/cloud/provisioning/operations.go create mode 100644 lib/cloud/provisioning/operations_test.go create mode 100644 lib/integrations/awsoidc/testdata/TestConfigureIdPIAMOutput.golden diff --git a/lib/cloud/aws/policy.go b/lib/cloud/aws/policy.go index 51f44dab7258e..0208dd8ae6bba 100644 --- a/lib/cloud/aws/policy.go +++ b/lib/cloud/aws/policy.go @@ -84,25 +84,71 @@ type Statement struct { // Condition: // StringEquals: // "proxy.example.com:aud": "discover.teleport" - Conditions map[string]map[string]SliceOrString `json:"Condition,omitempty"` + Conditions Conditions `json:"Condition,omitempty"` // StatementID is an optional identifier for the statement. StatementID string `json:"Sid,omitempty"` } +// Conditions is a list of conditions that must be satisfied for an action to be allowed. +type Conditions map[string]StringOrMap + +// Equals returns true if conditions are equal. +func (a Conditions) Equals(b Conditions) bool { + if len(a) != len(b) { + return false + } + for conditionKindA, conditionOpA := range a { + conditionOpB := b[conditionKindA] + if !conditionOpA.Equals(conditionOpB) { + return false + } + } + return true +} + // ensureResource ensures that the statement contains the specified resource. // -// Returns true if the resource was already a part of the statement. +// Returns true if the resource was added to the statement or false if the +// resource was already part of the statement. func (s *Statement) ensureResource(resource string) bool { if slices.Contains(s.Resources, resource) { - return true + return false } s.Resources = append(s.Resources, resource) - return false + return true } -func (s *Statement) ensureResources(resources []string) { +func (s *Statement) ensureResources(resources []string) bool { + var updated bool for _, resource := range resources { - s.ensureResource(resource) + updated = s.ensureResource(resource) || updated + } + return updated +} + +// ensurePrincipal ensures that the statement contains the specified principal. +// +// Returns true if the principal was already a part of the statement. +func (s *Statement) ensurePrincipal(kind string, value string) bool { + if len(s.Principals) == 0 { + s.Principals = make(StringOrMap) + } + values := s.Principals[kind] + if slices.Contains(values, value) { + return false + } + values = append(values, value) + s.Principals[kind] = values + return true +} + +func (s *Statement) ensurePrincipals(principals StringOrMap) bool { + var updated bool + for kind, values := range principals { + for _, v := range values { + updated = s.ensurePrincipal(kind, v) || updated + } } + return updated } // EqualStatement returns whether the receive statement is the same. @@ -115,39 +161,17 @@ func (s *Statement) EqualStatement(other *Statement) bool { return false } - if len(s.Principals) != len(other.Principals) { + if !s.Principals.Equals(other.Principals) { return false } - for principalKind, principalList := range s.Principals { - expectedPrincipalList := other.Principals[principalKind] - if !slices.Equal(principalList, expectedPrincipalList) { - return false - } - } - if !slices.Equal(s.Resources, other.Resources) { return false } - if len(s.Conditions) != len(other.Conditions) { + if !s.Conditions.Equals(other.Conditions) { return false } - for conditionKind, conditionOp := range s.Conditions { - expectedConditionOp := other.Conditions[conditionKind] - - if len(conditionOp) != len(expectedConditionOp) { - return false - } - - for conditionOpKind, conditionOpList := range conditionOp { - expectedConditionOpList := expectedConditionOp[conditionOpKind] - if !slices.Equal(conditionOpList, expectedConditionOpList) { - return false - } - } - } - return true } @@ -174,33 +198,38 @@ func NewPolicyDocument(statements ...*Statement) *PolicyDocument { } } -// Ensure ensures that the policy document contains the specified resource -// action. +// EnsureResourceAction ensures that the policy document contains the specified +// resource action. // -// Returns true if the resource action was already a part of the policy and -// false otherwise. -func (p *PolicyDocument) Ensure(effect, action, resource string) bool { - if existingStatement := p.findStatement(effect, action); existingStatement != nil { +// Returns true if the resource action was added to the policy and false if it +// was already part of the policy. +func (p *PolicyDocument) EnsureResourceAction(effect, action, resource string, conditions Conditions) bool { + if existingStatement := p.findStatement(effect, action, conditions); existingStatement != nil { return existingStatement.ensureResource(resource) } // No statement yet for this resource action, add it. p.Statements = append(p.Statements, &Statement{ - Effect: effect, - Actions: []string{action}, - Resources: []string{resource}, + Effect: effect, + Actions: []string{action}, + Resources: []string{resource}, + Conditions: conditions, }) - return false + return true } // Delete deletes the specified resource action from the policy. -func (p *PolicyDocument) Delete(effect, action, resource string) { +func (p *PolicyDocument) DeleteResourceAction(effect, action, resource string, conditions Conditions) { var statements []*Statement for _, s := range p.Statements { if s.Effect != effect { statements = append(statements, s) continue } + if !s.Conditions.Equals(conditions) { + statements = append(statements, s) + continue + } var resources []string for _, a := range s.Actions { for _, r := range s.Resources { @@ -225,7 +254,8 @@ func (p *PolicyDocument) Delete(effect, action, resource string) { // // The main benefit of using this function (versus appending to p.Statements // directly) is to avoid duplications. -func (p *PolicyDocument) EnsureStatements(statements ...*Statement) { +func (p *PolicyDocument) EnsureStatements(statements ...*Statement) bool { + var updated bool for _, statement := range statements { if statement == nil { continue @@ -234,8 +264,9 @@ func (p *PolicyDocument) EnsureStatements(statements ...*Statement) { // Try to find an existing statement by the action, and add the resources there. var newActions []string for _, action := range statement.Actions { - if existingStatement := p.findStatement(statement.Effect, action); existingStatement != nil { - existingStatement.ensureResources(statement.Resources) + if existingStatement := p.findStatement(statement.Effect, action, statement.Conditions); existingStatement != nil { + updated = existingStatement.ensureResources(statement.Resources) || updated + updated = existingStatement.ensurePrincipals(statement.Principals) || updated } else { newActions = append(newActions, action) } @@ -244,12 +275,21 @@ func (p *PolicyDocument) EnsureStatements(statements ...*Statement) { // Add the leftover actions as a new statement. if len(newActions) > 0 { p.Statements = append(p.Statements, &Statement{ - Effect: statement.Effect, - Actions: newActions, - Resources: statement.Resources, + Effect: statement.Effect, + Actions: newActions, + Resources: statement.Resources, + Conditions: statement.Conditions, + Principals: statement.Principals, }) + updated = true } } + return updated +} + +// IsEmpty returns whether the policy document is empty. +func (p *PolicyDocument) IsEmpty() bool { + return len(p.Statements) == 0 } // Marshal formats the PolicyDocument in a "friendly" format, which can be @@ -264,27 +304,30 @@ func (p *PolicyDocument) Marshal() (string, error) { } // ForEach loops through each action and resource of each statement. -func (p *PolicyDocument) ForEach(fn func(effect, action, resource string)) { +func (p *PolicyDocument) ForEach(fn func(effect, action, resource string, conditions Conditions)) { for _, statement := range p.Statements { for _, action := range statement.Actions { for _, resource := range statement.Resources { - fn(statement.Effect, action, resource) + fn(statement.Effect, action, resource, statement.Conditions) } } } } -func (p *PolicyDocument) findStatement(effect, action string) *Statement { +func (p *PolicyDocument) findStatement(effect, action string, conditions Conditions) *Statement { for _, s := range p.Statements { if s.Effect != effect { continue } - if slices.Contains(s.Actions, action) { - return s + if !slices.Contains(s.Actions, action) { + continue + } + if !s.Conditions.Equals(conditions) { + continue } + return s } return nil - } // SliceOrString defines a type that can be either a single string or a slice. @@ -337,6 +380,21 @@ func (s SliceOrString) MarshalJSON() ([]byte, error) { // https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_principal.html#principal-anonymous type StringOrMap map[string]SliceOrString +// Equals returns true if this StringOrMap is equal to another StringOrMap. +func (s StringOrMap) Equals(other StringOrMap) bool { + if len(s) != len(other) { + return false + } + + for key, list := range s { + otherList := other[key] + if !slices.Equal(list, otherList) { + return false + } + } + return true +} + // UnmarshalJSON implements json.Unmarshaller. // If it contains a string and not a map, it will create a map with a single entry: // { "str": [] } diff --git a/lib/cloud/aws/policy_statements.go b/lib/cloud/aws/policy_statements.go index a6fb3f3556739..644981e3ccdab 100644 --- a/lib/cloud/aws/policy_statements.go +++ b/lib/cloud/aws/policy_statements.go @@ -161,7 +161,7 @@ func StatementForAWSAppAccess() *Statement { "sts:AssumeRole", }, Resources: allResources, - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "StringEquals": { "iam:ResourceTag/" + requiredTag: SliceOrString{"true"}, }, @@ -198,7 +198,7 @@ func StatementForAWSOIDCRoleTrustRelationship(accountID, providerURL string, aud Principals: map[string]SliceOrString{ "Federated": []string{federatedARN}, }, - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "StringEquals": { federatedAudience: audiences, }, diff --git a/lib/cloud/aws/policy_test.go b/lib/cloud/aws/policy_test.go index b76cc70684de7..d9085a7f586a3 100644 --- a/lib/cloud/aws/policy_test.go +++ b/lib/cloud/aws/policy_test.go @@ -317,7 +317,7 @@ func TestMarshalPolicyDocument(t *testing.T) { Principals: map[string]SliceOrString{ "Federated": {"arn:aws:iam::123456789012:oidc-provider/proxy.example.com"}, }, - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "StringEquals": { "proxy.example.com:aud": SliceOrString{"discover.teleport"}, }, @@ -353,8 +353,8 @@ func TestIAMPolicy(t *testing.T) { policy := NewPolicyDocument() // Add a new action/resource. - alreadyExisted := policy.Ensure(EffectAllow, "action-1", "resource-1") - require.False(t, alreadyExisted) + updated := policy.EnsureResourceAction(EffectAllow, "action-1", "resource-1", nil) + require.True(t, updated) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, Statements: []*Statement{ @@ -367,8 +367,8 @@ func TestIAMPolicy(t *testing.T) { }, policy) // Add the same action/resource. - alreadyExisted = policy.Ensure(EffectAllow, "action-1", "resource-1") - require.True(t, alreadyExisted) + updated = policy.EnsureResourceAction(EffectAllow, "action-1", "resource-1", nil) + require.False(t, updated) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, Statements: []*Statement{ @@ -381,8 +381,8 @@ func TestIAMPolicy(t *testing.T) { }, policy) // Add a new resource to existing action. - alreadyExisted = policy.Ensure(EffectAllow, "action-1", "resource-2") - require.False(t, alreadyExisted) + updated = policy.EnsureResourceAction(EffectAllow, "action-1", "resource-2", nil) + require.True(t, updated) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, Statements: []*Statement{ @@ -395,8 +395,8 @@ func TestIAMPolicy(t *testing.T) { }, policy) // Add another action/resource. - alreadyExisted = policy.Ensure(EffectAllow, "action-2", "resource-3") - require.False(t, alreadyExisted) + updated = policy.EnsureResourceAction(EffectAllow, "action-2", "resource-3", nil) + require.True(t, updated) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, Statements: []*Statement{ @@ -414,7 +414,7 @@ func TestIAMPolicy(t *testing.T) { }, policy) // Delete existing resource action. - policy.Delete(EffectAllow, "action-1", "resource-1") + policy.DeleteResourceAction(EffectAllow, "action-1", "resource-1", nil) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, Statements: []*Statement{ @@ -432,7 +432,7 @@ func TestIAMPolicy(t *testing.T) { }, policy) // Delete last resource from first action, statement should get removed as well. - policy.Delete(EffectAllow, "action-1", "resource-2") + policy.DeleteResourceAction(EffectAllow, "action-1", "resource-2", nil) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, Statements: []*Statement{ @@ -445,7 +445,7 @@ func TestIAMPolicy(t *testing.T) { }, policy) // Delete last resource action, policy should be empty. - policy.Delete(EffectAllow, "action-2", "resource-3") + policy.DeleteResourceAction(EffectAllow, "action-2", "resource-3", nil) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, }, policy) @@ -466,7 +466,7 @@ func TestIAMPolicy(t *testing.T) { }, }, } - policy.Delete(EffectAllow, "action-1", "resource-1") + policy.DeleteResourceAction(EffectAllow, "action-1", "resource-1", nil) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, }, policy) @@ -487,7 +487,7 @@ func TestIAMPolicy(t *testing.T) { }, }, } - policy.Delete(EffectAllow, "action-1", "resource-2") + policy.DeleteResourceAction(EffectAllow, "action-1", "resource-2", nil) require.Equal(t, &PolicyDocument{ Version: PolicyVersion, Statements: []*Statement{ @@ -532,6 +532,19 @@ func TestPolicyEnsureStatements(t *testing.T) { Actions: []string{"action-1"}, Resources: []string{"resource-3"}, }, + // Existing action with different condition and new principals + &Statement{ + Effect: EffectAllow, + Actions: []string{"action-1"}, + Principals: StringOrMap{ + "Federated": []string{"arn:aws:iam::123456789012:oidc-provider/example.com"}, + }, + Conditions: Conditions{ + "StringEquals": StringOrMap{ + "example.com:aud": []string{"discover.teleport"}, + }, + }, + }, // New actions and new resources. &Statement{ Effect: EffectAllow, @@ -565,6 +578,18 @@ func TestPolicyEnsureStatements(t *testing.T) { Actions: []string{"action-2"}, Resources: []string{"resource-1", "resource-4"}, }, + { + Effect: EffectAllow, + Actions: []string{"action-1"}, + Principals: StringOrMap{ + "Federated": []string{"arn:aws:iam::123456789012:oidc-provider/example.com"}, + }, + Conditions: Conditions{ + "StringEquals": StringOrMap{ + "example.com:aud": []string{"discover.teleport"}, + }, + }, + }, { Effect: EffectAllow, Actions: []string{"action-3", "action-4"}, @@ -931,13 +956,13 @@ func TestEqualStatement(t *testing.T) { { name: "different number of conditions", statementA: &Statement{ - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3600"}}, "StringLike": {"s3:prefix": []string{"janedoe/*"}}, }, }, statementB: &Statement{ - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3601"}}, }, }, @@ -946,12 +971,12 @@ func TestEqualStatement(t *testing.T) { { name: "different conditions", statementA: &Statement{ - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3600"}}, }, }, statementB: &Statement{ - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3601"}}, }, }, @@ -960,12 +985,12 @@ func TestEqualStatement(t *testing.T) { { name: "different condition values", statementA: &Statement{ - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3600", "3601"}}, }, }, statementB: &Statement{ - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3600"}}, }, }, @@ -990,7 +1015,7 @@ func TestEqualStatement(t *testing.T) { }, Actions: []string{"s3:GetObject"}, Resources: []string{"arn:aws:s3:::my-bucket/my-prefix/*"}, - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "StringLike": {"s3:prefix": []string{"my-prefix/*"}}, }, }, @@ -1001,7 +1026,7 @@ func TestEqualStatement(t *testing.T) { }, Actions: []string{"s3:GetObject"}, Resources: []string{"arn:aws:s3:::my-bucket/my-prefix/*"}, - Conditions: map[string]map[string]SliceOrString{ + Conditions: map[string]StringOrMap{ "StringLike": {"s3:prefix": []string{"my-prefix/*"}}, }, }, diff --git a/lib/cloud/provisioning/awsactions/common.go b/lib/cloud/provisioning/awsactions/common.go new file mode 100644 index 0000000000000..bd4a23f97c64b --- /dev/null +++ b/lib/cloud/provisioning/awsactions/common.go @@ -0,0 +1,32 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package awsactions + +import ( + "encoding/json" + + "github.com/gravitational/trace" +) + +func formatDetails(in any) (string, error) { + const prefix = "" + const indent = " " + out, err := json.MarshalIndent(in, prefix, indent) + return string(out), trace.Wrap(err) +} diff --git a/lib/cloud/provisioning/awsactions/create_oidc_idp.go b/lib/cloud/provisioning/awsactions/create_oidc_idp.go new file mode 100644 index 0000000000000..d90425d80f06c --- /dev/null +++ b/lib/cloud/provisioning/awsactions/create_oidc_idp.go @@ -0,0 +1,81 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package awsactions + +import ( + "context" + "log/slog" + + "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/gravitational/trace" + + awslib "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/provisioning" + "github.com/gravitational/teleport/lib/integrations/awsoidc/tags" +) + +// OpenIDConnectProviderCreator can create an OpenID Connect Identity Provider +// (OIDC IdP) in AWS IAM. +type OpenIDConnectProviderCreator interface { + // CreateOpenIDConnectProvider creates an AWS IAM OIDC IdP. + CreateOpenIDConnectProvider(ctx context.Context, params *iam.CreateOpenIDConnectProviderInput, optFns ...func(*iam.Options)) (*iam.CreateOpenIDConnectProviderOutput, error) +} + +// CreateOIDCProvider wraps a [OpenIDConnectPRoviderCreator] in a +// [provisioning.Action] that creates an OIDC IdP in AWS IAM when invoked. +func CreateOIDCProvider( + clt OpenIDConnectProviderCreator, + thumbprints []string, + issuerURL string, + clientIDs []string, + tags tags.AWSTags, +) (*provisioning.Action, error) { + input := &iam.CreateOpenIDConnectProviderInput{ + ThumbprintList: thumbprints, + Url: &issuerURL, + ClientIDList: clientIDs, + Tags: tags.ToIAMTags(), + } + details, err := formatDetails(input) + if err != nil { + return nil, trace.Wrap(err) + } + + config := provisioning.ActionConfig{ + Name: "CreateOpenIDConnectProvider", + Summary: "Create an OpenID Connect identity provider in AWS IAM for your Teleport cluster", + Details: details, + RunnerFn: func(ctx context.Context) error { + slog.InfoContext(ctx, "Creating OpenID Connect identity provider") + _, err = clt.CreateOpenIDConnectProvider(ctx, input) + if err != nil { + awsErr := awslib.ConvertIAMv2Error(err) + if trace.IsAlreadyExists(awsErr) { + slog.InfoContext(ctx, "OpenID Connect identity provider already exists") + return nil + } + + return trace.Wrap(err) + } + return nil + }, + } + action, err := provisioning.NewAction(config) + return action, trace.Wrap(err) +} diff --git a/lib/cloud/provisioning/awsactions/create_role.go b/lib/cloud/provisioning/awsactions/create_role.go new file mode 100644 index 0000000000000..c1ed634706fc4 --- /dev/null +++ b/lib/cloud/provisioning/awsactions/create_role.go @@ -0,0 +1,215 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package awsactions + +import ( + "context" + "fmt" + "log/slog" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/iam" + iamtypes "github.com/aws/aws-sdk-go-v2/service/iam/types" + "github.com/gravitational/trace" + + awslib "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/provisioning" + "github.com/gravitational/teleport/lib/integrations/awsoidc/tags" +) + +// RoleCreator can create an IAM role. +type RoleCreator interface { + // CreateRole creates a new IAM Role. + CreateRole(ctx context.Context, params *iam.CreateRoleInput, optFns ...func(*iam.Options)) (*iam.CreateRoleOutput, error) +} + +// RoleGetter can get an IAM role. +type RoleGetter interface { + // GetRole retrieves information about the specified role, including the + // role's path, GUID, ARN, and the role's trust policy that grants + // permission to assume the role. + GetRole(ctx context.Context, params *iam.GetRoleInput, optFns ...func(*iam.Options)) (*iam.GetRoleOutput, error) +} + +// AssumeRolePolicyUpdater can update an IAM role's trust policy. +type AssumeRolePolicyUpdater interface { + // UpdateAssumeRolePolicy updates the policy that grants an IAM entity + // permission to assume a role. + // This is typically referred to as the "role trust policy". + UpdateAssumeRolePolicy(ctx context.Context, params *iam.UpdateAssumeRolePolicyInput, optFns ...func(*iam.Options)) (*iam.UpdateAssumeRolePolicyOutput, error) +} + +// RoleTagger can tag an AWS IAM role. +type RoleTagger interface { + // TagRole adds one or more tags to an IAM role. The role can be a regular + // role or a service-linked role. If a tag with the same key name already + // exists, then that tag is overwritten with the new value. + TagRole(ctx context.Context, params *iam.TagRoleInput, optFns ...func(*iam.Options)) (*iam.TagRoleOutput, error) +} + +// CreateRole returns a [provisioning.Action] that creates or updates an IAM +// role when invoked. +func CreateRole( + clt interface { + AssumeRolePolicyUpdater + RoleCreator + RoleGetter + RoleTagger + }, + roleName string, + description string, + trustPolicy *awslib.PolicyDocument, + tags tags.AWSTags, +) (*provisioning.Action, error) { + trustPolicyJSON, err := trustPolicy.Marshal() + if err != nil { + return nil, trace.Wrap(err) + } + input := &iam.CreateRoleInput{ + RoleName: &roleName, + Description: &description, + AssumeRolePolicyDocument: &trustPolicyJSON, + Tags: tags.ToIAMTags(), + } + type createRoleInput struct { + // AssumeRolePolicyDocument shadows the input's field of the same name + // to marshal the trust policy doc as unescpaed JSON. + AssumeRolePolicyDocument *awslib.PolicyDocument + *iam.CreateRoleInput + } + details, err := formatDetails(createRoleInput{ + AssumeRolePolicyDocument: trustPolicy, + CreateRoleInput: input, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + config := provisioning.ActionConfig{ + Name: "CreateRole", + Summary: fmt.Sprintf("Create IAM role %q with a custom trust policy", roleName), + Details: details, + RunnerFn: func(ctx context.Context) error { + slog.InfoContext(ctx, "Checking for existing IAM role", + "role", roleName, + ) + getRoleOut, err := clt.GetRole(ctx, &iam.GetRoleInput{ + RoleName: &roleName, + }) + if err != nil { + convertedErr := awslib.ConvertIAMv2Error(err) + if !trace.IsNotFound(convertedErr) { + return trace.Wrap(convertedErr) + } + slog.InfoContext(ctx, "Creating IAM role", "role", roleName) + _, err = clt.CreateRole(ctx, input) + if err != nil { + return trace.Wrap(awslib.ConvertIAMv2Error(err)) + } + return nil + } + + slog.InfoContext(ctx, "IAM role already exists", + "role", roleName, + ) + existingTrustPolicy, err := awslib.ParsePolicyDocument(aws.ToString(getRoleOut.Role.AssumeRolePolicyDocument)) + if err != nil { + return trace.Wrap(err) + } + err = ensureTrustPolicy(ctx, clt, roleName, trustPolicy, existingTrustPolicy) + if err != nil { + return trace.Wrap(err) + } + + err = ensureTags(ctx, clt, roleName, tags, getRoleOut.Role.Tags) + if err != nil { + // Tagging an existing role after we update it is a + // nice-to-have, but not a need-to-have. + slog.WarnContext(ctx, "Failed to update IAM role tags", + "role", roleName, + "error", err, + "tags", tags.ToMap(), + ) + } + return nil + }, + } + action, err := provisioning.NewAction(config) + return action, trace.Wrap(err) +} + +func ensureTrustPolicy( + ctx context.Context, + clt AssumeRolePolicyUpdater, + roleName string, + trustPolicy *awslib.PolicyDocument, + existingTrustPolicy *awslib.PolicyDocument, +) error { + slog.InfoContext(ctx, "Checking IAM role trust policy", + "role", roleName, + ) + + if !existingTrustPolicy.EnsureStatements(trustPolicy.Statements...) { + slog.InfoContext(ctx, "IAM role trust policy does not require update", + "role", roleName, + ) + return nil + } + + slog.InfoContext(ctx, "Updating IAM role trust policy", + "role", roleName, + ) + trustPolicyJSON, err := existingTrustPolicy.Marshal() + if err != nil { + return trace.Wrap(err) + } + + _, err = clt.UpdateAssumeRolePolicy(ctx, &iam.UpdateAssumeRolePolicyInput{ + RoleName: &roleName, + PolicyDocument: &trustPolicyJSON, + }) + return trace.Wrap(err) +} + +func ensureTags( + ctx context.Context, + clt RoleTagger, + roleName string, + tags tags.AWSTags, + existingTags []iamtypes.Tag, +) error { + slog.InfoContext(ctx, "Checking for tags on IAM role", + "role", roleName, + ) + if tags.MatchesIAMTags(existingTags) { + slog.InfoContext(ctx, "IAM role is already tagged", + "role", roleName, + ) + return nil + } + + slog.InfoContext(ctx, "Updating IAM role tags", + "role", roleName, + ) + _, err := clt.TagRole(ctx, &iam.TagRoleInput{ + RoleName: &roleName, + Tags: tags.ToIAMTags(), + }) + return trace.Wrap(err) +} diff --git a/lib/cloud/provisioning/operations.go b/lib/cloud/provisioning/operations.go new file mode 100644 index 0000000000000..8b16a445ccace --- /dev/null +++ b/lib/cloud/provisioning/operations.go @@ -0,0 +1,219 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package provisioning + +import ( + "context" + "fmt" + "io" + "log/slog" + "os" + "regexp" + "strings" + "text/template" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/utils/prompt" +) + +// validName is used to ensure that [OperationConfig] and [ActionConfig] names +// start with a letter and only consist of letters, numbers, and hyphen +// characters thereafter. +var validName = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9-]+$`) + +// OperationConfig is the configuration for an operation. +type OperationConfig struct { + // Name is the operation name. Must consist of only letters and hyphens. + Name string + // Actions is the list of actions that make up the operation. + Actions []Action + // AutoConfirm is whether to skip the operation plan confirmation prompt. + AutoConfirm bool + // Output is an [io.Writer] where the operation plan and confirmation prompt + // are written to. + // Defaults to [os.Stdout]. + Output io.Writer +} + +// checkAndSetDefaults validates the operation config and sets defaults. +func (c *OperationConfig) checkAndSetDefaults() error { + c.Name = strings.TrimSpace(c.Name) + if c.Name == "" { + return trace.BadParameter("missing operation name") + } + if !validName.MatchString(c.Name) { + return trace.BadParameter( + "operation name %q does not match regex used for validation %q", + c.Name, validName.String(), + ) + } + if len(c.Actions) == 0 { + return trace.BadParameter("missing operation actions") + } + if c.Output == nil { + c.Output = os.Stdout + } + return nil +} + +// Action wraps a runnable function to provide a name, summary, and detailed +// explanation of the behavior. +type Action struct { + // config is an unexported value-type to prevent mutation after it's been + // validated by the checkAndSetDefaults func. + config ActionConfig +} + +// GetName returns the action's configured name. +func (a *Action) GetName() string { + return a.config.Name +} + +// GetSummary returns the action's configured summary. +func (a *Action) GetSummary() string { + return a.config.Summary +} + +// GetDetails returns the action's configured details. +func (a *Action) GetDetails() string { + return a.config.Details +} + +// Run runs the action. +func (a *Action) Run(ctx context.Context) error { + return a.config.RunnerFn(ctx) +} + +// NewAction creates a new [Action]. +func NewAction(config ActionConfig) (*Action, error) { + if err := config.checkAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + return &Action{ + config: config, + }, nil +} + +// ActionConfig is the configuration for an [Action]. +type ActionConfig struct { + // Name is the action name. + Name string + // Summary is the action summary in prose. + Summary string + // Details is the detailed explanation of the action, explaining exactly + // what it will do. + Details string + // RunnerFn is a function that actually runs the action. + RunnerFn func(context.Context) error +} + +// checkAndSetDefaults validates the action config and sets defaults. +func (c *ActionConfig) checkAndSetDefaults() error { + c.Name = strings.TrimSpace(c.Name) + c.Summary = strings.TrimSpace(c.Summary) + c.Details = strings.TrimSpace(c.Details) + if c.Name == "" { + return trace.BadParameter("missing action name") + } + if !validName.MatchString(c.Name) { + return trace.BadParameter( + "action name %q does not match regex used for validation %q", + c.Name, validName.String(), + ) + } + if c.Summary == "" { + return trace.BadParameter("missing action summary") + } + if c.Details == "" { + return trace.BadParameter("missing action details") + } + if c.RunnerFn == nil { + return trace.BadParameter("missing action runner") + } + + return nil +} + +// Run writes the operation plan, optionally prompts for user confirmation, +// then executes the operation plan. +func Run(ctx context.Context, config OperationConfig) error { + if err := config.checkAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + + if err := writeOperationPlan(config); err != nil { + return trace.Wrap(err) + } + + if !config.AutoConfirm { + question := fmt.Sprintf("Do you want %q to perform these actions?", config.Name) + ok, err := prompt.Confirmation(ctx, config.Output, prompt.Stdin(), question) + if err != nil { + return trace.Wrap(err) + } + if !ok { + return trace.BadParameter("operation %q canceled", config.Name) + } + } + + enumerateSteps := len(config.Actions) > 1 + for i, action := range config.Actions { + if enumerateSteps { + slog.InfoContext(ctx, "Running", "step", i+1, "action", action.config.Name) + } else { + slog.InfoContext(ctx, "Running", "action", action.config.Name) + } + + if err := action.Run(ctx); err != nil { + if enumerateSteps { + return trace.Wrap(err, "step %d %q failed", i+1, action.config.Name) + } + return trace.Wrap(err, "%q failed", action.config.Name) + } + } + slog.InfoContext(ctx, "Success!", "operation", config.Name) + return nil +} + +// writeOperationPlan writes the operational plan to the given [io.Writer] as +// a structured summary of the operation and the actions that compose it. +func writeOperationPlan(config OperationConfig) error { + data := map[string]any{ + "config": config, + "showStepNumbers": len(config.Actions) > 1, + } + return trace.Wrap(operationPlanTemplate.Execute(config.Output, data)) +} + +var operationPlanTemplate = template.Must(template.New("plan"). + Funcs(template.FuncMap{ + // used to enumerate the action steps starting from 1 instead of 0. + "addOne": func(x int) int { return x + 1 }, + }). + Parse(` +{{- printf "%q" .config.Name }} will perform the following actions: + +{{ $global := . }} +{{- range $index, $action := .config.Actions }} +{{- if $global.showStepNumbers }}{{ $index | addOne }}. {{ end -}}{{$action.GetSummary}}. +{{$action.GetName}}: {{$action.GetDetails}} + +{{end -}} +`)) diff --git a/lib/cloud/provisioning/operations_test.go b/lib/cloud/provisioning/operations_test.go new file mode 100644 index 0000000000000..56f8af546c7b8 --- /dev/null +++ b/lib/cloud/provisioning/operations_test.go @@ -0,0 +1,372 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package provisioning + +import ( + "bytes" + "context" + "os" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +func withErrCheck(fn func(*testing.T, error)) require.ErrorAssertionFunc { + return func(t require.TestingT, err error, msgAndArgs ...interface{}) { + tt := t.(*testing.T) + tt.Helper() + fn(tt, err) + } +} + +var badParameterCheck = withErrCheck(func(t *testing.T, err error) { + t.Helper() + require.True(t, trace.IsBadParameter(err), `expected "bad parameter", but got %v`, err) +}) + +func TestOperationCheckAndSetDefaults(t *testing.T) { + stubAction := Action{} + baseCfg := OperationConfig{ + Name: "name", + Actions: []Action{stubAction}, + AutoConfirm: true, + Output: os.Stdout, + } + tests := []struct { + desc string + getConfig func() OperationConfig + want OperationConfig + wantErrContains string + }{ + { + desc: "valid config", + getConfig: func() OperationConfig { return baseCfg }, + want: baseCfg, + }, + { + desc: "defaults output to stdout", + getConfig: func() OperationConfig { + cfg := baseCfg + cfg.Output = nil + return cfg + }, + want: baseCfg, + }, + { + desc: "trims trailing and leading spaces", + getConfig: func() OperationConfig { + cfg := baseCfg + cfg.Name = "\t\n " + cfg.Name + "\t\n " + return cfg + }, + want: baseCfg, + }, + { + desc: "missing name is an error", + getConfig: func() OperationConfig { + cfg := baseCfg + cfg.Name = "" + return cfg + }, + wantErrContains: "missing operation name", + }, + { + desc: "spaces in name is an error", + getConfig: func() OperationConfig { + cfg := baseCfg + cfg.Name = "some thing" + return cfg + }, + wantErrContains: "does not match regex", + }, + { + desc: "missing actions is an error", + getConfig: func() OperationConfig { + cfg := baseCfg + cfg.Actions = nil + return cfg + }, + wantErrContains: "missing operation actions", + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + got := test.getConfig() + err := got.checkAndSetDefaults() + if test.wantErrContains != "" { + require.Error(t, err) + require.ErrorContains(t, err, test.wantErrContains) + return + } + require.Empty(t, cmp.Diff(test.want, got, + cmp.AllowUnexported(Action{}), + cmpopts.IgnoreFields(OperationConfig{}, "Output"), + )) + require.NotNil(t, got.Output) + }) + } +} + +func TestNewAction(t *testing.T) { + noOpRunnerFn := func(context.Context) error { return nil } + baseCfg := ActionConfig{ + Name: "name", + Summary: "summary", + Details: "details", + RunnerFn: noOpRunnerFn, + } + tests := []struct { + desc string + getConfig func() ActionConfig + errCheck require.ErrorAssertionFunc + want Action + }{ + { + desc: "valid config", + getConfig: func() ActionConfig { return baseCfg }, + errCheck: require.NoError, + want: Action{config: baseCfg}, + }, + { + desc: "trims trailing and leading spaces", + getConfig: func() ActionConfig { + cfg := baseCfg + cfg.Name = "\t\n " + cfg.Name + "\t\n " + return cfg + }, + errCheck: require.NoError, + want: Action{config: baseCfg}, + }, + { + desc: "missing name is an error", + getConfig: func() ActionConfig { + cfg := baseCfg + cfg.Name = "" + return cfg + }, + errCheck: badParameterCheck, + }, + { + desc: "spaces in name is an error", + getConfig: func() ActionConfig { + cfg := baseCfg + cfg.Name = "some thing" + return cfg + }, + errCheck: badParameterCheck, + }, + { + desc: "missing summary is an error", + getConfig: func() ActionConfig { + cfg := baseCfg + cfg.Summary = "" + return cfg + }, + errCheck: badParameterCheck, + }, + { + desc: "missing details is an error", + getConfig: func() ActionConfig { + cfg := baseCfg + cfg.Details = "" + return cfg + }, + errCheck: badParameterCheck, + }, + { + desc: "missing runner is an error", + getConfig: func() ActionConfig { + cfg := baseCfg + cfg.RunnerFn = nil + return cfg + }, + errCheck: badParameterCheck, + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + _, err := NewAction(test.getConfig()) + test.errCheck(t, err) + }) + } +} + +func TestRun(t *testing.T) { + actionA := Action{ + config: ActionConfig{ + Name: "actionA", + Summary: "actionA summary", + Details: "actionA details", + RunnerFn: func(context.Context) error { return nil }, + }, + } + actionB := Action{ + config: ActionConfig{ + Name: "actionB", + Summary: "actionB summary", + Details: "actionB details", + RunnerFn: func(context.Context) error { return nil }, + }, + } + failingAction := Action{ + config: ActionConfig{ + Name: "actionC", + Summary: "actionC summary", + Details: "actionC details", + RunnerFn: func(context.Context) error { return trace.AccessDenied("access denied") }, + }, + } + + tests := []struct { + desc string + config OperationConfig + errCheck require.ErrorAssertionFunc + }{ + { + desc: "success", + config: OperationConfig{ + Name: "op-name", + Actions: []Action{ + actionA, + actionB, + }, + AutoConfirm: true, + }, + errCheck: require.NoError, + }, + { + desc: "failed action does not enumerate the failed step", + config: OperationConfig{ + Name: "op-name", + Actions: []Action{ + failingAction, + }, + AutoConfirm: true, + }, + errCheck: withErrCheck(func(t *testing.T, err error) { + t.Helper() + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + // the error should indicate the failed action + require.Contains(t, err.Error(), `"actionC" failed`) + // but it should not number the step, since there's only + // one action. + require.NotContains(t, err.Error(), `1`) + }), + }, + { + desc: "failed action fails the entire operation", + config: OperationConfig{ + Name: "op-name", + Actions: []Action{ + actionA, + failingAction, + actionB, + }, + AutoConfirm: true, + }, + errCheck: withErrCheck(func(t *testing.T, err error) { + t.Helper() + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + // error should indicate step number and action name that failed + require.Contains(t, err.Error(), `step 2 "actionC" failed`) + }), + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + err := Run(ctx, test.config) + test.errCheck(t, err) + }) + } +} + +func TestWriteOperationPlan(t *testing.T) { + actionA := Action{ + config: ActionConfig{ + Name: "actionA", + Summary: "", + Details: "", + RunnerFn: func(context.Context) error { return trace.NotImplemented("not implemented") }, + }, + } + actionB := Action{ + config: ActionConfig{ + Name: "actionB", + Summary: "", + Details: "", + RunnerFn: func(context.Context) error { return trace.NotImplemented("not implemented") }, + }, + } + tests := []struct { + desc string + config OperationConfig + want string + }{ + { + desc: "operation with only one action does not enumerate the step", + config: OperationConfig{ + Name: "op-name", + Actions: []Action{actionA}, + AutoConfirm: true, + }, + want: strings.TrimLeft(` +"op-name" will perform the following actions: + +. +actionA: + +`, "\n"), + }, + { + desc: "operation with multiple actions enumerates the steps", + config: OperationConfig{ + Name: "op-name", + Actions: []Action{actionA, actionB}, + AutoConfirm: true, + }, + want: strings.TrimLeft(` +"op-name" will perform the following actions: + +1. . +actionA: + +2. . +actionB: + +`, "\n"), + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + var got bytes.Buffer + test.config.Output = &got + require.NoError(t, writeOperationPlan(test.config)) + require.Empty(t, cmp.Diff(test.want, got.String())) + }) + } +} diff --git a/lib/config/configuration.go b/lib/config/configuration.go index ac3e56ff59e8b..c1cc43fc2783f 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -380,6 +380,8 @@ type IntegrationConfAWSOIDCIdP struct { // ProxyPublicURL is the IdP Issuer URL (Teleport Proxy Public Address). // Eg, https://.teleport.sh ProxyPublicURL string + // AutoConfirm skips user confirmation of the operation plan if true. + AutoConfirm bool } // IntegrationConfListDatabasesIAM contains the arguments of diff --git a/lib/configurators/aws/aws.go b/lib/configurators/aws/aws.go index 4e31b202e4d88..cece7047dfc13 100644 --- a/lib/configurators/aws/aws.go +++ b/lib/configurators/aws/aws.go @@ -523,7 +523,7 @@ func buildCommonActions(config ConfiguratorConfig, targetCfg targetConfig) ([]co // If the policy has no statements means that the agent doesn't require // any IAM permission. In this case, return without errors and with empty // actions. - if len(policy.Document.Statements) == 0 { + if policy.Document.IsEmpty() { return []configurators.ConfiguratorAction{}, nil } diff --git a/lib/integrations/awsoidc/idp_iam_config.go b/lib/integrations/awsoidc/idp_iam_config.go index 24d14ffdcb2ee..6fff791c34b5f 100644 --- a/lib/integrations/awsoidc/idp_iam_config.go +++ b/lib/integrations/awsoidc/idp_iam_config.go @@ -20,7 +20,7 @@ package awsoidc import ( "context" - "log/slog" + "io" "net/http" "net/url" @@ -32,6 +32,8 @@ import ( "github.com/gravitational/teleport/api/types" awslib "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/provisioning" + "github.com/gravitational/teleport/lib/cloud/provisioning/awsactions" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/integrations/awsoidc/tags" ) @@ -59,6 +61,9 @@ type IdPIAMConfigureRequest struct { // Eg, https://.teleport.sh, https://proxy.example.org:443, https://teleport.ec2.aws:3080 ProxyPublicAddress string + // AutoConfirm skips user confirmation of the operation plan if true. + AutoConfirm bool + // issuer is the above value but only contains the host. // Eg, .teleport.sh, proxy.example.org issuer string @@ -70,6 +75,12 @@ type IdPIAMConfigureRequest struct { IntegrationRole string ownershipTags tags.AWSTags + + // stdout is used to override stdout output in tests. + stdout io.Writer + // fakeThumbprint is used to override thumbprint in output tests, to produce + // consistent output. + fakeThumbprint string } // CheckAndSetDefaults ensures the required fields are present. @@ -109,21 +120,11 @@ func (r *IdPIAMConfigureRequest) CheckAndSetDefaults() error { // There is no guarantee that the client is thread safe. type IdPIAMConfigureClient interface { CallerIdentityGetter - - // CreateOpenIDConnectProvider creates an IAM OIDC IdP. - CreateOpenIDConnectProvider(ctx context.Context, params *iam.CreateOpenIDConnectProviderInput, optFns ...func(*iam.Options)) (*iam.CreateOpenIDConnectProviderOutput, error) - - // CreateRole creates a new IAM Role. - CreateRole(ctx context.Context, params *iam.CreateRoleInput, optFns ...func(*iam.Options)) (*iam.CreateRoleOutput, error) - - // GetRole retrieves information about the specified role, including the role's path, - // GUID, ARN, and the role's trust policy that grants permission to assume the - // role. - GetRole(ctx context.Context, params *iam.GetRoleInput, optFns ...func(*iam.Options)) (*iam.GetRoleOutput, error) - - // UpdateAssumeRolePolicy updates the policy that grants an IAM entity permission to assume a role. - // This is typically referred to as the "role trust policy". - UpdateAssumeRolePolicy(ctx context.Context, params *iam.UpdateAssumeRolePolicyInput, optFns ...func(*iam.Options)) (*iam.UpdateAssumeRolePolicyOutput, error) + awsactions.AssumeRolePolicyUpdater + awsactions.OpenIDConnectProviderCreator + awsactions.RoleCreator + awsactions.RoleGetter + awsactions.RoleTagger } type defaultIdPIAMConfigureClient struct { @@ -183,99 +184,53 @@ func ConfigureIdPIAM(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMC req.AccountID = aws.ToString(callerIdentity.Account) } - slog.InfoContext(ctx, "Creating IAM OpenID Connect Provider", "url", req.issuerURL) - if err := ensureOIDCIdPIAM(ctx, clt, req); err != nil { - return trace.Wrap(err) - } - - slog.InfoContext(ctx, "Creating IAM Role", "role", req.IntegrationRole) - if err := upsertIdPIAMRole(ctx, clt, req); err != nil { - return trace.Wrap(err) - } - - return nil -} - -func ensureOIDCIdPIAM(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) error { - thumbprint, err := ThumbprintIdP(ctx, req.ProxyPublicAddress) - if err != nil { - return trace.Wrap(err) - } - - _, err = clt.CreateOpenIDConnectProvider(ctx, &iam.CreateOpenIDConnectProviderInput{ - ThumbprintList: []string{thumbprint}, - Url: &req.issuerURL, - ClientIDList: []string{types.IntegrationAWSOIDCAudience}, - Tags: req.ownershipTags.ToIAMTags(), - }) + createOIDCIdP, err := createOIDCIdPAction(ctx, clt, req) if err != nil { - awsErr := awslib.ConvertIAMv2Error(err) - if trace.IsAlreadyExists(awsErr) { - return nil - } - return trace.Wrap(err) } - return nil -} - -func createIdPIAMRole(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) error { - integrationRoleAssumeRoleDocument, err := awslib.NewPolicyDocument( - awslib.StatementForAWSOIDCRoleTrustRelationship(req.AccountID, req.issuer, []string{types.IntegrationAWSOIDCAudience}), - ).Marshal() + createIdPIAMRole, err := createIdPIAMRoleAction(clt, req) if err != nil { return trace.Wrap(err) } - _, err = clt.CreateRole(ctx, &iam.CreateRoleInput{ - RoleName: &req.IntegrationRole, - Description: aws.String(descriptionOIDCIdPRole), - AssumeRolePolicyDocument: &integrationRoleAssumeRoleDocument, - Tags: req.ownershipTags.ToIAMTags(), - }) - return trace.Wrap(err) + return trace.Wrap(provisioning.Run(ctx, provisioning.OperationConfig{ + Name: "awsoidc-idp", + Actions: []provisioning.Action{ + *createOIDCIdP, + *createIdPIAMRole, + }, + AutoConfirm: req.AutoConfirm, + Output: req.stdout, + })) } -func upsertIdPIAMRole(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) error { - getRoleOut, err := clt.GetRole(ctx, &iam.GetRoleInput{ - RoleName: &req.IntegrationRole, - }) - if err != nil { - convertedErr := awslib.ConvertIAMv2Error(err) - if !trace.IsNotFound(convertedErr) { - return trace.Wrap(convertedErr) - } - - return trace.Wrap(createIdPIAMRole(ctx, clt, req)) - } - - if !req.ownershipTags.MatchesIAMTags(getRoleOut.Role.Tags) { - return trace.BadParameter("IAM Role %q already exists but is not managed by Teleport. "+ - "Add the following tags to allow Teleport to manage this Role: %s", req.IntegrationRole, req.ownershipTags) - } - - trustRelationshipDoc, err := awslib.ParsePolicyDocument(aws.ToString(getRoleOut.Role.AssumeRolePolicyDocument)) - if err != nil { - return trace.Wrap(err) - } - - trustRelationshipForIdP := awslib.StatementForAWSOIDCRoleTrustRelationship(req.AccountID, req.issuer, []string{types.IntegrationAWSOIDCAudience}) - for _, existingStatement := range trustRelationshipDoc.Statements { - if existingStatement.EqualStatement(trustRelationshipForIdP) { - return nil +func createOIDCIdPAction(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) (*provisioning.Action, error) { + var thumbprint string + if req.fakeThumbprint != "" { + // only happens in tests. + thumbprint = req.fakeThumbprint + } else { + var err error + thumbprint, err = ThumbprintIdP(ctx, req.ProxyPublicAddress) + if err != nil { + return nil, trace.Wrap(err) } } - trustRelationshipDoc.Statements = append(trustRelationshipDoc.Statements, trustRelationshipForIdP) - trustRelationshipDocString, err := trustRelationshipDoc.Marshal() - if err != nil { - return trace.Wrap(err) - } + clientIDs := []string{types.IntegrationAWSOIDCAudience} + thumbprints := []string{thumbprint} + return awsactions.CreateOIDCProvider(clt, thumbprints, req.issuerURL, clientIDs, req.ownershipTags) +} - _, err = clt.UpdateAssumeRolePolicy(ctx, &iam.UpdateAssumeRolePolicyInput{ - RoleName: &req.IntegrationRole, - PolicyDocument: &trustRelationshipDocString, - }) - return trace.Wrap(err) +func createIdPIAMRoleAction(clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) (*provisioning.Action, error) { + integrationRoleAssumeRoleDocument := awslib.NewPolicyDocument( + awslib.StatementForAWSOIDCRoleTrustRelationship(req.AccountID, req.issuer, []string{types.IntegrationAWSOIDCAudience}), + ) + return awsactions.CreateRole(clt, + req.IntegrationRole, + descriptionOIDCIdPRole, + integrationRoleAssumeRoleDocument, + req.ownershipTags, + ) } diff --git a/lib/integrations/awsoidc/idp_iam_config_test.go b/lib/integrations/awsoidc/idp_iam_config_test.go index 03bfcca1373c0..a7430d1bee953 100644 --- a/lib/integrations/awsoidc/idp_iam_config_test.go +++ b/lib/integrations/awsoidc/idp_iam_config_test.go @@ -19,6 +19,7 @@ package awsoidc import ( + "bytes" "context" "fmt" "net/http/httptest" @@ -35,6 +36,7 @@ import ( "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/integrations/awsoidc/tags" + "github.com/gravitational/teleport/lib/utils/golden" ) func TestIdPIAMConfigReqDefaults(t *testing.T) { @@ -44,6 +46,7 @@ func TestIdPIAMConfigReqDefaults(t *testing.T) { IntegrationName: "myintegration", IntegrationRole: "integrationrole", ProxyPublicAddress: "https://proxy.example.com", + AutoConfirm: true, } } @@ -69,6 +72,7 @@ func TestIdPIAMConfigReqDefaults(t *testing.T) { "teleport.dev/integration": "myintegration", "teleport.dev/origin": "integration_awsoidc", }, + AutoConfirm: true, }, }, { @@ -166,6 +170,7 @@ func TestConfigureIdPIAM(t *testing.T) { IntegrationName: "myintegration", IntegrationRole: "integrationrole", ProxyPublicAddress: tlsServer.URL, + AutoConfirm: true, } } @@ -195,15 +200,7 @@ func TestConfigureIdPIAM(t *testing.T) { errCheck: require.NoError, }, { - name: "role exists, no ownership tags", - mockAccountID: "123456789012", - mockExistingIdPUrl: []string{}, - mockExistingRoles: map[string]mockRole{"integrationrole": {}}, - req: baseIdPIAMConfigReqWithTLServer, - errCheck: badParameterCheck, - }, - { - name: "role exists, ownership tags, no assume role", + name: "role exists with empty trust policy", mockAccountID: "123456789012", mockExistingIdPUrl: []string{}, mockExistingRoles: map[string]mockRole{"integrationrole": { @@ -225,14 +222,12 @@ func TestConfigureIdPIAM(t *testing.T) { }, }, { - name: "role exists, ownership tags, with existing assume role", + name: "role exists with existing trust policy and without matching tags", mockAccountID: "123456789012", mockExistingIdPUrl: []string{}, mockExistingRoles: map[string]mockRole{"integrationrole": { tags: []iamtypes.Tag{ - {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, - {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, - {Key: aws.String("teleport.dev/integration"), Value: aws.String("myintegration")}, + {Key: aws.String("teleport.dev/origin"), Value: aws.String("should be overwritten")}, }, assumeRolePolicyDoc: policyDocWithStatementsJSON( assumeRoleStatementJSON("some-other-issuer"), @@ -247,10 +242,20 @@ func TestConfigureIdPIAM(t *testing.T) { assumeRoleStatementJSON(tlsServerIssuer), ) require.JSONEq(t, *expectedAssumeRolePolicyDoc, aws.ToString(role.assumeRolePolicyDoc)) + gotTags := map[string]string{} + for _, tag := range role.tags { + gotTags[aws.ToString(tag.Key)] = aws.ToString(tag.Value) + } + wantTags := map[string]string{ + "teleport.dev/origin": "integration_awsoidc", + "teleport.dev/cluster": "mycluster", + "teleport.dev/integration": "myintegration", + } + require.Equal(t, wantTags, gotTags) }, }, { - name: "role exists, ownership tags, assume role already exists", + name: "role exists with matching trust policy", mockAccountID: "123456789012", mockExistingIdPUrl: []string{}, mockExistingRoles: map[string]mockRole{"integrationrole": { @@ -291,6 +296,32 @@ func TestConfigureIdPIAM(t *testing.T) { } } +func TestConfigureIdPIAMOutput(t *testing.T) { + ctx := context.Background() + var buf bytes.Buffer + req := IdPIAMConfigureRequest{ + Cluster: "mycluster", + IntegrationName: "myintegration", + IntegrationRole: "integrationrole", + ProxyPublicAddress: "https://example.com", + AutoConfirm: true, + stdout: &buf, + fakeThumbprint: "15dbd260c7465ecca6de2c0b2181187f66ee0d1a", + } + + clt := mockIdPIAMConfigClient{ + CallerIdentityGetter: mockSTSClient{accountID: "123456789012"}, + existingRoles: map[string]mockRole{}, + existingIDPUrl: []string{}, + } + + require.NoError(t, ConfigureIdPIAM(ctx, &clt, req)) + if golden.ShouldSet() { + golden.Set(t, buf.Bytes()) + } + require.Equal(t, string(golden.Get(t)), buf.String()) +} + type mockRole struct { assumeRolePolicyDoc *string tags []iamtypes.Tag @@ -366,6 +397,25 @@ func (m *mockIdPIAMConfigClient) UpdateAssumeRolePolicy(ctx context.Context, par return &iam.UpdateAssumeRolePolicyOutput{}, nil } +func (m *mockIdPIAMConfigClient) TagRole(ctx context.Context, params *iam.TagRoleInput, _ ...func(*iam.Options)) (*iam.TagRoleOutput, error) { + roleName := aws.ToString(params.RoleName) + role, found := m.existingRoles[roleName] + if !found { + return nil, trace.NotFound("role not found") + } + + tags := tags.AWSTags{} + for _, existingTag := range role.tags { + tags[*existingTag.Key] = *existingTag.Value + } + for _, newTag := range params.Tags { + tags[*newTag.Key] = *newTag.Value + } + role.tags = tags.ToIAMTags() + m.existingRoles[roleName] = role + return &iam.TagRoleOutput{}, nil +} + func TestNewIdPIAMConfigureClient(t *testing.T) { t.Run("no aws_region env var, returns an error", func(t *testing.T) { _, err := NewIdPIAMConfigureClient(context.Background()) diff --git a/lib/integrations/awsoidc/tags/tags.go b/lib/integrations/awsoidc/tags/tags.go index 5a810f759a64b..fe4ab77991f28 100644 --- a/lib/integrations/awsoidc/tags/tags.go +++ b/lib/integrations/awsoidc/tags/tags.go @@ -19,8 +19,10 @@ package tags import ( + "cmp" "fmt" "maps" + "slices" "strings" athenatypes "github.com/aws/aws-sdk-go-v2/service/athena/types" @@ -67,6 +69,9 @@ func (d AWSTags) ToECSTags() []ecstypes.Tag { Value: &v, }) } + slices.SortFunc(ecsTags, func(a, b ecstypes.Tag) int { + return cmp.Compare(*a.Key, *b.Key) + }) return ecsTags } @@ -79,6 +84,9 @@ func (d AWSTags) ToEC2Tags() []ec2types.Tag { Value: &v, }) } + slices.SortFunc(ec2Tags, func(a, b ec2types.Tag) int { + return cmp.Compare(*a.Key, *b.Key) + }) return ec2Tags } @@ -125,6 +133,9 @@ func (d AWSTags) ToIAMTags() []iamtypes.Tag { Value: &v, }) } + slices.SortFunc(iamTags, func(a, b iamtypes.Tag) int { + return cmp.Compare(*a.Key, *b.Key) + }) return iamTags } @@ -137,6 +148,9 @@ func (d AWSTags) ToS3Tags() []s3types.Tag { Value: &v, }) } + slices.SortFunc(s3Tags, func(a, b s3types.Tag) int { + return cmp.Compare(*a.Key, *b.Key) + }) return s3Tags } @@ -149,6 +163,9 @@ func (d AWSTags) ToAthenaTags() []athenatypes.Tag { Value: &v, }) } + slices.SortFunc(athenaTags, func(a, b athenatypes.Tag) int { + return cmp.Compare(*a.Key, *b.Key) + }) return athenaTags } @@ -161,6 +178,9 @@ func (d AWSTags) ToSSMTags() []ssmtypes.Tag { Value: &v, }) } + slices.SortFunc(ssmTags, func(a, b ssmtypes.Tag) int { + return cmp.Compare(*a.Key, *b.Key) + }) return ssmTags } diff --git a/lib/integrations/awsoidc/testdata/TestConfigureIdPIAMOutput.golden b/lib/integrations/awsoidc/testdata/TestConfigureIdPIAMOutput.golden new file mode 100644 index 0000000000000..9cabf4b0d00d1 --- /dev/null +++ b/lib/integrations/awsoidc/testdata/TestConfigureIdPIAMOutput.golden @@ -0,0 +1,67 @@ +"awsoidc-idp" will perform the following actions: + +1. Create an OpenID Connect identity provider in AWS IAM for your Teleport cluster. +CreateOpenIDConnectProvider: { + "Url": "https://example.com", + "ClientIDList": [ + "discover.teleport" + ], + "Tags": [ + { + "Key": "teleport.dev/cluster", + "Value": "mycluster" + }, + { + "Key": "teleport.dev/integration", + "Value": "myintegration" + }, + { + "Key": "teleport.dev/origin", + "Value": "integration_awsoidc" + } + ], + "ThumbprintList": [ + "15dbd260c7465ecca6de2c0b2181187f66ee0d1a" + ] +} + +2. Create IAM role "integrationrole" with a custom trust policy. +CreateRole: { + "AssumeRolePolicyDocument": { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": "sts:AssumeRoleWithWebIdentity", + "Principal": { + "Federated": "arn:aws:iam::123456789012:oidc-provider/example.com" + }, + "Condition": { + "StringEquals": { + "example.com:aud": "discover.teleport" + } + } + } + ] + }, + "RoleName": "integrationrole", + "Description": "Used by Teleport to provide access to AWS resources.", + "MaxSessionDuration": null, + "Path": null, + "PermissionsBoundary": null, + "Tags": [ + { + "Key": "teleport.dev/cluster", + "Value": "mycluster" + }, + { + "Key": "teleport.dev/integration", + "Value": "myintegration" + }, + { + "Key": "teleport.dev/origin", + "Value": "integration_awsoidc" + } + ] +} + diff --git a/lib/srv/db/cloud/aws.go b/lib/srv/db/cloud/aws.go index 2b8a597029c76..5f69145230434 100644 --- a/lib/srv/db/cloud/aws.go +++ b/lib/srv/db/cloud/aws.go @@ -172,12 +172,12 @@ func (r *awsClient) ensureIAMPolicy(ctx context.Context) error { return trace.Wrap(err) } var changed bool - dbIAM.ForEach(func(effect, action, resource string) { - if policy.Ensure(effect, action, resource) { - r.log.Debugf("Permission %q for %q is already part of policy.", action, resource) - } else { + dbIAM.ForEach(func(effect, action, resource string, conditions awslib.Conditions) { + if policy.EnsureResourceAction(effect, action, resource, conditions) { r.log.Debugf("Adding permission %q for %q to policy.", action, resource) changed = true + } else { + r.log.Debugf("Permission %q for %q is already part of policy.", action, resource) } }) if !changed { @@ -206,11 +206,11 @@ func (r *awsClient) deleteIAMPolicy(ctx context.Context) error { if err != nil { return trace.Wrap(err) } - dbIAM.ForEach(func(effect, action, resource string) { - policy.Delete(effect, action, resource) + dbIAM.ForEach(func(effect, action, resource string, conditions awslib.Conditions) { + policy.DeleteResourceAction(effect, action, resource, conditions) }) // If policy is empty now, delete it as IAM policy can't be empty. - if len(policy.Statements) == 0 { + if policy.IsEmpty() { return r.detachIAMPolicy(ctx) } return r.updateIAMPolicy(ctx, policy) diff --git a/tool/teleport/common/integration_configure.go b/tool/teleport/common/integration_configure.go index dceb1bd2ca529..c542356226264 100644 --- a/tool/teleport/common/integration_configure.go +++ b/tool/teleport/common/integration_configure.go @@ -151,6 +151,7 @@ func onIntegrationConfAWSOIDCIdP(ctx context.Context, clf config.CommandLineFlag IntegrationName: clf.IntegrationConfAWSOIDCIdPArguments.Name, IntegrationRole: clf.IntegrationConfAWSOIDCIdPArguments.Role, ProxyPublicAddress: clf.IntegrationConfAWSOIDCIdPArguments.ProxyPublicURL, + AutoConfirm: clf.IntegrationConfAWSOIDCIdPArguments.AutoConfirm, } return trace.Wrap(awsoidc.ConfigureIdPIAM(ctx, iamClient, confReq)) } diff --git a/tool/teleport/common/teleport.go b/tool/teleport/common/teleport.go index 7047f1eb62759..45376e1ec8d41 100644 --- a/tool/teleport/common/teleport.go +++ b/tool/teleport/common/teleport.go @@ -515,6 +515,7 @@ func Run(options Options) (app *kingpin.Application, executedCommand string, con IntegrationConfAWSOIDCIdPArguments.Role) integrationConfAWSOIDCIdPCmd.Flag("proxy-public-url", "Proxy Public URL (eg https://mytenant.teleport.sh).").Required().StringVar(&ccf. IntegrationConfAWSOIDCIdPArguments.ProxyPublicURL) + integrationConfAWSOIDCIdPCmd.Flag("confirm", "Do not prompt user and auto-confirm all actions.").BoolVar(&ccf.IntegrationConfAWSOIDCIdPArguments.AutoConfirm) integrationConfAWSOIDCIdPCmd.Flag("insecure", "Insecure mode disables certificate validation.").BoolVar(&ccf.InsecureMode) integrationConfListDatabasesCmd := integrationConfigureCmd.Command("listdatabases-iam", "Adds required IAM permissions to List RDS Databases (Instances and Clusters).")