Skip to content

Commit

Permalink
automatically reload TLS certificates when they change (#3598)
Browse files Browse the repository at this point in the history
* Dynamically refresh tls certs for all servers

* make sure that CertLoader is always closed

---------

Co-authored-by: aler9 <[email protected]>
  • Loading branch information
dbason and aler9 authored Aug 4, 2024
1 parent 972ffbf commit 1055be9
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 8 deletions.
108 changes: 108 additions & 0 deletions internal/certloader/certloader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Package certloader contains a certicate loader.
package certloader

import (
"crypto/tls"
"sync"

"github.com/bluenviron/mediamtx/internal/confwatcher"
"github.com/bluenviron/mediamtx/internal/logger"
)

// CertLoader is a certificate loader. It watches for changes to the certificate and key files.
type CertLoader struct {
log logger.Writer
certWatcher, keyWatcher *confwatcher.ConfWatcher
certPath, keyPath string
done chan struct{}

cert *tls.Certificate
certMu sync.RWMutex
}

// New allocates a CertLoader.
func New(certPath, keyPath string, log logger.Writer) (*CertLoader, error) {
cl := &CertLoader{
log: log,
certPath: certPath,
keyPath: keyPath,
done: make(chan struct{}),
}

var err error
cl.certWatcher, err = confwatcher.New(certPath)
if err != nil {
return nil, err
}

cl.keyWatcher, err = confwatcher.New(keyPath)
if err != nil {
cl.certWatcher.Close() //nolint:errcheck
return nil, err
}

cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, err
}

cl.certMu.Lock()
cl.cert = &cert
cl.certMu.Unlock()

go cl.watch()

return cl, nil
}

// Close closes a CertLoader and releases any underlying resources.
func (cl *CertLoader) Close() {
close(cl.done)
cl.certWatcher.Close() //nolint:errcheck
cl.keyWatcher.Close() //nolint:errcheck
cl.certMu.Lock()
defer cl.certMu.Unlock()
cl.cert = nil
}

// GetCertificate returns a function that returns the certificate for use in a tls.Config.
func (cl *CertLoader) GetCertificate() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
cl.certMu.RLock()
defer cl.certMu.RUnlock()
return cl.cert, nil
}
}

func (cl *CertLoader) watch() {
for {
select {
case <-cl.certWatcher.Watch():
cert, err := tls.LoadX509KeyPair(cl.certPath, cl.keyPath)
if err != nil {
cl.log.Log(logger.Error, "certloader failed to load after change to %s: %s", cl.certPath, err.Error())
continue
}

cl.certMu.Lock()
cl.cert = &cert
cl.certMu.Unlock()

cl.log.Log(logger.Info, "certificate reloaded after change to %s", cl.certPath)
case <-cl.keyWatcher.Watch():
cert, err := tls.LoadX509KeyPair(cl.certPath, cl.keyPath)
if err != nil {
cl.log.Log(logger.Error, "certloader failed to load after change to %s: %s", cl.keyPath, err.Error())
continue
}

cl.certMu.Lock()
cl.cert = &cert
cl.certMu.Unlock()

cl.log.Log(logger.Info, "certificate reloaded after change to %s", cl.keyPath)
case <-cl.done:
return
}
}
}
52 changes: 52 additions & 0 deletions internal/certloader/certloader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package certloader

import (
"crypto/tls"
"os"
"testing"
"time"

"github.com/bluenviron/mediamtx/internal/test"
"github.com/stretchr/testify/require"
)

