diff --git a/chains.go b/chains.go index 89f7505..8739694 100644 --- a/chains.go +++ b/chains.go @@ -24,17 +24,25 @@ import ( // AssembleServerChains takes in an array of certificates, finds all certificates with // x509.ExtKeyUsageAny or x509.ExtKeyUsageServerAuth and builds an array of leaf-first -// chains. -func AssembleServerChains(certs []*x509.Certificate) ([][]*x509.Certificate, error) { +// chains. Chains are built starting from server authentication certificates found in `certs` +// and the signer chains are built from `certs` and `cas`. Both slices are de-duped +// and the `cas` slice is filtered for certificates with the CA flag set. +func AssembleServerChains(certs []*x509.Certificate, cas []*x509.Certificate) ([][]*x509.Certificate, error) { if len(certs) == 0 { return nil, nil } + certs = getUniqueCerts(certs) + var chains [][]*x509.Certificate var serverCerts []*x509.Certificate for _, cert := range certs { + //if we find CAs add them to the known CAs + if cert.IsCA { + cas = append(cas, cert) + } //if not a CA and have any DNS/IP SANs //primary lookup method is "a leaf with SANs" if (!cert.IsCA) && (len(cert.DNSNames) != 0 || len(cert.IPAddresses) != 0) { @@ -50,8 +58,10 @@ func AssembleServerChains(certs []*x509.Certificate) ([][]*x509.Certificate, err } } + cas = getUniqueCas(cas) + for _, serverCert := range serverCerts { - chain := buildChain(serverCert, certs) + chain := buildChain(serverCert, cas) chains = append(chains, chain) } @@ -59,7 +69,7 @@ func AssembleServerChains(certs []*x509.Certificate) ([][]*x509.Certificate, err } // buildChain will build as much of a chain as possible from startingLeaf up using signature checking. -func buildChain(startingLeaf *x509.Certificate, certs []*x509.Certificate) []*x509.Certificate { +func buildChain(startingLeaf *x509.Certificate, cas []*x509.Certificate) []*x509.Certificate { var chain []*x509.Certificate current := startingLeaf @@ -77,7 +87,7 @@ func buildChain(startingLeaf *x509.Certificate, certs []*x509.Certificate) []*x5 parentFound := false //search by checking signature - for _, next := range certs { + for _, next := range cas { if next.IsCA { if err := current.CheckSignatureFrom(next); err == nil { current = next diff --git a/chains_test.go b/chains_test.go index 793910e..199f2d5 100644 --- a/chains_test.go +++ b/chains_test.go @@ -34,30 +34,39 @@ import ( ) func Test_Assemble(t *testing.T) { - t.Run("returns nil and no error on nil certs", func(t *testing.T) { + t.Run("returns nil and no error on nil certs and cas", func(t *testing.T) { req := require.New(t) - ret, err := AssembleServerChains(nil) + ret, err := AssembleServerChains(nil, nil) req.Nil(ret) req.NoError(err) }) - t.Run("returns nil and no error on 0 certs", func(t *testing.T) { + t.Run("returns nil and no error on 0 certs and nil cas", func(t *testing.T) { req := require.New(t) - ret, err := AssembleServerChains([]*x509.Certificate{}) + ret, err := AssembleServerChains([]*x509.Certificate{}, nil) req.Nil(ret) req.NoError(err) }) - t.Run("returns 1 chain of 1 cert with 1 root given only 1 leaf all with AKIDs", func(t *testing.T) { + t.Run("returns nil and no error on nil certs and 0 cas", func(t *testing.T) { + req := require.New(t) + + ret, err := AssembleServerChains(nil, []*x509.Certificate{}) + + req.Nil(ret) + req.NoError(err) + }) + + t.Run("returns 1 chain of 1 cert with 1 root given only 1 leaf all with AKIDs, nil cas", func(t *testing.T) { req := require.New(t) root := newRootCa() leaf := root.NewLeafWithAKID() - ret, err := AssembleServerChains([]*x509.Certificate{leaf.cert}) + ret, err := AssembleServerChains([]*x509.Certificate{leaf.cert}, nil) req.NoError(err) req.NotNil(ret) @@ -66,7 +75,7 @@ func Test_Assemble(t *testing.T) { req.Equal(ret[0][0], leaf.cert) }) - t.Run("returns 1 chain of 1 cert with 2 root given only 1 leaf with and without AKIDs", func(t *testing.T) { + t.Run("returns 1 chain of 1 cert with 2 root given only 1 leaf with and without AKIDs, nil cas", func(t *testing.T) { req := require.New(t) root := newRootCa() leaf := root.NewLeafWithAKID() @@ -74,7 +83,7 @@ func Test_Assemble(t *testing.T) { root2 := newRootCa() intermediate2 := root2.NewIntermediateWithoutAKID() - ret, err := AssembleServerChains([]*x509.Certificate{leaf.cert, root2.cert, intermediate2.cert}) + ret, err := AssembleServerChains([]*x509.Certificate{leaf.cert, root2.cert, intermediate2.cert}, nil) req.NoError(err) req.NotNil(ret) @@ -83,12 +92,12 @@ func Test_Assemble(t *testing.T) { req.Equal(ret[0][0], leaf.cert) }) - t.Run("returns 1 chain of 1 cert with 1 root given only 1 leaf all without AKIDs", func(t *testing.T) { + t.Run("returns 1 chain of 1 cert with 1 root given only 1 leaf all without AKIDs, nil cas", func(t *testing.T) { req := require.New(t) root := newRootCa() leaf := root.NewLeafWithoutAKID() - ret, err := AssembleServerChains([]*x509.Certificate{leaf.cert}) + ret, err := AssembleServerChains([]*x509.Certificate{leaf.cert}, nil) req.NoError(err) req.NotNil(ret) @@ -97,7 +106,7 @@ func Test_Assemble(t *testing.T) { req.Equal(ret[0][0], leaf.cert) }) - t.Run("returns 2 chains for two different CAs all with AKIDs", func(t *testing.T) { + t.Run("returns 2 chains for two different CAs all with AKIDs, nil cas", func(t *testing.T) { req := require.New(t) root1 := newRootCa() @@ -107,7 +116,7 @@ func Test_Assemble(t *testing.T) { intermediate2 := root2.NewIntermediateWithAKID() leaf2 := intermediate2.NewLeafWithAKID() - ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, leaf1.cert, root2.cert, intermediate2.cert, leaf2.cert}) + ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, leaf1.cert, root2.cert, intermediate2.cert, leaf2.cert}, nil) req.NoError(err) req.NotNil(ret) @@ -126,7 +135,7 @@ func Test_Assemble(t *testing.T) { req.Equal(ret[1][2], root2.cert) }) - t.Run("returns 2 chains for two different CAs all without AKIDs", func(t *testing.T) { + t.Run("returns 2 chains for two different CAs all without AKIDs, nil cas", func(t *testing.T) { req := require.New(t) root1 := newRootCa() @@ -136,7 +145,7 @@ func Test_Assemble(t *testing.T) { intermediate2 := root2.NewIntermediateWithoutAKID() leaf2 := intermediate2.NewLeafWithoutAKID() - ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, leaf1.cert, root2.cert, intermediate2.cert, leaf2.cert}) + ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, leaf1.cert, root2.cert, intermediate2.cert, leaf2.cert}, nil) req.NoError(err) req.NotNil(ret) @@ -155,7 +164,7 @@ func Test_Assemble(t *testing.T) { req.Equal(ret[1][2], root2.cert) }) - t.Run("returns 1 chain for ca>intermediate>leaf + random intermediates and CAs with AKIDs", func(t *testing.T) { + t.Run("returns 1 chain for ca>intermediate>leaf + random intermediates and CAs with AKIDs, nil cas", func(t *testing.T) { req := require.New(t) root1 := newRootCa() @@ -168,7 +177,7 @@ func Test_Assemble(t *testing.T) { root3 := newRootCa() intermediate3 := root3.NewIntermediateWithAKID() - ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, intermediate1.cert, root2.cert, intermediate2.cert, leaf2.cert, intermediate3.cert}) + ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, intermediate1.cert, root2.cert, intermediate2.cert, leaf2.cert, intermediate3.cert}, nil) req.NoError(err) req.NotNil(ret) @@ -181,14 +190,14 @@ func Test_Assemble(t *testing.T) { req.Equal(ret[0][2], root2.cert) }) - t.Run("returns 1 chain ca>intermediate>leaf with mixed AKID/no AKID", func(t *testing.T) { + t.Run("returns 1 chain ca>intermediate>leaf with mixed AKID/no AKID, nil cas", func(t *testing.T) { req := require.New(t) root1 := newRootCa() intermediate1 := root1.NewIntermediateWithoutAKID() leaf1 := intermediate1.NewLeafWithAKID() - ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, intermediate1.cert, leaf1.cert}) + ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, intermediate1.cert, leaf1.cert}, nil) req.NoError(err) req.NotNil(ret) @@ -202,7 +211,71 @@ func Test_Assemble(t *testing.T) { req.Equal(ret[0][2], root1.cert) }) - t.Run("can host a http server for 127.0.0.1 and localhost from two different certificates", func(t *testing.T) { + t.Run("returns 1 chain for ca>intermediate>leaf + random intermediates, leafs and CAs as cas", func(t *testing.T) { + req := require.New(t) + + root1 := newRootCa() + intermediate1 := root1.NewIntermediateWithAKID() + leaf1 := intermediate1.NewLeafWithoutAKID() + + root2 := newRootCa() + intermediate2 := root2.NewIntermediateWithAKID() + leaf2 := intermediate2.NewLeafWithAKID() + + root3 := newRootCa() + intermediate3 := root3.NewIntermediateWithAKID() + + ret, err := AssembleServerChains( + []*x509.Certificate{leaf1.cert}, + []*x509.Certificate{root1.cert, intermediate1.cert, root2.cert, intermediate2.cert, leaf2.cert, intermediate3.cert}) + + req.NoError(err) + req.NotNil(ret) + req.Len(ret, 1) + + // root1 -> leaf2 + req.Len(ret[0], 3) + req.Equal(ret[0][0], leaf1.cert) + req.Equal(ret[0][1], intermediate1.cert) + req.Equal(ret[0][2], root1.cert) + }) + + t.Run("returns 2 chains for ca>intermediate>leaf + random intermediates, and CAs as cas", func(t *testing.T) { + req := require.New(t) + + root1 := newRootCa() + intermediate1 := root1.NewIntermediateWithAKID() + leaf1 := intermediate1.NewLeafWithoutAKID() + + root2 := newRootCa() + intermediate2 := root2.NewIntermediateWithAKID() + leaf2 := intermediate2.NewLeafWithAKID() + + root3 := newRootCa() + intermediate3 := root3.NewIntermediateWithAKID() + + ret, err := AssembleServerChains( + []*x509.Certificate{leaf1.cert, leaf2.cert}, + []*x509.Certificate{root1.cert, intermediate1.cert, root2.cert, intermediate2.cert, intermediate3.cert}) + + req.NoError(err) + req.NotNil(ret) + req.Len(ret, 2) + + // root1 -> leaf1 + req.Len(ret[0], 3) + req.Equal(ret[0][0], leaf1.cert) + req.Equal(ret[0][1], intermediate1.cert) + req.Equal(ret[0][2], root1.cert) + + // root2 -> leaf2 + req.Len(ret[1], 3) + req.Equal(ret[1][0], leaf2.cert) + req.Equal(ret[1][1], intermediate2.cert) + req.Equal(ret[1][2], root2.cert) + }) + + t.Run("can host a http server for 127.0.0.1 and localhost from two different certificates, nil cas", func(t *testing.T) { req := require.New(t) root := newRootCa() diff --git a/identity.go b/identity.go index e20d0ce..8561536 100644 --- a/identity.go +++ b/identity.go @@ -430,8 +430,7 @@ func LoadIdentity(cfg Config) (Identity, error) { return nil, errors.New("no corresponding key specified for identity server_cert") } - svrCert = getUniqueCerts(svrCert, id.caPool) - chains, err := AssembleServerChains(svrCert) + chains, err := AssembleServerChains(svrCert, id.caPool.certs) if err != nil { return nil, err @@ -469,7 +468,7 @@ func LoadIdentity(cfg Config) (Identity, error) { return nil, errors.New("no key specified for identity alternate server cert") } - chains, err := AssembleServerChains(svrCert) + chains, err := AssembleServerChains(svrCert, nil) if err != nil { return nil, err @@ -483,11 +482,34 @@ func LoadIdentity(cfg Config) (Identity, error) { return id, nil } -func getUniqueCerts(certs []*x509.Certificate, pool *CaPool) []*x509.Certificate { +// getUniqueCerts will return a slice of unique certificates from the given slice +func getUniqueCerts(certs []*x509.Certificate) []*x509.Certificate { set := map[string]*x509.Certificate{} var keys []string // track insertion order so that server certs come before pool certs - addCerts := func(certs []*x509.Certificate) { - for _, cert := range certs { + + for _, cert := range certs { + hash := sha1.Sum(cert.Raw) + fp := string(hash[:]) + if _, exists := set[fp]; !exists { + set[fp] = cert + keys = append(keys, fp) + } + } + + var result []*x509.Certificate + for _, key := range keys { + result = append(result, set[key]) + } + return result +} + +// getUniqueCas will return a slice of unique certificates that are CAs from the given slice +func getUniqueCas(certs []*x509.Certificate) []*x509.Certificate { + set := map[string]*x509.Certificate{} + var keys []string // track insertion order so that server certs come before pool certs + + for _, cert := range certs { + if cert.IsCA { hash := sha1.Sum(cert.Raw) fp := string(hash[:]) if _, exists := set[fp]; !exists { @@ -497,12 +519,6 @@ func getUniqueCerts(certs []*x509.Certificate, pool *CaPool) []*x509.Certificate } } - addCerts(certs) - - if pool != nil { - addCerts(pool.certs) - } - var result []*x509.Certificate for _, key := range keys { result = append(result, set[key])