Skip to content

Commit

Permalink
fixes openziti/ziti#1422 server cert chain assembly
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewpmartinez committed Oct 23, 2023
1 parent 4858c7f commit f080cb9
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 36 deletions.
20 changes: 15 additions & 5 deletions chains.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,25 @@ import (

// AssembleServerChains takes in an array of certificates, finds all certificates with
// x509.ExtKeyUsageAny or x509.ExtKeyUsageServerAuth and builds an array of leaf-first
// chains.
func AssembleServerChains(certs []*x509.Certificate) ([][]*x509.Certificate, error) {
// chains. Chains are built starting from server authentication certificates found in `certs`
// and the signer chains are built from `certs` and `cas`. Both slices are de-duped
// and the `cas` slice is filtered for certificates with the CA flag set.
func AssembleServerChains(certs []*x509.Certificate, cas []*x509.Certificate) ([][]*x509.Certificate, error) {
if len(certs) == 0 {
return nil, nil
}

certs = getUniqueCerts(certs)

var chains [][]*x509.Certificate

var serverCerts []*x509.Certificate

for _, cert := range certs {
//if we find CAs add them to the known CAs
if cert.IsCA {
cas = append(cas, cert)
}
//if not a CA and have any DNS/IP SANs
//primary lookup method is "a leaf with SANs"
if (!cert.IsCA) && (len(cert.DNSNames) != 0 || len(cert.IPAddresses) != 0) {
Expand All @@ -50,16 +58,18 @@ func AssembleServerChains(certs []*x509.Certificate) ([][]*x509.Certificate, err
}
}

cas = getUniqueCas(cas)

for _, serverCert := range serverCerts {
chain := buildChain(serverCert, certs)
chain := buildChain(serverCert, cas)
chains = append(chains, chain)
}

return chains, nil
}

// buildChain will build as much of a chain as possible from startingLeaf up using signature checking.
func buildChain(startingLeaf *x509.Certificate, certs []*x509.Certificate) []*x509.Certificate {
func buildChain(startingLeaf *x509.Certificate, cas []*x509.Certificate) []*x509.Certificate {
var chain []*x509.Certificate

current := startingLeaf
Expand All @@ -77,7 +87,7 @@ func buildChain(startingLeaf *x509.Certificate, certs []*x509.Certificate) []*x5
parentFound := false

//search by checking signature
for _, next := range certs {
for _, next := range cas {
if next.IsCA {
if err := current.CheckSignatureFrom(next); err == nil {
current = next
Expand Down
111 changes: 92 additions & 19 deletions chains_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,39 @@ import (
)

func Test_Assemble(t *testing.T) {
t.Run("returns nil and no error on nil certs", func(t *testing.T) {
t.Run("returns nil and no error on nil certs and cas", func(t *testing.T) {
req := require.New(t)

ret, err := AssembleServerChains(nil)
ret, err := AssembleServerChains(nil, nil)

req.Nil(ret)
req.NoError(err)
})

t.Run("returns nil and no error on 0 certs", func(t *testing.T) {
t.Run("returns nil and no error on 0 certs and nil cas", func(t *testing.T) {
req := require.New(t)

ret, err := AssembleServerChains([]*x509.Certificate{})
ret, err := AssembleServerChains([]*x509.Certificate{}, nil)

req.Nil(ret)
req.NoError(err)
})

t.Run("returns 1 chain of 1 cert with 1 root given only 1 leaf all with AKIDs", func(t *testing.T) {
t.Run("returns nil and no error on nil certs and 0 cas", func(t *testing.T) {
req := require.New(t)

ret, err := AssembleServerChains(nil, []*x509.Certificate{})

req.Nil(ret)
req.NoError(err)
})

t.Run("returns 1 chain of 1 cert with 1 root given only 1 leaf all with AKIDs, nil cas", func(t *testing.T) {
req := require.New(t)
root := newRootCa()
leaf := root.NewLeafWithAKID()

ret, err := AssembleServerChains([]*x509.Certificate{leaf.cert})
ret, err := AssembleServerChains([]*x509.Certificate{leaf.cert}, nil)

req.NoError(err)
req.NotNil(ret)
Expand All @@ -66,15 +75,15 @@ func Test_Assemble(t *testing.T) {
req.Equal(ret[0][0], leaf.cert)
})

t.Run("returns 1 chain of 1 cert with 2 root given only 1 leaf with and without AKIDs", func(t *testing.T) {
t.Run("returns 1 chain of 1 cert with 2 root given only 1 leaf with and without AKIDs, nil cas", func(t *testing.T) {
req := require.New(t)
root := newRootCa()
leaf := root.NewLeafWithAKID()

root2 := newRootCa()
intermediate2 := root2.NewIntermediateWithoutAKID()

ret, err := AssembleServerChains([]*x509.Certificate{leaf.cert, root2.cert, intermediate2.cert})
ret, err := AssembleServerChains([]*x509.Certificate{leaf.cert, root2.cert, intermediate2.cert}, nil)

req.NoError(err)
req.NotNil(ret)
Expand All @@ -83,12 +92,12 @@ func Test_Assemble(t *testing.T) {
req.Equal(ret[0][0], leaf.cert)
})

t.Run("returns 1 chain of 1 cert with 1 root given only 1 leaf all without AKIDs", func(t *testing.T) {
t.Run("returns 1 chain of 1 cert with 1 root given only 1 leaf all without AKIDs, nil cas", func(t *testing.T) {
req := require.New(t)
root := newRootCa()
leaf := root.NewLeafWithoutAKID()

ret, err := AssembleServerChains([]*x509.Certificate{leaf.cert})
ret, err := AssembleServerChains([]*x509.Certificate{leaf.cert}, nil)

req.NoError(err)
req.NotNil(ret)
Expand All @@ -97,7 +106,7 @@ func Test_Assemble(t *testing.T) {
req.Equal(ret[0][0], leaf.cert)
})

t.Run("returns 2 chains for two different CAs all with AKIDs", func(t *testing.T) {
t.Run("returns 2 chains for two different CAs all with AKIDs, nil cas", func(t *testing.T) {
req := require.New(t)

root1 := newRootCa()
Expand All @@ -107,7 +116,7 @@ func Test_Assemble(t *testing.T) {
intermediate2 := root2.NewIntermediateWithAKID()
leaf2 := intermediate2.NewLeafWithAKID()

ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, leaf1.cert, root2.cert, intermediate2.cert, leaf2.cert})
ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, leaf1.cert, root2.cert, intermediate2.cert, leaf2.cert}, nil)

req.NoError(err)
req.NotNil(ret)
Expand All @@ -126,7 +135,7 @@ func Test_Assemble(t *testing.T) {
req.Equal(ret[1][2], root2.cert)
})

t.Run("returns 2 chains for two different CAs all without AKIDs", func(t *testing.T) {
t.Run("returns 2 chains for two different CAs all without AKIDs, nil cas", func(t *testing.T) {
req := require.New(t)

root1 := newRootCa()
Expand All @@ -136,7 +145,7 @@ func Test_Assemble(t *testing.T) {
intermediate2 := root2.NewIntermediateWithoutAKID()
leaf2 := intermediate2.NewLeafWithoutAKID()

ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, leaf1.cert, root2.cert, intermediate2.cert, leaf2.cert})
ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, leaf1.cert, root2.cert, intermediate2.cert, leaf2.cert}, nil)

req.NoError(err)
req.NotNil(ret)
Expand All @@ -155,7 +164,7 @@ func Test_Assemble(t *testing.T) {
req.Equal(ret[1][2], root2.cert)
})

t.Run("returns 1 chain for ca>intermediate>leaf + random intermediates and CAs with AKIDs", func(t *testing.T) {
t.Run("returns 1 chain for ca>intermediate>leaf + random intermediates and CAs with AKIDs, nil cas", func(t *testing.T) {
req := require.New(t)

root1 := newRootCa()
Expand All @@ -168,7 +177,7 @@ func Test_Assemble(t *testing.T) {
root3 := newRootCa()
intermediate3 := root3.NewIntermediateWithAKID()

ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, intermediate1.cert, root2.cert, intermediate2.cert, leaf2.cert, intermediate3.cert})
ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, intermediate1.cert, root2.cert, intermediate2.cert, leaf2.cert, intermediate3.cert}, nil)

req.NoError(err)
req.NotNil(ret)
Expand All @@ -181,14 +190,14 @@ func Test_Assemble(t *testing.T) {
req.Equal(ret[0][2], root2.cert)
})

t.Run("returns 1 chain ca>intermediate>leaf with mixed AKID/no AKID", func(t *testing.T) {
t.Run("returns 1 chain ca>intermediate>leaf with mixed AKID/no AKID, nil cas", func(t *testing.T) {
req := require.New(t)

root1 := newRootCa()
intermediate1 := root1.NewIntermediateWithoutAKID()
leaf1 := intermediate1.NewLeafWithAKID()

ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, intermediate1.cert, leaf1.cert})
ret, err := AssembleServerChains([]*x509.Certificate{root1.cert, intermediate1.cert, leaf1.cert}, nil)

req.NoError(err)
req.NotNil(ret)
Expand All @@ -202,7 +211,71 @@ func Test_Assemble(t *testing.T) {
req.Equal(ret[0][2], root1.cert)
})

t.Run("can host a http server for 127.0.0.1 and localhost from two different certificates", func(t *testing.T) {
t.Run("returns 1 chain for ca>intermediate>leaf + random intermediates, leafs and CAs as cas", func(t *testing.T) {
req := require.New(t)

root1 := newRootCa()
intermediate1 := root1.NewIntermediateWithAKID()
leaf1 := intermediate1.NewLeafWithoutAKID()

root2 := newRootCa()
intermediate2 := root2.NewIntermediateWithAKID()
leaf2 := intermediate2.NewLeafWithAKID()

root3 := newRootCa()
intermediate3 := root3.NewIntermediateWithAKID()

ret, err := AssembleServerChains(
[]*x509.Certificate{leaf1.cert},
[]*x509.Certificate{root1.cert, intermediate1.cert, root2.cert, intermediate2.cert, leaf2.cert, intermediate3.cert})

req.NoError(err)
req.NotNil(ret)
req.Len(ret, 1)

// root1 -> leaf2
req.Len(ret[0], 3)
req.Equal(ret[0][0], leaf1.cert)
req.Equal(ret[0][1], intermediate1.cert)
req.Equal(ret[0][2], root1.cert)
})

t.Run("returns 2 chains for ca>intermediate>leaf + random intermediates, and CAs as cas", func(t *testing.T) {
req := require.New(t)

root1 := newRootCa()
intermediate1 := root1.NewIntermediateWithAKID()
leaf1 := intermediate1.NewLeafWithoutAKID()

root2 := newRootCa()
intermediate2 := root2.NewIntermediateWithAKID()
leaf2 := intermediate2.NewLeafWithAKID()

root3 := newRootCa()
intermediate3 := root3.NewIntermediateWithAKID()

ret, err := AssembleServerChains(
[]*x509.Certificate{leaf1.cert, leaf2.cert},
[]*x509.Certificate{root1.cert, intermediate1.cert, root2.cert, intermediate2.cert, intermediate3.cert})

req.NoError(err)
req.NotNil(ret)
req.Len(ret, 2)

// root1 -> leaf1
req.Len(ret[0], 3)
req.Equal(ret[0][0], leaf1.cert)
req.Equal(ret[0][1], intermediate1.cert)
req.Equal(ret[0][2], root1.cert)

// root2 -> leaf2
req.Len(ret[1], 3)
req.Equal(ret[1][0], leaf2.cert)
req.Equal(ret[1][1], intermediate2.cert)
req.Equal(ret[1][2], root2.cert)
})

t.Run("can host a http server for 127.0.0.1 and localhost from two different certificates, nil cas", func(t *testing.T) {
req := require.New(t)
root := newRootCa()

Expand Down
40 changes: 28 additions & 12 deletions identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,7 @@ func LoadIdentity(cfg Config) (Identity, error) {
return nil, errors.New("no corresponding key specified for identity server_cert")
}

svrCert = getUniqueCerts(svrCert, id.caPool)
chains, err := AssembleServerChains(svrCert)
chains, err := AssembleServerChains(svrCert, id.caPool.certs)

if err != nil {
return nil, err
Expand Down Expand Up @@ -469,7 +468,7 @@ func LoadIdentity(cfg Config) (Identity, error) {
return nil, errors.New("no key specified for identity alternate server cert")
}

chains, err := AssembleServerChains(svrCert)
chains, err := AssembleServerChains(svrCert, nil)

if err != nil {
return nil, err
Expand All @@ -483,11 +482,34 @@ func LoadIdentity(cfg Config) (Identity, error) {
return id, nil
}

func getUniqueCerts(certs []*x509.Certificate, pool *CaPool) []*x509.Certificate {
// getUniqueCerts will return a slice of unique certificates from the given slice
func getUniqueCerts(certs []*x509.Certificate) []*x509.Certificate {
set := map[string]*x509.Certificate{}
var keys []string // track insertion order so that server certs come before pool certs
addCerts := func(certs []*x509.Certificate) {
for _, cert := range certs {

for _, cert := range certs {
hash := sha1.Sum(cert.Raw)
fp := string(hash[:])
if _, exists := set[fp]; !exists {
set[fp] = cert
keys = append(keys, fp)
}
}

var result []*x509.Certificate
for _, key := range keys {
result = append(result, set[key])
}
return result
}

// getUniqueCas will return a slice of unique certificates that are CAs from the given slice
func getUniqueCas(certs []*x509.Certificate) []*x509.Certificate {
set := map[string]*x509.Certificate{}
var keys []string // track insertion order so that server certs come before pool certs

for _, cert := range certs {
if cert.IsCA {
hash := sha1.Sum(cert.Raw)
fp := string(hash[:])
if _, exists := set[fp]; !exists {
Expand All @@ -497,12 +519,6 @@ func getUniqueCerts(certs []*x509.Certificate, pool *CaPool) []*x509.Certificate
}
}

addCerts(certs)

if pool != nil {
addCerts(pool.certs)
}

var result []*x509.Certificate
for _, key := range keys {
result = append(result, set[key])
Expand Down

0 comments on commit f080cb9

Please sign in to comment.