func TestCertReload(t *testing.T) {
testData, err := tls.X509KeyPair(test.TLSCertPub, test.TLSCertKey)
require.NoError(t, err)

serverCertPath, err := test.CreateTempFile(test.TLSCertPub)
require.NoError(t, err)
defer os.Remove(serverCertPath)

serverKeyPath, err := test.CreateTempFile(test.TLSCertKey)
require.NoError(t, err)
defer os.Remove(serverKeyPath)

loader, err := New(serverCertPath, serverKeyPath, test.NilLogger)
require.NoError(t, err)
defer loader.Close()

getCert := loader.GetCertificate()
require.NotNil(t, getCert)

cert, err := getCert(nil)
require.NoError(t, err)
require.NotNil(t, cert)
require.Equal(t, &testData, cert)

testData, err = tls.X509KeyPair(test.TLSCertPubAlt, test.TLSCertKeyAlt)
require.NoError(t, err)

err = os.WriteFile(serverCertPath, test.TLSCertPubAlt, 0o644)
require.NoError(t, err)

err = os.WriteFile(serverKeyPath, test.TLSCertKeyAlt, 0o644)
require.NoError(t, err)

time.Sleep(1 * time.Second)

cert, err = getCert(nil)
require.NoError(t, err)
require.NotNil(t, cert)
require.Equal(t, &testData, cert)
}
15 changes: 11 additions & 4 deletions internal/protocols/httpp/wrapped_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/http"
"time"

"github.com/bluenviron/mediamtx/internal/certloader"
"github.com/bluenviron/mediamtx/internal/logger"
)

Expand All @@ -36,8 +37,9 @@ type WrappedServer struct {
Handler http.Handler
Parent logger.Writer

ln net.Listener
inner *http.Server
ln net.Listener
inner *http.Server
loader *certloader.CertLoader
}

// Initialize initializes a WrappedServer.
Expand All @@ -47,13 +49,15 @@ func (s *WrappedServer) Initialize() error {
if s.ServerCert == "" {
return fmt.Errorf("server cert is missing")
}
crt, err := tls.LoadX509KeyPair(s.ServerCert, s.ServerKey)

var err error
s.loader, err = certloader.New(s.ServerCert, s.ServerKey, s.Parent)
if err != nil {
return err
}

tlsConfig = &tls.Config{
Certificates: []tls.Certificate{crt},
GetCertificate: s.loader.GetCertificate(),
}
}

Expand Down Expand Up @@ -92,4 +96,7 @@ func (s *WrappedServer) Close() {
ctxCancel()
s.inner.Shutdown(ctx)
s.ln.Close() // in case Shutdown() is called before Serve()
if s.loader != nil {
s.loader.Close()
}
}
10 changes: 8 additions & 2 deletions internal/servers/rtmp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/google/uuid"

