diff --git a/api/ssh.go b/api/ssh.go index fbaa8c5a0..9d0bbc14b 100644 --- a/api/ssh.go +++ b/api/ssh.go @@ -317,7 +317,7 @@ func SSHSign(w http.ResponseWriter, r *http.Request) { var identityCertificate []Certificate if cr := body.IdentityCSR.CertificateRequest; cr != nil { ctx := authority.NewContextWithSkipTokenReuse(r.Context()) - ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignMethod) + ctx = provisioner.NewContextWithMethod(ctx, provisioner.SignIdentityMethod) signOpts, err := a.Authorize(ctx, body.OTT) if err != nil { render.Error(w, errs.UnauthorizedErr(err)) diff --git a/authority/authorize.go b/authority/authorize.go index 1e35afe00..f14574a86 100644 --- a/authority/authorize.go +++ b/authority/authorize.go @@ -214,7 +214,7 @@ func (a *Authority) Authorize(ctx context.Context, token string) ([]provisioner. var opts = []interface{}{errs.WithKeyVal("token", token)} switch m := provisioner.MethodFromContext(ctx); m { - case provisioner.SignMethod: + case provisioner.SignMethod, provisioner.SignIdentityMethod: signOpts, err := a.authorizeSign(ctx, token) return signOpts, errs.Wrap(http.StatusInternalServerError, err, "authority.Authorize", opts...) case provisioner.RevokeMethod: diff --git a/authority/provisioner/aws.go b/authority/provisioner/aws.go index be6419738..e95feedd4 100644 --- a/authority/provisioner/aws.go +++ b/authority/provisioner/aws.go @@ -336,7 +336,7 @@ func (p *AWS) Init(config Config) (err error) { // AuthorizeSign validates the given token and returns the sign options that // will be used on certificate creation. -func (p *AWS) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { +func (p *AWS) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { payload, err := p.authorizeToken(token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "aws.AuthorizeSign") @@ -363,7 +363,7 @@ func (p *AWS) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro net.ParseIP(doc.PrivateIP), }), emailAddressesValidator(nil), - urisValidator(nil), + newURIsValidator(ctx, nil), ) // Template options diff --git a/authority/provisioner/aws_test.go b/authority/provisioner/aws_test.go index 05f514565..02be1ba92 100644 --- a/authority/provisioner/aws_test.go +++ b/authority/provisioner/aws_test.go @@ -695,8 +695,9 @@ func TestAWS_AuthorizeSign(t *testing.T) { assert.Equals(t, []net.IP(v), []net.IP{net.ParseIP("127.0.0.1")}) case emailAddressesValidator: assert.Equals(t, v, nil) - case urisValidator: - assert.Equals(t, v, nil) + case *urisValidator: + assert.Equals(t, v.uris, nil) + assert.Equals(t, MethodFromContext(v.ctx), SignMethod) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"ip-127-0-0-1.us-west-1.compute.internal"}) case *x509NamePolicyValidator: diff --git a/authority/provisioner/azure.go b/authority/provisioner/azure.go index 76bcebb66..a9d5d1fa3 100644 --- a/authority/provisioner/azure.go +++ b/authority/provisioner/azure.go @@ -316,7 +316,7 @@ func (p *Azure) authorizeToken(token string) (*azurePayload, string, string, str // AuthorizeSign validates the given token and returns the sign options that // will be used on certificate creation. -func (p *Azure) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { +func (p *Azure) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { _, name, group, subscription, identityObjectID, err := p.authorizeToken(token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "azure.AuthorizeSign") @@ -382,7 +382,7 @@ func (p *Azure) AuthorizeSign(_ context.Context, token string) ([]SignOption, er dnsNamesValidator([]string{name}), ipAddressesValidator(nil), emailAddressesValidator(nil), - urisValidator(nil), + newURIsValidator(ctx, nil), ) // Enforce SANs in the template. diff --git a/authority/provisioner/azure_test.go b/authority/provisioner/azure_test.go index 51d46c5ae..f262ffbca 100644 --- a/authority/provisioner/azure_test.go +++ b/authority/provisioner/azure_test.go @@ -560,8 +560,9 @@ func TestAzure_AuthorizeSign(t *testing.T) { assert.Equals(t, v, nil) case emailAddressesValidator: assert.Equals(t, v, nil) - case urisValidator: - assert.Equals(t, v, nil) + case *urisValidator: + assert.Equals(t, v.uris, nil) + assert.Equals(t, MethodFromContext(v.ctx), SignMethod) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"virtualMachine"}) case *x509NamePolicyValidator: diff --git a/authority/provisioner/gcp.go b/authority/provisioner/gcp.go index b6274f8fa..2296b1b06 100644 --- a/authority/provisioner/gcp.go +++ b/authority/provisioner/gcp.go @@ -223,7 +223,7 @@ func (p *GCP) Init(config Config) (err error) { // AuthorizeSign validates the given token and returns the sign options that // will be used on certificate creation. -func (p *GCP) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { +func (p *GCP) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(token) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "gcp.AuthorizeSign") @@ -254,7 +254,7 @@ func (p *GCP) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro }), ipAddressesValidator(nil), emailAddressesValidator(nil), - urisValidator(nil), + newURIsValidator(ctx, nil), ) // Template SANs diff --git a/authority/provisioner/gcp_test.go b/authority/provisioner/gcp_test.go index 7705b44a1..ef791614e 100644 --- a/authority/provisioner/gcp_test.go +++ b/authority/provisioner/gcp_test.go @@ -567,8 +567,9 @@ func TestGCP_AuthorizeSign(t *testing.T) { assert.Equals(t, v, nil) case emailAddressesValidator: assert.Equals(t, v, nil) - case urisValidator: - assert.Equals(t, v, nil) + case *urisValidator: + assert.Equals(t, v.uris, nil) + assert.Equals(t, MethodFromContext(v.ctx), SignMethod) case dnsNamesValidator: assert.Equals(t, []string(v), []string{"instance-name.c.project-id.internal", "instance-name.zone.c.project-id.internal"}) case *x509NamePolicyValidator: diff --git a/authority/provisioner/jwk.go b/authority/provisioner/jwk.go index 6c5ee6570..c98d78f22 100644 --- a/authority/provisioner/jwk.go +++ b/authority/provisioner/jwk.go @@ -150,7 +150,7 @@ func (p *JWK) AuthorizeRevoke(_ context.Context, token string) error { } // AuthorizeSign validates the given token. -func (p *JWK) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { +func (p *JWK) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "jwk.AuthorizeSign") @@ -192,7 +192,7 @@ func (p *JWK) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro // validators commonNameValidator(claims.Subject), defaultPublicKeyValidator{}, - defaultSANsValidator(claims.SANs), + newDefaultSANsValidator(ctx, claims.SANs), newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), p.ctl.newWebhookController(data, linkedca.Webhook_X509), diff --git a/authority/provisioner/jwk_test.go b/authority/provisioner/jwk_test.go index 19cee4fb0..bffe11414 100644 --- a/authority/provisioner/jwk_test.go +++ b/authority/provisioner/jwk_test.go @@ -315,8 +315,9 @@ func TestJWK_AuthorizeSign(t *testing.T) { case *validityValidator: assert.Equals(t, v.min, tt.prov.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tt.prov.ctl.Claimer.MaxTLSCertDuration()) - case defaultSANsValidator: - assert.Equals(t, []string(v), tt.sans) + case *defaultSANsValidator: + assert.Equals(t, v.sans, tt.sans) + assert.Equals(t, MethodFromContext(v.ctx), SignMethod) case *x509NamePolicyValidator: assert.Equals(t, nil, v.policyEngine) case *WebhookController: diff --git a/authority/provisioner/method.go b/authority/provisioner/method.go index 01dda2ed2..19aa6224f 100644 --- a/authority/provisioner/method.go +++ b/authority/provisioner/method.go @@ -14,6 +14,8 @@ type methodKey struct{} const ( // SignMethod is the method used to sign X.509 certificates. SignMethod Method = iota + // SignIdentityMethod is the method used to sign X.509 identity certificates. + SignIdentityMethod // RevokeMethod is the method used to revoke X.509 certificates. RevokeMethod // RenewMethod is the method used to renew X.509 certificates. @@ -33,6 +35,8 @@ func (m Method) String() string { switch m { case SignMethod: return "sign-method" + case SignIdentityMethod: + return "sign-identity-method" case RevokeMethod: return "revoke-method" case RenewMethod: diff --git a/authority/provisioner/nebula.go b/authority/provisioner/nebula.go index 6c24bd008..66c523dc4 100644 --- a/authority/provisioner/nebula.go +++ b/authority/provisioner/nebula.go @@ -389,7 +389,7 @@ func (v nebulaSANsValidator) Valid(req *x509.CertificateRequest) error { } } if len(req.URIs) > 0 { - if err := urisValidator(uris).Valid(req); err != nil { + if err := newURIsValidator(context.Background(), uris).Valid(req); err != nil { return err } } diff --git a/authority/provisioner/sign_options.go b/authority/provisioner/sign_options.go index 782a3598a..fec9b9f6f 100644 --- a/authority/provisioner/sign_options.go +++ b/authority/provisioner/sign_options.go @@ -1,6 +1,7 @@ package provisioner import ( + "context" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" @@ -233,16 +234,28 @@ func (v emailAddressesValidator) Valid(req *x509.CertificateRequest) error { } // urisValidator validates the URI SANs of a certificate request. -type urisValidator []*url.URL +type urisValidator struct { + ctx context.Context + uris []*url.URL +} + +func newURIsValidator(ctx context.Context, uris []*url.URL) *urisValidator { + return &urisValidator{ctx, uris} +} // Valid checks that certificate request IP Addresses match those configured in // the bootstrap (token) flow. func (v urisValidator) Valid(req *x509.CertificateRequest) error { + // SignIdentityMethod does not need to validate URIs. + if MethodFromContext(v.ctx) == SignIdentityMethod { + return nil + } + if len(req.URIs) == 0 { return nil } want := make(map[string]bool) - for _, u := range v { + for _, u := range v.uris { want[u.String()] = true } got := make(map[string]bool) @@ -250,26 +263,33 @@ func (v urisValidator) Valid(req *x509.CertificateRequest) error { got[u.String()] = true } if !reflect.DeepEqual(want, got) { - return errs.Forbidden("certificate request does not contain the valid URIs - got %v, want %v", req.URIs, v) + return errs.Forbidden("certificate request does not contain the valid URIs - got %v, want %v", req.URIs, v.uris) } return nil } // defaultsSANsValidator stores a set of SANs to eventually validate 1:1 against // the SANs in an x509 certificate request. -type defaultSANsValidator []string +type defaultSANsValidator struct { + ctx context.Context + sans []string +} + +func newDefaultSANsValidator(ctx context.Context, sans []string) *defaultSANsValidator { + return &defaultSANsValidator{ctx, sans} +} // Valid verifies that the SANs stored in the validator match 1:1 with those // requested in the x509 certificate request. func (v defaultSANsValidator) Valid(req *x509.CertificateRequest) (err error) { - dnsNames, ips, emails, uris := x509util.SplitSANs(v) + dnsNames, ips, emails, uris := x509util.SplitSANs(v.sans) if err = dnsNamesValidator(dnsNames).Valid(req); err != nil { return } else if err = emailAddressesValidator(emails).Valid(req); err != nil { return } else if err = ipAddressesValidator(ips).Valid(req); err != nil { return - } else if err = urisValidator(uris).Valid(req); err != nil { + } else if err = newURIsValidator(v.ctx, uris).Valid(req); err != nil { return } return diff --git a/authority/provisioner/sign_options_test.go b/authority/provisioner/sign_options_test.go index e36d051f2..5a55aa86a 100644 --- a/authority/provisioner/sign_options_test.go +++ b/authority/provisioner/sign_options_test.go @@ -1,6 +1,7 @@ package provisioner import ( + "context" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" @@ -227,23 +228,26 @@ func Test_urisValidator_Valid(t *testing.T) { fu, err := url.Parse("https://unexpected.com") assert.FatalError(t, err) + signContext := NewContextWithMethod(context.Background(), SignMethod) + signIdentityContext := NewContextWithMethod(context.Background(), SignIdentityMethod) + type args struct { req *x509.CertificateRequest } tests := []struct { name string - v urisValidator + v *urisValidator args args wantErr bool }{ - {"ok0", []*url.URL{}, args{&x509.CertificateRequest{URIs: []*url.URL{}}}, false}, - {"ok1", []*url.URL{u1}, args{&x509.CertificateRequest{URIs: []*url.URL{u1}}}, false}, - {"ok2", []*url.URL{u1, u2}, args{&x509.CertificateRequest{URIs: []*url.URL{u2, u1}}}, false}, - {"ok3", []*url.URL{u2, u1, u3}, args{&x509.CertificateRequest{URIs: []*url.URL{u3, u2, u1}}}, false}, - {"ok3", []*url.URL{u2, u1, u3}, args{&x509.CertificateRequest{}}, false}, - {"fail1", []*url.URL{u1}, args{&x509.CertificateRequest{URIs: []*url.URL{u2}}}, true}, - {"fail2", []*url.URL{u1}, args{&x509.CertificateRequest{URIs: []*url.URL{u2, u1}}}, true}, - {"fail3", []*url.URL{u1, u2}, args{&x509.CertificateRequest{URIs: []*url.URL{u1, fu}}}, true}, + {"ok0", newURIsValidator(signContext, []*url.URL{}), args{&x509.CertificateRequest{URIs: []*url.URL{}}}, false}, + {"ok1", newURIsValidator(signContext, []*url.URL{u1}), args{&x509.CertificateRequest{URIs: []*url.URL{u1}}}, false}, + {"ok2", newURIsValidator(signContext, []*url.URL{u1, u2}), args{&x509.CertificateRequest{URIs: []*url.URL{u2, u1}}}, false}, + {"ok3", newURIsValidator(signContext, []*url.URL{u2, u1, u3}), args{&x509.CertificateRequest{URIs: []*url.URL{u3, u2, u1}}}, false}, + {"ok4", newURIsValidator(signIdentityContext, []*url.URL{u1, u2}), args{&x509.CertificateRequest{URIs: []*url.URL{u1, fu}}}, false}, + {"fail1", newURIsValidator(signContext, []*url.URL{u1}), args{&x509.CertificateRequest{URIs: []*url.URL{u2}}}, true}, + {"fail2", newURIsValidator(signContext, []*url.URL{u1}), args{&x509.CertificateRequest{URIs: []*url.URL{u2, u1}}}, true}, + {"fail3", newURIsValidator(signContext, []*url.URL{u1, u2}), args{&x509.CertificateRequest{URIs: []*url.URL{u1, fu}}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -257,13 +261,19 @@ func Test_urisValidator_Valid(t *testing.T) { func Test_defaultSANsValidator_Valid(t *testing.T) { type test struct { csr *x509.CertificateRequest + ctx context.Context expectedSANs []string err error } + + signContext := NewContextWithMethod(context.Background(), SignMethod) + signIdentityContext := NewContextWithMethod(context.Background(), SignIdentityMethod) + tests := map[string]func() test{ "fail/dnsNamesValidator": func() test { return test{ csr: &x509.CertificateRequest{DNSNames: []string{"foo", "bar"}}, + ctx: signContext, expectedSANs: []string{"foo"}, err: errors.New("certificate request does not contain the valid DNS names"), } @@ -271,6 +281,7 @@ func Test_defaultSANsValidator_Valid(t *testing.T) { "fail/emailAddressesValidator": func() test { return test{ csr: &x509.CertificateRequest{EmailAddresses: []string{"max@fx.com", "mariano@fx.com"}}, + ctx: signContext, expectedSANs: []string{"dcow@fx.com"}, err: errors.New("certificate request does not contain the valid email addresses"), } @@ -278,6 +289,7 @@ func Test_defaultSANsValidator_Valid(t *testing.T) { "fail/ipAddressesValidator": func() test { return test{ csr: &x509.CertificateRequest{IPAddresses: []net.IP{net.ParseIP("1.1.1.1"), net.ParseIP("127.0.0.1")}}, + ctx: signContext, expectedSANs: []string{"127.0.0.1"}, err: errors.New("certificate request does not contain the valid IP addresses"), } @@ -289,16 +301,29 @@ func Test_defaultSANsValidator_Valid(t *testing.T) { assert.FatalError(t, err) return test{ csr: &x509.CertificateRequest{URIs: []*url.URL{u1, u2}}, + ctx: signContext, expectedSANs: []string{"urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959"}, err: errors.New("certificate request does not contain the valid URIs"), } }, + "ok/urisBadValidator-SignIdentity": func() test { + u1, err := url.Parse("https://google.com") + assert.FatalError(t, err) + u2, err := url.Parse("urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959") + assert.FatalError(t, err) + return test{ + csr: &x509.CertificateRequest{URIs: []*url.URL{u1, u2}}, + ctx: signIdentityContext, + expectedSANs: []string{"urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959"}, + } + }, "ok": func() test { u1, err := url.Parse("https://google.com") assert.FatalError(t, err) u2, err := url.Parse("urn:uuid:ddfe62ba-7e99-4bc1-83b3-8f57fe3e9959") assert.FatalError(t, err) return test{ + ctx: signContext, csr: &x509.CertificateRequest{ DNSNames: []string{"foo", "bar"}, EmailAddresses: []string{"max@fx.com", "mariano@fx.com"}, @@ -312,7 +337,7 @@ func Test_defaultSANsValidator_Valid(t *testing.T) { for name, run := range tests { t.Run(name, func(t *testing.T) { tt := run() - if err := defaultSANsValidator(tt.expectedSANs).Valid(tt.csr); err != nil { + if err := newDefaultSANsValidator(tt.ctx, tt.expectedSANs).Valid(tt.csr); err != nil { if assert.NotNil(t, tt.err, fmt.Sprintf("expected no error, but got err = %s", err.Error())) { assert.True(t, strings.Contains(err.Error(), tt.err.Error()), fmt.Sprintf("want err = %s, but got err = %s", tt.err.Error(), err.Error())) diff --git a/authority/provisioner/x5c.go b/authority/provisioner/x5c.go index b6e78697b..9b1f2b086 100644 --- a/authority/provisioner/x5c.go +++ b/authority/provisioner/x5c.go @@ -194,7 +194,7 @@ func (p *X5C) AuthorizeRevoke(_ context.Context, token string) error { } // AuthorizeSign validates the given token. -func (p *X5C) AuthorizeSign(_ context.Context, token string) ([]SignOption, error) { +func (p *X5C) AuthorizeSign(ctx context.Context, token string) ([]SignOption, error) { claims, err := p.authorizeToken(token, p.ctl.Audiences.Sign) if err != nil { return nil, errs.Wrap(http.StatusInternalServerError, err, "x5c.AuthorizeSign") @@ -244,7 +244,7 @@ func (p *X5C) AuthorizeSign(_ context.Context, token string) ([]SignOption, erro }, // validators commonNameValidator(claims.Subject), - defaultSANsValidator(claims.SANs), + newDefaultSANsValidator(ctx, claims.SANs), defaultPublicKeyValidator{}, newValidityValidator(p.ctl.Claimer.MinTLSCertDuration(), p.ctl.Claimer.MaxTLSCertDuration()), newX509NamePolicyValidator(p.ctl.getPolicy().getX509()), diff --git a/authority/provisioner/x5c_test.go b/authority/provisioner/x5c_test.go index f9a2604b3..22545446b 100644 --- a/authority/provisioner/x5c_test.go +++ b/authority/provisioner/x5c_test.go @@ -460,7 +460,8 @@ func TestX5C_AuthorizeSign(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { tc := tt(t) - if opts, err := tc.p.AuthorizeSign(context.Background(), tc.token); err != nil { + ctx := NewContextWithMethod(context.Background(), SignIdentityMethod) + if opts, err := tc.p.AuthorizeSign(ctx, tc.token); err != nil { if assert.NotNil(t, tc.err) { var sc render.StatusCodedError if assert.True(t, errors.As(err, &sc), "error does not implement StatusCodedError interface") { @@ -489,8 +490,9 @@ func TestX5C_AuthorizeSign(t *testing.T) { case commonNameValidator: assert.Equals(t, string(v), "foo") case defaultPublicKeyValidator: - case defaultSANsValidator: - assert.Equals(t, []string(v), tc.sans) + case *defaultSANsValidator: + assert.Equals(t, v.sans, tc.sans) + assert.Equals(t, MethodFromContext(v.ctx), SignIdentityMethod) case *validityValidator: assert.Equals(t, v.min, tc.p.ctl.Claimer.MinTLSCertDuration()) assert.Equals(t, v.max, tc.p.ctl.Claimer.MaxTLSCertDuration()) diff --git a/pki/helm.go b/pki/helm.go index 3c5cb5a9d..3de2c2ec0 100644 --- a/pki/helm.go +++ b/pki/helm.go @@ -1,6 +1,7 @@ package pki import ( + "fmt" "io" "text/template" @@ -49,21 +50,42 @@ func (p *PKI) WriteHelmTemplate(w io.Writer) error { // to what's in p.GenerateConfig(), but that codepath isn't taken when // writing the Helm template. The default JWK provisioner is added earlier in // the process and that's part of the provisioners above. + // + // To prevent name clashes for the default ACME provisioner, we append "-1" to + // the name if it already exists. See https://github.com/smallstep/cli/issues/1018 + // for the reason. + // // TODO(hs): consider refactoring the initialization, so that this becomes // easier to reason about and maintain. if p.options.enableACME { + acmeProvisionerName := "acme" + for _, prov := range provisioners { + if prov.GetName() == acmeProvisionerName { + acmeProvisionerName = fmt.Sprintf("%s-1", acmeProvisionerName) + break + } + } provisioners = append(provisioners, &provisioner.ACME{ Type: "ACME", - Name: "acme", + Name: acmeProvisionerName, }) } // Add default SSHPOP provisioner if enabled. Similar to the above, this is - // the same as what happens in p.GenerateConfig(). + // the same as what happens in p.GenerateConfig(). To prevent name clashes for the + // default SSHPOP provisioner, we append "-1" to it if it already exists. See + // https://github.com/smallstep/cli/issues/1018 for the reason. if p.options.enableSSH { + sshProvisionerName := "sshpop" + for _, prov := range provisioners { + if prov.GetName() == sshProvisionerName { + sshProvisionerName = fmt.Sprintf("%s-1", sshProvisionerName) + break + } + } provisioners = append(provisioners, &provisioner.SSHPOP{ Type: "SSHPOP", - Name: "sshpop", + Name: sshProvisionerName, Claims: &provisioner.Claims{ EnableSSHCA: &p.options.enableSSH, }, diff --git a/pki/helm_test.go b/pki/helm_test.go index 508f8c3e7..3aa0d2242 100644 --- a/pki/helm_test.go +++ b/pki/helm_test.go @@ -85,6 +85,13 @@ func TestPKI_WriteHelmTemplate(t *testing.T) { wantErr: false, } }, + "ok/with-acme-and-duplicate-provisioner-name": func(t *testing.T) test { + return test{ + pki: preparePKI(t, WithProvisioner("acme"), WithACME()), + testFile: "testdata/helm/with-acme-and-duplicate-provisioner-name.yml", + wantErr: false, + } + }, "ok/with-admin": func(t *testing.T) test { return test{ pki: preparePKI(t, WithAdmin()), @@ -99,6 +106,13 @@ func TestPKI_WriteHelmTemplate(t *testing.T) { wantErr: false, } }, + "ok/with-ssh-and-duplicate-provisioner-name": func(t *testing.T) test { + return test{ + pki: preparePKI(t, WithProvisioner("sshpop"), WithSSH()), + testFile: "testdata/helm/with-ssh-and-duplicate-provisioner-name.yml", + wantErr: false, + } + }, "ok/with-ssh-and-acme": func(t *testing.T) test { return test{ pki: preparePKI(t, WithSSH(), WithACME()), diff --git a/pki/pki.go b/pki/pki.go index 971c189b0..234bae5a2 100644 --- a/pki/pki.go +++ b/pki/pki.go @@ -319,7 +319,10 @@ type PKI struct { func New(o apiv1.Options, opts ...Option) (*PKI, error) { // TODO(hs): invoking `New` with a context active will use values from // that CA context while generating the context. Thay may or may not - // be fully expected and/or what we want. Check that. + // be fully expected and/or what we want. This specific behavior was + // changed after not relying on the `init` inside of `step`, resulting in + // the default context being active if `step.Init` isn't called explicitly. + // It can still result in surprising results, though. currentCtx := step.Contexts().GetCurrent() caService, err := cas.New(context.Background(), o) if err != nil { @@ -330,7 +333,7 @@ func New(o apiv1.Options, opts ...Option) (*PKI, error) { if o.IsCreator { creator, ok := caService.(apiv1.CertificateAuthorityCreator) if !ok { - return nil, errors.Errorf("cas type '%s' does not implements CertificateAuthorityCreator", o.Type) + return nil, errors.Errorf("cas type %q does not implement CertificateAuthorityCreator", o.Type) } caCreator = creator } @@ -850,9 +853,16 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) { // Add default ACME provisioner if enabled if p.options.enableACME { + // To prevent name clashes for the default ACME provisioner, we append "-1" to + // the name if it already exists. See https://github.com/smallstep/cli/issues/1018 + // for the reason. + acmeProvisionerName := "acme" + if p.options.provisioner == acmeProvisionerName { + acmeProvisionerName = fmt.Sprintf("%s-1", acmeProvisionerName) + } provisioners = append(provisioners, &provisioner.ACME{ Type: "ACME", - Name: "acme", + Name: acmeProvisionerName, }) } @@ -867,10 +877,16 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) { EnableSSHCA: &enableSSHCA, } - // Add default SSHPOP provisioner + // Add default SSHPOP provisioner. To prevent name clashes for the default + // SSHPOP provisioner, we append "-1" to the name if it already exists. + // See https://github.com/smallstep/cli/issues/1018 for the reason. + sshProvisionerName := "sshpop" + if p.options.provisioner == sshProvisionerName { + sshProvisionerName = fmt.Sprintf("%s-1", sshProvisionerName) + } provisioners = append(provisioners, &provisioner.SSHPOP{ Type: "SSHPOP", - Name: "sshpop", + Name: sshProvisionerName, Claims: &provisioner.Claims{ EnableSSHCA: &enableSSHCA, }, @@ -910,10 +926,13 @@ func (p *PKI) GenerateConfig(opt ...ConfigOption) (*authconfig.Config, error) { if err != nil { return nil, err } + defer _db.Shutdown() // free DB resources; unlock BadgerDB file + adminDB, err := admindb.New(_db.(nosql.DB), admin.DefaultAuthorityID) if err != nil { return nil, err } + // Add all the provisioners to the db. var adminID string for i, p := range provisioners { diff --git a/pki/pki_test.go b/pki/pki_test.go new file mode 100644 index 000000000..7d5bc8c50 --- /dev/null +++ b/pki/pki_test.go @@ -0,0 +1,313 @@ +package pki + +import ( + "context" + "path/filepath" + "testing" + + "github.com/smallstep/certificates/authority/admin" + admindb "github.com/smallstep/certificates/authority/admin/db/nosql" + authconfig "github.com/smallstep/certificates/authority/config" + "github.com/smallstep/certificates/authority/provisioner" + "github.com/smallstep/certificates/cas/apiv1" + "github.com/smallstep/certificates/db" + "github.com/smallstep/nosql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.step.sm/cli-utils/step" +) + +func withDBDataSource(t *testing.T, dataSource string) func(c *authconfig.Config) error { + return func(c *authconfig.Config) error { + if c == nil || c.DB == nil { + require.Fail(t, "withDBDataSource prerequisites not met") + } + c.DB.DataSource = dataSource + return nil + } +} + +func TestPKI_GenerateConfig(t *testing.T) { + var preparePKI = func(t *testing.T, opts ...Option) *PKI { + o := apiv1.Options{ + Type: "softcas", + IsCreator: true, + } + + // TODO(hs): invoking `New` doesn't perform all operations that are executed + // when `ca init` is executed. Ideally this logic should be handled in one + // place and probably inside of the PKI initialization. For testing purposes + // the missing operations are faked by `setKeyPair`. + p, err := New(o, opts...) + require.NoError(t, err) + + // setKeyPair sets a predefined JWK and a default JWK provisioner. This is one + // of the things performed in the `ca init` code that's not part of `New`, but + // performed after that in p.GenerateKeyPairs`. We're currently using the same + // JWK for every test to keep test variance small: we're not testing JWK generation + // here after all. It's a bit dangerous to redefine the function here, but it's + // the simplest way to make this fully testable without refactoring the init now. + // The password for the predefined encrypted key is \x01\x03\x03\x07. + setKeyPair(t, p) + + return p + } + type args struct { + opt []ConfigOption + } + type test struct { + pki *PKI + args args + want *authconfig.Config + wantErr bool + } + var tests = map[string]func(t *testing.T) test{ + "ok/simple": func(t *testing.T) test { + pki := preparePKI(t) + pki.options.deploymentType = StandaloneDeployment + pki.options.provisioner = "default-prov" + return test{ + pki: pki, + args: args{ + []ConfigOption{}, + }, + want: &authconfig.Config{ + Address: "127.0.0.1:9000", + InsecureAddress: "", + DNSNames: []string{"127.0.0.1"}, + AuthorityConfig: &authconfig.AuthConfig{ + DeploymentType: "", // TODO(hs): (why is) this is not set to standalone? + EnableAdmin: false, + Provisioners: provisioner.List{ + &provisioner.JWK{ + Type: "JWK", + Name: "default-prov", + }, + }, + }, + DB: &db.Config{ + Type: "badgerv2", + DataSource: filepath.Join(step.Path(), "db"), + }, + }, + wantErr: false, + } + }, + "ok/with-acme": func(t *testing.T) test { + pki := preparePKI(t) + pki.options.deploymentType = StandaloneDeployment + pki.options.provisioner = "default-prov" + pki.options.enableACME = true + return test{ + pki: pki, + args: args{ + []ConfigOption{}, + }, + want: &authconfig.Config{ + Address: "127.0.0.1:9000", + InsecureAddress: "", + DNSNames: []string{"127.0.0.1"}, + AuthorityConfig: &authconfig.AuthConfig{ + DeploymentType: "", // TODO(hs): (why is) this is not set to standalone? + EnableAdmin: false, + Provisioners: provisioner.List{ + &provisioner.JWK{ + Type: "JWK", + Name: "default-prov", + }, + &provisioner.ACME{ + Type: "ACME", + Name: "acme", + }, + }, + }, + DB: &db.Config{ + Type: "badgerv2", + DataSource: filepath.Join(step.Path(), "db"), + }, + }, + wantErr: false, + } + }, + "ok/with-acme-and-double-provisioner-name": func(t *testing.T) test { + pki := preparePKI(t) + pki.options.deploymentType = StandaloneDeployment + pki.options.provisioner = "acme" + pki.options.enableACME = true + return test{ + pki: pki, + args: args{ + []ConfigOption{}, + }, + want: &authconfig.Config{ + Address: "127.0.0.1:9000", + InsecureAddress: "", + DNSNames: []string{"127.0.0.1"}, + AuthorityConfig: &authconfig.AuthConfig{ + DeploymentType: "", // TODO(hs): (why is) this is not set to standalone? + EnableAdmin: false, + Provisioners: provisioner.List{ + &provisioner.JWK{ + Type: "JWK", + Name: "acme", + }, + &provisioner.ACME{ + Type: "ACME", + Name: "acme-1", + }, + }, + }, + DB: &db.Config{ + Type: "badgerv2", + DataSource: filepath.Join(step.Path(), "db"), + }, + }, + wantErr: false, + } + }, + "ok/with-ssh": func(t *testing.T) test { + pki := preparePKI(t) + pki.options.deploymentType = StandaloneDeployment + pki.options.provisioner = "default-prov" + pki.options.enableSSH = true + return test{ + pki: pki, + args: args{ + []ConfigOption{}, + }, + want: &authconfig.Config{ + Address: "127.0.0.1:9000", + InsecureAddress: "", + DNSNames: []string{"127.0.0.1"}, + AuthorityConfig: &authconfig.AuthConfig{ + DeploymentType: "", // TODO(hs): (why is) this is not set to standalone? + EnableAdmin: false, + Provisioners: provisioner.List{ + &provisioner.JWK{ + Type: "JWK", + Name: "default-prov", + }, + &provisioner.SSHPOP{ + Type: "SSHPOP", + Name: "sshpop", + }, + }, + }, + DB: &db.Config{ + Type: "badgerv2", + DataSource: filepath.Join(step.Path(), "db"), + }, + }, + wantErr: false, + } + }, + "ok/with-ssh-and-double-provisioner-name": func(t *testing.T) test { + pki := preparePKI(t) + pki.options.deploymentType = StandaloneDeployment + pki.options.provisioner = "sshpop" + pki.options.enableSSH = true + return test{ + pki: pki, + args: args{ + []ConfigOption{}, + }, + want: &authconfig.Config{ + Address: "127.0.0.1:9000", + InsecureAddress: "", + DNSNames: []string{"127.0.0.1"}, + AuthorityConfig: &authconfig.AuthConfig{ + DeploymentType: "", // TODO(hs): (why is) this is not set to standalone? + EnableAdmin: false, + Provisioners: provisioner.List{ + &provisioner.JWK{ + Type: "JWK", + Name: "sshpop", + }, + &provisioner.SSHPOP{ + Type: "SSHPOP", + Name: "sshpop-1", + }, + }, + }, + DB: &db.Config{ + Type: "badgerv2", + DataSource: filepath.Join(step.Path(), "db"), + }, + }, + wantErr: false, + } + }, + "ok/with-admin": func(t *testing.T) test { + pki := preparePKI(t) + pki.options.deploymentType = StandaloneDeployment + pki.options.provisioner = "default-prov" + pki.options.enableAdmin = true + tempDir := t.TempDir() + return test{ + pki: pki, + args: args{ + []ConfigOption{withDBDataSource(t, filepath.Join(tempDir, "db"))}, + }, + want: &authconfig.Config{ + Address: "127.0.0.1:9000", + InsecureAddress: "", + DNSNames: []string{"127.0.0.1"}, + AuthorityConfig: &authconfig.AuthConfig{ + DeploymentType: "", // TODO(hs): (why is) this is not set to standalone? + EnableAdmin: true, + Provisioners: provisioner.List{}, // when admin is enabled, provisioner list is expected to be empty + }, + DB: &db.Config{ + Type: "badgerv2", + DataSource: filepath.Join(tempDir, "db"), + }, + }, + wantErr: false, + } + }, + } + for name, run := range tests { + tc := run(t) + t.Run(name, func(t *testing.T) { + got, err := tc.pki.GenerateConfig(tc.args.opt...) + if tc.wantErr { + assert.NotNil(t, err) + assert.Nil(t, got) + return + } + + assert.Nil(t, err) + if assert.NotNil(t, got) { + assert.Equal(t, tc.want.Address, got.Address) + assert.Equal(t, tc.want.InsecureAddress, got.InsecureAddress) + assert.Equal(t, tc.want.DNSNames, got.DNSNames) + assert.Equal(t, tc.want.DB, got.DB) + if assert.NotNil(t, tc.want.AuthorityConfig) { + assert.Equal(t, tc.want.AuthorityConfig.DeploymentType, got.AuthorityConfig.DeploymentType) + assert.Equal(t, tc.want.AuthorityConfig.EnableAdmin, got.AuthorityConfig.EnableAdmin) + if numberOfProvisioners := len(tc.want.AuthorityConfig.Provisioners); numberOfProvisioners > 0 { + if assert.Len(t, got.AuthorityConfig.Provisioners, numberOfProvisioners) { + for i, p := range tc.want.AuthorityConfig.Provisioners { + assert.Equal(t, p.GetType(), got.AuthorityConfig.Provisioners[i].GetType()) + assert.Equal(t, p.GetName(), got.AuthorityConfig.Provisioners[i].GetName()) + } + } + } + if tc.want.AuthorityConfig.EnableAdmin { + _db, err := db.New(tc.want.DB) + require.NoError(t, err) + defer _db.Shutdown() + + adminDB, err := admindb.New(_db.(nosql.DB), admin.DefaultAuthorityID) + require.NoError(t, err) + + provs, err := adminDB.GetProvisioners(context.Background()) + require.NoError(t, err) + + assert.NotEmpty(t, provs) // currently about the best we can do in terms of checks + } + } + } + }) + } +} diff --git a/pki/testdata/helm/with-acme-and-duplicate-provisioner-name.yml b/pki/testdata/helm/with-acme-and-duplicate-provisioner-name.yml new file mode 100644 index 000000000..f45322076 --- /dev/null +++ b/pki/testdata/helm/with-acme-and-duplicate-provisioner-name.yml @@ -0,0 +1,82 @@ +# Helm template +inject: + enabled: true + # Config contains the configuration files ca.json and defaults.json + config: + files: + ca.json: + root: /home/step/certs/root_ca.crt + federateRoots: [] + crt: /home/step/certs/intermediate_ca.crt + key: /home/step/secrets/intermediate_ca_key + address: 127.0.0.1:9000 + dnsNames: + - 127.0.0.1 + logger: + format: json + db: + type: badgerv2 + dataSource: /home/step/db + authority: + enableAdmin: false + provisioners: + - {"type":"JWK","name":"acme","key":{"use":"sig","kty":"EC","kid":"zsUmysmDVoGJ71YoPHyZ-68tNihDaDaO5Mu7xX3M-_I","crv":"P-256","alg":"ES256","x":"Pqnua4CzqKz6ua41J3yeWZ1sRkGt0UlCkbHv8H2DGuY","y":"UhoZ_2ItDen9KQTcjay-ph-SBXH0mwqhHyvrrqIFDOI"},"encryptedKey":"eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiZjVvdGVRS2hvOXl4MmQtSGlMZi05QSJ9.eYA6tt3fNuUpoxKWDT7P0Lbn2juxhEbTxEnwEMbjlYLLQ3sxL-dYTA.ven-FhmdjlC9itH0.a2jRTarN9vPd6F_mWnNBlOn6KbfMjCApmci2t65XbAsLzYFzhI_79Ykm5ueMYTupWLTjBJctl-g51ZHmsSB55pStbpoyyLNAsUX2E1fTmHe-Ni8bRrspwLv15FoN1Xo1g0mpR-ufWIFxOsW-QIfnMmMIIkygVuHFXmg2tFpzTNNG5aS29K3dN2nyk0WJrdIq79hZSTqVkkBU25Yu3A46sgjcM86XcIJJ2XUEih_KWEa6T1YrkixGu96pebjVqbO0R6dbDckfPF7FqNnwPHVtb1ACFpEYoOJVIbUCMaARBpWsxYhjJZlEM__XA46l8snFQDkNY3CdN0p1_gF3ckA.JLmq9nmu1h9oUi1S8ZxYjA","options":{"x509":{},"ssh":{}}} + - {"type":"ACME","name":"acme-1"} + tls: + cipherSuites: + - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 + - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + minVersion: 1.2 + maxVersion: 1.3 + renegotiation: false + + defaults.json: + ca-url: https://127.0.0.1 + ca-config: /home/step/config/ca.json + fingerprint: e543cad8e9f6417076bb5aed3471c588152118aac1e0ca7984a43ee7f76da5e3 + root: /home/step/certs/root_ca.crt + + # Certificates contains the root and intermediate certificate and + # optionally the SSH host and user public keys + certificates: + # intermediate_ca contains the text of the intermediate CA Certificate + intermediate_ca: | + -----BEGIN CERTIFICATE----- + dGhlc2UgYXJlIGp1c3Qgc29tZSBmYWtlIGludGVybWVkaWF0ZSBDQSBjZXJ0IGJ5 + dGVz + -----END CERTIFICATE----- + + + # root_ca contains the text of the root CA Certificate + root_ca: | + -----BEGIN CERTIFICATE----- + dGhlc2UgYXJlIGp1c3Qgc29tZSBmYWtlIHJvb3QgQ0EgY2VydCBieXRlcw== + -----END CERTIFICATE----- + + + # Secrets contains the root and intermediate keys and optionally the SSH + # private keys + secrets: + # ca_password contains the password used to encrypt x509.intermediate_ca_key, ssh.host_ca_key and ssh.user_ca_key + # This value must be base64 encoded. + ca_password: + provisioner_password: + + x509: + # intermediate_ca_key contains the contents of your encrypted intermediate CA key + intermediate_ca_key: | + -----BEGIN EC PRIVATE KEY----- + dGhlc2UgYXJlIGp1c3Qgc29tZSBmYWtlIGludGVybWVkaWF0ZSBDQSBrZXkgYnl0 + ZXM= + -----END EC PRIVATE KEY----- + + + # root_ca_key contains the contents of your encrypted root CA key + # Note that this value can be omitted without impacting the functionality of step-certificates + # If supplied, this should be encrypted using a unique password that is not used for encrypting + # the intermediate_ca_key, ssh.host_ca_key or ssh.user_ca_key. + root_ca_key: | + -----BEGIN EC PRIVATE KEY----- + dGhlc2UgYXJlIGp1c3Qgc29tZSBmYWtlIHJvb3QgQ0Ega2V5IGJ5dGVz + -----END EC PRIVATE KEY----- + diff --git a/pki/testdata/helm/with-ssh-and-duplicate-provisioner-name.yml b/pki/testdata/helm/with-ssh-and-duplicate-provisioner-name.yml new file mode 100644 index 000000000..a499b1553 --- /dev/null +++ b/pki/testdata/helm/with-ssh-and-duplicate-provisioner-name.yml @@ -0,0 +1,104 @@ +# Helm template +inject: + enabled: true + # Config contains the configuration files ca.json and defaults.json + config: + files: + ca.json: + root: /home/step/certs/root_ca.crt + federateRoots: [] + crt: /home/step/certs/intermediate_ca.crt + key: /home/step/secrets/intermediate_ca_key + ssh: + hostKey: /home/step/secrets/ssh_host_ca_key + userKey: /home/step/secrets/ssh_user_ca_key + address: 127.0.0.1:9000 + dnsNames: + - 127.0.0.1 + logger: + format: json + db: + type: badgerv2 + dataSource: /home/step/db + authority: + enableAdmin: false + provisioners: + - {"type":"JWK","name":"sshpop","key":{"use":"sig","kty":"EC","kid":"zsUmysmDVoGJ71YoPHyZ-68tNihDaDaO5Mu7xX3M-_I","crv":"P-256","alg":"ES256","x":"Pqnua4CzqKz6ua41J3yeWZ1sRkGt0UlCkbHv8H2DGuY","y":"UhoZ_2ItDen9KQTcjay-ph-SBXH0mwqhHyvrrqIFDOI"},"encryptedKey":"eyJhbGciOiJQQkVTMi1IUzI1NitBMTI4S1ciLCJjdHkiOiJqd2sranNvbiIsImVuYyI6IkEyNTZHQ00iLCJwMmMiOjEwMDAwMCwicDJzIjoiZjVvdGVRS2hvOXl4MmQtSGlMZi05QSJ9.eYA6tt3fNuUpoxKWDT7P0Lbn2juxhEbTxEnwEMbjlYLLQ3sxL-dYTA.ven-FhmdjlC9itH0.a2jRTarN9vPd6F_mWnNBlOn6KbfMjCApmci2t65XbAsLzYFzhI_79Ykm5ueMYTupWLTjBJctl-g51ZHmsSB55pStbpoyyLNAsUX2E1fTmHe-Ni8bRrspwLv15FoN1Xo1g0mpR-ufWIFxOsW-QIfnMmMIIkygVuHFXmg2tFpzTNNG5aS29K3dN2nyk0WJrdIq79hZSTqVkkBU25Yu3A46sgjcM86XcIJJ2XUEih_KWEa6T1YrkixGu96pebjVqbO0R6dbDckfPF7FqNnwPHVtb1ACFpEYoOJVIbUCMaARBpWsxYhjJZlEM__XA46l8snFQDkNY3CdN0p1_gF3ckA.JLmq9nmu1h9oUi1S8ZxYjA","claims":{"enableSSHCA":true,"disableRenewal":false,"allowRenewalAfterExpiry":false,"disableSmallstepExtensions":false},"options":{"x509":{},"ssh":{}}} + - {"type":"SSHPOP","name":"sshpop-1","claims":{"enableSSHCA":true}} + tls: + cipherSuites: + - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 + - TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + minVersion: 1.2 + maxVersion: 1.3 + renegotiation: false + + defaults.json: + ca-url: https://127.0.0.1 + ca-config: /home/step/config/ca.json + fingerprint: e543cad8e9f6417076bb5aed3471c588152118aac1e0ca7984a43ee7f76da5e3 + root: /home/step/certs/root_ca.crt + + # Certificates contains the root and intermediate certificate and + # optionally the SSH host and user public keys + certificates: + # intermediate_ca contains the text of the intermediate CA Certificate + intermediate_ca: | + -----BEGIN CERTIFICATE----- + dGhlc2UgYXJlIGp1c3Qgc29tZSBmYWtlIGludGVybWVkaWF0ZSBDQSBjZXJ0IGJ5 + dGVz + -----END CERTIFICATE----- + + + # root_ca contains the text of the root CA Certificate + root_ca: | + -----BEGIN CERTIFICATE----- + dGhlc2UgYXJlIGp1c3Qgc29tZSBmYWtlIHJvb3QgQ0EgY2VydCBieXRlcw== + -----END CERTIFICATE----- + + # ssh_host_ca contains the text of the public ssh key for the SSH root CA + ssh_host_ca: ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBJ0IdS5sZm6KITBMZLEJD6b5ROVraYHcAOr3feFel8r1Wp4DRPR1oU0W00J/zjNBRBbANlJoYN4x/8WNNVZ49Ms= + + # ssh_user_ca contains the text of the public ssh key for the SSH root CA + ssh_user_ca: ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEWA1qUxaGwVNErsvEOGe2d6TvLMF+aiVpuOiIEvpMJ3JeJmecLQctjWqeIbpSvy6/gRa7c82Ge5rLlapYmOChs= + + # Secrets contains the root and intermediate keys and optionally the SSH + # private keys + secrets: + # ca_password contains the password used to encrypt x509.intermediate_ca_key, ssh.host_ca_key and ssh.user_ca_key + # This value must be base64 encoded. + ca_password: + provisioner_password: + + x509: + # intermediate_ca_key contains the contents of your encrypted intermediate CA key + intermediate_ca_key: | + -----BEGIN EC PRIVATE KEY----- + dGhlc2UgYXJlIGp1c3Qgc29tZSBmYWtlIGludGVybWVkaWF0ZSBDQSBrZXkgYnl0 + ZXM= + -----END EC PRIVATE KEY----- + + + # root_ca_key contains the contents of your encrypted root CA key + # Note that this value can be omitted without impacting the functionality of step-certificates + # If supplied, this should be encrypted using a unique password that is not used for encrypting + # the intermediate_ca_key, ssh.host_ca_key or ssh.user_ca_key. + root_ca_key: | + -----BEGIN EC PRIVATE KEY----- + dGhlc2UgYXJlIGp1c3Qgc29tZSBmYWtlIHJvb3QgQ0Ega2V5IGJ5dGVz + -----END EC PRIVATE KEY----- + + ssh: + # ssh_host_ca_key contains the contents of your encrypted SSH Host CA key + host_ca_key: | + -----BEGIN EC PRIVATE KEY----- + ZmFrZSBzc2ggaG9zdCBrZXkgYnl0ZXM= + -----END EC PRIVATE KEY----- + + + # ssh_user_ca_key contains the contents of your encrypted SSH User CA key + user_ca_key: | + -----BEGIN EC PRIVATE KEY----- + ZmFrZSBzc2ggdXNlciBrZXkgYnl0ZXM= + -----END EC PRIVATE KEY----- + diff --git a/scep/api/api.go b/scep/api/api.go index 614b51841..cb0a73b1e 100644 --- a/scep/api/api.go +++ b/scep/api/api.go @@ -151,11 +151,14 @@ func decodeRequest(r *http.Request) (request, error) { defer r.Body.Close() method := r.Method - query := r.URL.Query() + query, err := url.ParseQuery(r.URL.RawQuery) + if err != nil { + return request{}, fmt.Errorf("failed parsing URL query: %w", err) + } - var operation string - if _, ok := query["operation"]; ok { - operation = query.Get("operation") + operation := query.Get("operation") + if operation == "" { + return request{}, errors.New("no operation provided") } switch method { @@ -167,14 +170,10 @@ func decodeRequest(r *http.Request) (request, error) { Message: []byte{}, }, nil case opnPKIOperation: - var message string - if _, ok := query["message"]; ok { - message = query.Get("message") - } - // TODO: verify this; right type of encoding? Needs additional transformations? - decodedMessage, err := base64.StdEncoding.DecodeString(message) + message := query.Get("message") + decodedMessage, err := decodeMessage(message, r) if err != nil { - return request{}, err + return request{}, fmt.Errorf("failed decoding message: %w", err) } return request{ Operation: operation, @@ -186,7 +185,7 @@ func decodeRequest(r *http.Request) (request, error) { case http.MethodPost: body, err := io.ReadAll(io.LimitReader(r.Body, maxPayloadSize)) if err != nil { - return request{}, err + return request{}, fmt.Errorf("failed reading request body: %w", err) } return request{ Operation: operation, @@ -197,6 +196,77 @@ func decodeRequest(r *http.Request) (request, error) { } } +func decodeMessage(message string, r *http.Request) ([]byte, error) { + if message == "" { + return nil, errors.New("message must not be empty") + } + + // decode the message, which should be base64 standard encoded. Any characters that + // were escaped in the original query, were unescaped as part of url.ParseQuery, so + // that doesn't need to be performed here. Return early if successful. + decodedMessage, err := base64.StdEncoding.DecodeString(message) + if err == nil { + return decodedMessage, nil + } + + // only interested in corrupt input errors below this. This type of error is the + // most likely to return, but better safe than sorry. + var cie base64.CorruptInputError + if !errors.As(err, &cie) { + return nil, fmt.Errorf("failed base64 decoding message: %w", err) + } + + // the below code is a workaround for macOS when it sends a GET PKIOperation, which seems to result + // in a query with the '+' and '/' not being percent encoded; only the padding ('=') is encoded. + // When that is unescaped in the code before this, this results in invalid base64. The workaround + // is to obtain the original query, extract the message, apply transformation(s) to make it valid + // base64 and try decoding it again. If it succeeds, the happy path can be followed with the patched + // message. Otherwise we still return an error. + rawQuery, err := parseRawQuery(r.URL.RawQuery) + if err != nil { + return nil, fmt.Errorf("failed to parse raw query: %w", err) + } + + rawMessage := rawQuery.Get("message") + if rawMessage == "" { + return nil, errors.New("no message in raw query") + } + + rawMessage = strings.ReplaceAll(rawMessage, "%3D", "=") // apparently the padding arrives encoded; the others (+, /) not? + decodedMessage, err = base64.StdEncoding.DecodeString(rawMessage) + if err != nil { + return nil, fmt.Errorf("failed base64 decoding raw message: %w", err) + } + + return decodedMessage, nil +} + +// parseRawQuery parses a URL query into url.Values. It skips +// unescaping keys and values. This code is based on url.ParseQuery. +func parseRawQuery(query string) (url.Values, error) { + m := make(url.Values) + err := parseRawQueryWithoutUnescaping(m, query) + return m, err +} + +// parseRawQueryWithoutUnescaping parses the raw query into url.Values, skipping +// unescaping of the parts. This code is based on url.parseQuery. +func parseRawQueryWithoutUnescaping(m url.Values, query string) (err error) { + for query != "" { + var key string + key, query, _ = strings.Cut(query, "&") + if strings.Contains(key, ";") { + return errors.New("invalid semicolon separator in query") + } + if key == "" { + continue + } + key, value, _ := strings.Cut(key, "=") + m[key] = append(m[key], value) + } + return err +} + // lookupProvisioner loads the provisioner associated with the request. // Responds 404 if the provisioner does not exist. func lookupProvisioner(next http.HandlerFunc) http.HandlerFunc { diff --git a/scep/api/api_test.go b/scep/api/api_test.go index ef3e57ab8..a1782933e 100644 --- a/scep/api/api_test.go +++ b/scep/api/api_test.go @@ -3,15 +3,27 @@ package api import ( "bytes" + "encoding/base64" "errors" + "fmt" "net/http" "net/http/httptest" - "reflect" + "net/url" + "strings" "testing" "testing/iotest" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_decodeRequest(t *testing.T) { + randomB64 := "wx/1mQ49TpdLRfvVjQhXNSe8RB3hjZEarqYp5XVIxpSbvOhQSs8hP2TgucID1IputbA8JC6CbsUpcVae3+8hRNqs5pTsSHP2aNxsw8AHGSX9dZVymSclkUV8irk+ztfEfs7aLA==" + expectedRandom, err := base64.StdEncoding.DecodeString(randomB64) + require.NoError(t, err) + weirdMacOSCase := "wx/1mQ49TpdLRfvVjQhXNSe8RB3hjZEarqYp5XVIxpSbvOhQSs8hP2TgucID1IputbA8JC6CbsUpcVae3+8hRNqs5pTsSHP2aNxsw8AHGSX9dZVymSclkUV8irk+ztfEfs7aLA%3D%3D" + expectedWeirdMacOSCase, err := base64.StdEncoding.DecodeString(strings.ReplaceAll(weirdMacOSCase, "%3D", "=")) + require.NoError(t, err) type args struct { r *http.Request } @@ -21,6 +33,22 @@ func Test_decodeRequest(t *testing.T) { want request wantErr bool }{ + { + name: "fail/invalid-query", + args: args{ + r: httptest.NewRequest(http.MethodGet, "http://scep:8080/?operation=bla;message=invalid-separator", http.NoBody), + }, + want: request{}, + wantErr: true, + }, + { + name: "fail/empty-operation", + args: args{ + r: httptest.NewRequest(http.MethodGet, "http://scep:8080/?operation=", http.NoBody), + }, + want: request{}, + wantErr: true, + }, { name: "fail/unsupported-method", args: args{ @@ -37,6 +65,14 @@ func Test_decodeRequest(t *testing.T) { want: request{}, wantErr: true, }, + { + name: "fail/get-PKIOperation-empty-message", + args: args{ + r: httptest.NewRequest(http.MethodGet, "http://scep:8080/?operation=PKIOperation&message=", http.NoBody), + }, + want: request{}, + wantErr: true, + }, { name: "fail/get-PKIOperation", args: args{ @@ -86,6 +122,39 @@ func Test_decodeRequest(t *testing.T) { }, wantErr: false, }, + { + name: "ok/get-PKIOperation-escaped", + args: args{ + r: httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://scep:8080/?operation=PKIOperation&message=%s", url.QueryEscape(randomB64)), http.NoBody), + }, + want: request{ + Operation: "PKIOperation", + Message: expectedRandom, + }, + wantErr: false, + }, + { + name: "ok/get-PKIOperation-not-escaped", // bit of a special case, but this is supported because of the macOS case now + args: args{ + r: httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://scep:8080/?operation=PKIOperation&message=%s", randomB64), http.NoBody), + }, + want: request{ + Operation: "PKIOperation", + Message: expectedRandom, + }, + wantErr: false, + }, + { + name: "ok/get-PKIOperation-weird-macos-case", // a special case for macOS, which seems to result in the message not arriving fully percent-encoded + args: args{ + r: httptest.NewRequest(http.MethodGet, fmt.Sprintf("http://scep:8080/?operation=PKIOperation&message=%s", weirdMacOSCase), http.NoBody), + }, + want: request{ + Operation: "PKIOperation", + Message: expectedWeirdMacOSCase, + }, + wantErr: false, + }, { name: "ok/post-PKIOperation", args: args{ @@ -101,13 +170,14 @@ func Test_decodeRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := decodeRequest(tt.args.r) - if (err != nil) != tt.wantErr { - t.Errorf("decodeRequest() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Equal(t, tt.want, got) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("decodeRequest() = %v, want %v", got, tt.want) - } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) }) } }