"github.com/bluenviron/mediamtx/internal/certloader"
"github.com/bluenviron/mediamtx/internal/conf"
"github.com/bluenviron/mediamtx/internal/defs"
"github.com/bluenviron/mediamtx/internal/externalcmd"
Expand Down Expand Up @@ -82,6 +83,7 @@ type Server struct {
wg sync.WaitGroup
ln net.Listener
conns map[*conn]struct{}
loader *certloader.CertLoader

// in
chNewConn chan net.Conn
Expand All @@ -99,13 +101,14 @@ func (s *Server) Initialize() error {
return net.Listen(restrictnetwork.Restrict("tcp", s.Address))
}

cert, err := tls.LoadX509KeyPair(s.ServerCert, s.ServerKey)
var err error
s.loader, err = certloader.New(s.ServerCert, s.ServerKey, s.Parent)
if err != nil {
return nil, err
}

network, address := restrictnetwork.Restrict("tcp", s.Address)
return tls.Listen(network, address, &tls.Config{Certificates: []tls.Certificate{cert}})
return tls.Listen(network, address, &tls.Config{GetCertificate: s.loader.GetCertificate()})
}()
if err != nil {
return err
Expand Down Expand Up @@ -153,6 +156,9 @@ func (s *Server) Close() {
s.Log(logger.Info, "listener is closing")
s.ctxCancel()
s.wg.Wait()
if s.loader != nil {
s.loader.Close()
}
}

func (s *Server) run() {
Expand Down
10 changes: 8 additions & 2 deletions internal/servers/rtsp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
"github.com/google/uuid"

"github.com/bluenviron/mediamtx/internal/certloader"
"github.com/bluenviron/mediamtx/internal/conf"
"github.com/bluenviron/mediamtx/internal/defs"
"github.com/bluenviron/mediamtx/internal/externalcmd"
Expand Down Expand Up @@ -89,6 +90,7 @@ type Server struct {
mutex sync.RWMutex
conns map[*gortsplib.ServerConn]*conn
sessions map[*gortsplib.ServerSession]*session
loader *certloader.CertLoader
}

// Initialize initializes the server.
Expand Down Expand Up @@ -118,12 +120,13 @@ func (s *Server) Initialize() error {
}

if s.IsTLS {
cert, err := tls.LoadX509KeyPair(s.ServerCert, s.ServerKey)
var err error
s.loader, err = certloader.New(s.ServerCert, s.ServerKey, s.Parent)
if err != nil {
return err
}

s.srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
s.srv.TLSConfig = &tls.Config{GetCertificate: s.loader.GetCertificate()}
}

err := s.srv.Start()
Expand Down Expand Up @@ -155,6 +158,9 @@ func (s *Server) Close() {
s.Log(logger.Info, "listener is closing")
s.ctxCancel()
s.wg.Wait()
if s.loader != nil {
s.loader.Close()
}
}

func (s *Server) run() {
Expand Down
52 changes: 52 additions & 0 deletions internal/test/tls_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,55 @@ y++U32uuSFiXDcSLarfIsE992MEJLSAynbF1Rsgsr3gXbGiuToJRyxbIeVy7gwzD
+3K6cnKEyg+0ekYmLertRFIY6SwWmY1fyKgTvxudMcsBY7dC4xs=
-----END RSA PRIVATE KEY-----
`)

// TLSCertPubAlt is the public key of an alternative test certificate.
var TLSCertPubAlt = []byte(`-----BEGIN CERTIFICATE-----
MIIDSTCCAjECFEut6ZxIOnbxi3bhrPLfPQZCLReNMA0GCSqGSIb3DQEBCwUAMGEx
CzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl
cm5ldCBXaWRnaXRzIFB0eSBMdGQxGjAYBgNVBAMMEW1lZGlhbXR4LnRlc3QuY29t
MB4XDTI0MDgwMTIzNDY0MloXDTM0MDczMDIzNDY0MlowYTELMAkGA1UEBhMCQVUx
EzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMg
UHR5IEx0ZDEaMBgGA1UEAwwRbWVkaWFtdHgudGVzdC5jb20wggEiMA0GCSqGSIb3
DQEBAQUAA4IBDwAwggEKAoIBAQCzfvG9eLXKSTDBoM+cgV/ThiNRI2JY6dpQV8rK
QFQ5bkkDUDP+2Ae/IWylgLLXmozsMwjz1Pu42awmGymBuo5HDbI4bxPJNQR9qRrR
2+MvfDgmZxyhw5NfZDlVl+enxhb3FRgbHsLBy4oSoHbRUdLApVdM0Kg6r3bXzkih
EEs63boFJOkPhs5H0NX7AzXyBp2WnvB71j+7avnMwAsjJHOiTs8wkp5wvRcIZpJl
MCandUkcZShMirug7QOcR9fAr5CVKxsO/DjqEjwkslJHFfizOl3yRx6nsxvW8JUd
dforpSRj84dkHTi7k37YTiji90GsOvh0qc0MfAmeE181HIb/AgMBAAEwDQYJKoZI
hvcNAQELBQADggEBAEWkLL/7nvt3iD7BVJNHLvAS6GwuTH99vCil6TFYwVl4goht
Dur7YfzN43vUq+lAwS3Ry4ka7tH72pAMkpNFRvHOikWGmWUSDo2DcLd8iu3ruLF7
yUg2ASQuekK0sUv4YKpAqV8gS2R4Jh4vLU+8L5iJ1XWGELbQ+H5wm4l7l+r2X6cD
/opmdV8Slfi0FlNQtflLsGoSlfZF5jHxqi3zyt8QdEf9WZt8e6JPxcx2Fq7Op51u
Qx9nosr5fLwhkx46+B/cotsbI/xPDjLF6RQ1OUpcHwg1HI6czoW4hHn33S0zstCf
BWt5Q1Mb2tGInbmbUgw3wUu/4nWoY+Mq4DKPlKs=
-----END CERTIFICATE-----`)

// TLSCertKeyAlt is the private key of an alternative test certificate.
var TLSCertKeyAlt = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEoQIBAAKCAQEAs37xvXi1ykkwwaDPnIFf04YjUSNiWOnaUFfKykBUOW5JA1Az
/tgHvyFspYCy15qM7DMI89T7uNmsJhspgbqORw2yOG8TyTUEfaka0dvjL3w4Jmcc
ocOTX2Q5VZfnp8YW9xUYGx7CwcuKEqB20VHSwKVXTNCoOq92185IoRBLOt26BSTp
D4bOR9DV+wM18gadlp7we9Y/u2r5zMALIyRzok7PMJKecL0XCGaSZTAmp3VJHGUo
TIq7oO0DnEfXwK+QlSsbDvw46hI8JLJSRxX4szpd8kcep7Mb1vCVHXX6K6UkY/OH
ZB04u5N+2E4o4vdBrDr4dKnNDHwJnhNfNRyG/wIDAQABAoH/WmCqV6Lv5dEnofCj
ZUO/Fdv0hf/LBS0g2SAoFRSCIM8aJ3dUUH0PaXoeINDGCMlIxT7tKXJg5jJNYhWx
g7oegw6vLe5ZiA+p5miL/uue+Jas4kLVp9DrfQLgQevt0gw4g/00pgy9adbFlTUD
a2HhPB7RIvXs8gYA6nVAT9jK1ST2pbeUgQNO4Ji4EjpPUkR2O7ISOlu5EV8Cj0eV
1Vs5B92Z7ORh7P2fFV2YBu+igd04+uYvei6slQl+F9cETvJv2Z9r37Yashvnn1in
uy/u1U4B1t4oOz81nHz6kxTixPpBOdJ6x8jLDgNGSsauJQfXT9xmB/rAr/NFq+7I
tbTNAoGBAMOgm3XXHWokmJnX9pfNj6ixNlrMuuez/yXMVwuxa2WFwAFN16tjJhBi
XOjestcvu/SRhOAMmYac5QdopJpLjO/FxO165r73eZhW/SJefyOHtfD29kHagA1u
JjcznU6tiA0O1owy6nuuaTfyVbDQj32PhVBx9ZwSI4778GFbjWl7AoGBAOrj4WCC
gTMaExpwNo+L+3VkM79YD1Obl13FcgtVoxjcoWjQeMx9D0k7adTV3xlchHFAjiD5
Gs/MZl8+seq+GDX3mODsmJkdRQbYId4g6IesiOnQ3Ug/Y282WZRnpB5h/BMnrcCZ
VoohnATA7f96c7XtPUgZyROmh24T7UIVwVdNAoGAbeeGT276TI6g2RWWqXRIOFrP
EbYhb1kViFPDt4MGtjOtSk5EUzpRwTSxw/aRfQmJS/6RKxqJCjKNDVuB1lmJpY9z
coPwrOr1+lssvalfPkPZOLZWZWrvNBxlBfBOeUxOuh9S89MLH08+N7tC3yJc6wq9
uBM+DF+4cHUkeF3qFY8CgYBzS+IwBj82/0CLRLNzaKnIqKPB846qYoA9NhLRv3ps
VLgiA9qXvXdIYhKDt2toPoKAOMjLJJtljpZdgB/C8wZdTyjKlzgcSEK+pk6RgyPA
nQ8jfjNwKDU9vLbh4rGrfDtIh7yBAoN5ECBOMQlh0xCDJ21iO834iFCH1t4qBxW9
LQKBgQC36adC2Gu+FJRvx4Mkm73fLmVdFbP6Do7qNwyVVyaG80PDVrFQrlWm4Dt7
AO9IwzaS1Lx+qmU1Fj1WfCtXuQa5nc9AzZ36TmM6+pAn8AC7PdNqc0qSdefVrIjj
zRGhUPaJV3A+sfO+xedBsAFnqNuX9oODYVGbTjuc2OWC30MGaw==
-----END RSA PRIVATE KEY-----
`)

0 comments on commit 1055be9

Please sign in to comment.