diff --git a/internal/certloader/certloader.go b/internal/certloader/certloader.go new file mode 100644 index 00000000000..8aade6e4f5b --- /dev/null +++ b/internal/certloader/certloader.go @@ -0,0 +1,108 @@ +// Package certloader contains a certicate loader. +package certloader + +import ( + "crypto/tls" + "sync" + + "github.com/bluenviron/mediamtx/internal/confwatcher" + "github.com/bluenviron/mediamtx/internal/logger" +) + +// CertLoader is a certificate loader. It watches for changes to the certificate and key files. +type CertLoader struct { + log logger.Writer + certWatcher, keyWatcher *confwatcher.ConfWatcher + certPath, keyPath string + done chan struct{} + + cert *tls.Certificate + certMu sync.RWMutex +} + +// New allocates a CertLoader. +func New(certPath, keyPath string, log logger.Writer) (*CertLoader, error) { + cl := &CertLoader{ + log: log, + certPath: certPath, + keyPath: keyPath, + done: make(chan struct{}), + } + + var err error + cl.certWatcher, err = confwatcher.New(certPath) + if err != nil { + return nil, err + } + + cl.keyWatcher, err = confwatcher.New(keyPath) + if err != nil { + cl.certWatcher.Close() //nolint:errcheck + return nil, err + } + + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, err + } + + cl.certMu.Lock() + cl.cert = &cert + cl.certMu.Unlock() + + go cl.watch() + + return cl, nil +} + +// Close closes a CertLoader and releases any underlying resources. +func (cl *CertLoader) Close() { + close(cl.done) + cl.certWatcher.Close() //nolint:errcheck + cl.keyWatcher.Close() //nolint:errcheck + cl.certMu.Lock() + defer cl.certMu.Unlock() + cl.cert = nil +} + +// GetCertificate returns a function that returns the certificate for use in a tls.Config. +func (cl *CertLoader) GetCertificate() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + cl.certMu.RLock() + defer cl.certMu.RUnlock() + return cl.cert, nil + } +} + +func (cl *CertLoader) watch() { + for { + select { + case <-cl.certWatcher.Watch(): + cert, err := tls.LoadX509KeyPair(cl.certPath, cl.keyPath) + if err != nil { + cl.log.Log(logger.Error, "certloader failed to load after change to %s: %s", cl.certPath, err.Error()) + continue + } + + cl.certMu.Lock() + cl.cert = &cert + cl.certMu.Unlock() + + cl.log.Log(logger.Info, "certificate reloaded after change to %s", cl.certPath) + case <-cl.keyWatcher.Watch(): + cert, err := tls.LoadX509KeyPair(cl.certPath, cl.keyPath) + if err != nil { + cl.log.Log(logger.Error, "certloader failed to load after change to %s: %s", cl.keyPath, err.Error()) + continue + } + + cl.certMu.Lock() + cl.cert = &cert + cl.certMu.Unlock() + + cl.log.Log(logger.Info, "certificate reloaded after change to %s", cl.keyPath) + case <-cl.done: + return + } + } +} diff --git a/internal/certloader/certloader_test.go b/internal/certloader/certloader_test.go new file mode 100644 index 00000000000..57819a4aa6e --- /dev/null +++ b/internal/certloader/certloader_test.go @@ -0,0 +1,52 @@ +package certloader + +import ( + "crypto/tls" + "os" + "testing" + "time" + + "github.com/bluenviron/mediamtx/internal/test" + "github.com/stretchr/testify/require" +) + +func TestCertReload(t *testing.T) { + testData, err := tls.X509KeyPair(test.TLSCertPub, test.TLSCertKey) + require.NoError(t, err) + + serverCertPath, err := test.CreateTempFile(test.TLSCertPub) + require.NoError(t, err) + defer os.Remove(serverCertPath) + + serverKeyPath, err := test.CreateTempFile(test.TLSCertKey) + require.NoError(t, err) + defer os.Remove(serverKeyPath) + + loader, err := New(serverCertPath, serverKeyPath, test.NilLogger) + require.NoError(t, err) + defer loader.Close() + + getCert := loader.GetCertificate() + require.NotNil(t, getCert) + + cert, err := getCert(nil) + require.NoError(t, err) + require.NotNil(t, cert) + require.Equal(t, &testData, cert) + + testData, err = tls.X509KeyPair(test.TLSCertPubAlt, test.TLSCertKeyAlt) + require.NoError(t, err) + + err = os.WriteFile(serverCertPath, test.TLSCertPubAlt, 0o644) + require.NoError(t, err) + + err = os.WriteFile(serverKeyPath, test.TLSCertKeyAlt, 0o644) + require.NoError(t, err) + + time.Sleep(1 * time.Second) + + cert, err = getCert(nil) + require.NoError(t, err) + require.NotNil(t, cert) + require.Equal(t, &testData, cert) +} diff --git a/internal/protocols/httpp/wrapped_server.go b/internal/protocols/httpp/wrapped_server.go index 86579ee5abb..71a15f004c3 100644 --- a/internal/protocols/httpp/wrapped_server.go +++ b/internal/protocols/httpp/wrapped_server.go @@ -10,6 +10,7 @@ import ( "net/http" "time" + "github.com/bluenviron/mediamtx/internal/certloader" "github.com/bluenviron/mediamtx/internal/logger" ) @@ -36,8 +37,9 @@ type WrappedServer struct { Handler http.Handler Parent logger.Writer - ln net.Listener - inner *http.Server + ln net.Listener + inner *http.Server + loader *certloader.CertLoader } // Initialize initializes a WrappedServer. @@ -47,13 +49,15 @@ func (s *WrappedServer) Initialize() error { if s.ServerCert == "" { return fmt.Errorf("server cert is missing") } - crt, err := tls.LoadX509KeyPair(s.ServerCert, s.ServerKey) + + var err error + s.loader, err = certloader.New(s.ServerCert, s.ServerKey, s.Parent) if err != nil { return err } tlsConfig = &tls.Config{ - Certificates: []tls.Certificate{crt}, + GetCertificate: s.loader.GetCertificate(), } } @@ -92,4 +96,7 @@ func (s *WrappedServer) Close() { ctxCancel() s.inner.Shutdown(ctx) s.ln.Close() // in case Shutdown() is called before Serve() + if s.loader != nil { + s.loader.Close() + } } diff --git a/internal/servers/rtmp/server.go b/internal/servers/rtmp/server.go index 8f492805618..574f4ded61e 100644 --- a/internal/servers/rtmp/server.go +++ b/internal/servers/rtmp/server.go @@ -12,6 +12,7 @@ import ( "github.com/google/uuid" + "github.com/bluenviron/mediamtx/internal/certloader" "github.com/bluenviron/mediamtx/internal/conf" "github.com/bluenviron/mediamtx/internal/defs" "github.com/bluenviron/mediamtx/internal/externalcmd" @@ -82,6 +83,7 @@ type Server struct { wg sync.WaitGroup ln net.Listener conns map[*conn]struct{} + loader *certloader.CertLoader // in chNewConn chan net.Conn @@ -99,13 +101,14 @@ func (s *Server) Initialize() error { return net.Listen(restrictnetwork.Restrict("tcp", s.Address)) } - cert, err := tls.LoadX509KeyPair(s.ServerCert, s.ServerKey) + var err error + s.loader, err = certloader.New(s.ServerCert, s.ServerKey, s.Parent) if err != nil { return nil, err } network, address := restrictnetwork.Restrict("tcp", s.Address) - return tls.Listen(network, address, &tls.Config{Certificates: []tls.Certificate{cert}}) + return tls.Listen(network, address, &tls.Config{GetCertificate: s.loader.GetCertificate()}) }() if err != nil { return err @@ -153,6 +156,9 @@ func (s *Server) Close() { s.Log(logger.Info, "listener is closing") s.ctxCancel() s.wg.Wait() + if s.loader != nil { + s.loader.Close() + } } func (s *Server) run() { diff --git a/internal/servers/rtsp/server.go b/internal/servers/rtsp/server.go index db0f1b4612c..a36ce12b4aa 100644 --- a/internal/servers/rtsp/server.go +++ b/internal/servers/rtsp/server.go @@ -17,6 +17,7 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/liberrors" "github.com/google/uuid" + "github.com/bluenviron/mediamtx/internal/certloader" "github.com/bluenviron/mediamtx/internal/conf" "github.com/bluenviron/mediamtx/internal/defs" "github.com/bluenviron/mediamtx/internal/externalcmd" @@ -89,6 +90,7 @@ type Server struct { mutex sync.RWMutex conns map[*gortsplib.ServerConn]*conn sessions map[*gortsplib.ServerSession]*session + loader *certloader.CertLoader } // Initialize initializes the server. @@ -118,12 +120,13 @@ func (s *Server) Initialize() error { } if s.IsTLS { - cert, err := tls.LoadX509KeyPair(s.ServerCert, s.ServerKey) + var err error + s.loader, err = certloader.New(s.ServerCert, s.ServerKey, s.Parent) if err != nil { return err } - s.srv.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}} + s.srv.TLSConfig = &tls.Config{GetCertificate: s.loader.GetCertificate()} } err := s.srv.Start() @@ -155,6 +158,9 @@ func (s *Server) Close() { s.Log(logger.Info, "listener is closing") s.ctxCancel() s.wg.Wait() + if s.loader != nil { + s.loader.Close() + } } func (s *Server) run() { diff --git a/internal/test/tls_cert.go b/internal/test/tls_cert.go index e110dc4d4ae..a1740dd3304 100644 --- a/internal/test/tls_cert.go +++ b/internal/test/tls_cert.go @@ -53,3 +53,55 @@ y++U32uuSFiXDcSLarfIsE992MEJLSAynbF1Rsgsr3gXbGiuToJRyxbIeVy7gwzD +3K6cnKEyg+0ekYmLertRFIY6SwWmY1fyKgTvxudMcsBY7dC4xs= -----END RSA PRIVATE KEY----- `) + +// TLSCertPubAlt is the public key of an alternative test certificate. +var TLSCertPubAlt = []byte(`-----BEGIN CERTIFICATE----- +MIIDSTCCAjECFEut6ZxIOnbxi3bhrPLfPQZCLReNMA0GCSqGSIb3DQEBCwUAMGEx +CzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl +cm5ldCBXaWRnaXRzIFB0eSBMdGQxGjAYBgNVBAMMEW1lZGlhbXR4LnRlc3QuY29t +MB4XDTI0MDgwMTIzNDY0MloXDTM0MDczMDIzNDY0MlowYTELMAkGA1UEBhMCQVUx +EzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMg +UHR5IEx0ZDEaMBgGA1UEAwwRbWVkaWFtdHgudGVzdC5jb20wggEiMA0GCSqGSIb3 +DQEBAQUAA4IBDwAwggEKAoIBAQCzfvG9eLXKSTDBoM+cgV/ThiNRI2JY6dpQV8rK +QFQ5bkkDUDP+2Ae/IWylgLLXmozsMwjz1Pu42awmGymBuo5HDbI4bxPJNQR9qRrR +2+MvfDgmZxyhw5NfZDlVl+enxhb3FRgbHsLBy4oSoHbRUdLApVdM0Kg6r3bXzkih +EEs63boFJOkPhs5H0NX7AzXyBp2WnvB71j+7avnMwAsjJHOiTs8wkp5wvRcIZpJl +MCandUkcZShMirug7QOcR9fAr5CVKxsO/DjqEjwkslJHFfizOl3yRx6nsxvW8JUd +dforpSRj84dkHTi7k37YTiji90GsOvh0qc0MfAmeE181HIb/AgMBAAEwDQYJKoZI +hvcNAQELBQADggEBAEWkLL/7nvt3iD7BVJNHLvAS6GwuTH99vCil6TFYwVl4goht +Dur7YfzN43vUq+lAwS3Ry4ka7tH72pAMkpNFRvHOikWGmWUSDo2DcLd8iu3ruLF7 +yUg2ASQuekK0sUv4YKpAqV8gS2R4Jh4vLU+8L5iJ1XWGELbQ+H5wm4l7l+r2X6cD +/opmdV8Slfi0FlNQtflLsGoSlfZF5jHxqi3zyt8QdEf9WZt8e6JPxcx2Fq7Op51u +Qx9nosr5fLwhkx46+B/cotsbI/xPDjLF6RQ1OUpcHwg1HI6czoW4hHn33S0zstCf +BWt5Q1Mb2tGInbmbUgw3wUu/4nWoY+Mq4DKPlKs= +-----END CERTIFICATE-----`) + +// TLSCertKeyAlt is the private key of an alternative test certificate. +var TLSCertKeyAlt = []byte(`-----BEGIN RSA PRIVATE KEY----- +MIIEoQIBAAKCAQEAs37xvXi1ykkwwaDPnIFf04YjUSNiWOnaUFfKykBUOW5JA1Az +/tgHvyFspYCy15qM7DMI89T7uNmsJhspgbqORw2yOG8TyTUEfaka0dvjL3w4Jmcc +ocOTX2Q5VZfnp8YW9xUYGx7CwcuKEqB20VHSwKVXTNCoOq92185IoRBLOt26BSTp +D4bOR9DV+wM18gadlp7we9Y/u2r5zMALIyRzok7PMJKecL0XCGaSZTAmp3VJHGUo +TIq7oO0DnEfXwK+QlSsbDvw46hI8JLJSRxX4szpd8kcep7Mb1vCVHXX6K6UkY/OH +ZB04u5N+2E4o4vdBrDr4dKnNDHwJnhNfNRyG/wIDAQABAoH/WmCqV6Lv5dEnofCj +ZUO/Fdv0hf/LBS0g2SAoFRSCIM8aJ3dUUH0PaXoeINDGCMlIxT7tKXJg5jJNYhWx +g7oegw6vLe5ZiA+p5miL/uue+Jas4kLVp9DrfQLgQevt0gw4g/00pgy9adbFlTUD +a2HhPB7RIvXs8gYA6nVAT9jK1ST2pbeUgQNO4Ji4EjpPUkR2O7ISOlu5EV8Cj0eV +1Vs5B92Z7ORh7P2fFV2YBu+igd04+uYvei6slQl+F9cETvJv2Z9r37Yashvnn1in +uy/u1U4B1t4oOz81nHz6kxTixPpBOdJ6x8jLDgNGSsauJQfXT9xmB/rAr/NFq+7I +tbTNAoGBAMOgm3XXHWokmJnX9pfNj6ixNlrMuuez/yXMVwuxa2WFwAFN16tjJhBi +XOjestcvu/SRhOAMmYac5QdopJpLjO/FxO165r73eZhW/SJefyOHtfD29kHagA1u +JjcznU6tiA0O1owy6nuuaTfyVbDQj32PhVBx9ZwSI4778GFbjWl7AoGBAOrj4WCC +gTMaExpwNo+L+3VkM79YD1Obl13FcgtVoxjcoWjQeMx9D0k7adTV3xlchHFAjiD5 +Gs/MZl8+seq+GDX3mODsmJkdRQbYId4g6IesiOnQ3Ug/Y282WZRnpB5h/BMnrcCZ +VoohnATA7f96c7XtPUgZyROmh24T7UIVwVdNAoGAbeeGT276TI6g2RWWqXRIOFrP +EbYhb1kViFPDt4MGtjOtSk5EUzpRwTSxw/aRfQmJS/6RKxqJCjKNDVuB1lmJpY9z +coPwrOr1+lssvalfPkPZOLZWZWrvNBxlBfBOeUxOuh9S89MLH08+N7tC3yJc6wq9 +uBM+DF+4cHUkeF3qFY8CgYBzS+IwBj82/0CLRLNzaKnIqKPB846qYoA9NhLRv3ps +VLgiA9qXvXdIYhKDt2toPoKAOMjLJJtljpZdgB/C8wZdTyjKlzgcSEK+pk6RgyPA +nQ8jfjNwKDU9vLbh4rGrfDtIh7yBAoN5ECBOMQlh0xCDJ21iO834iFCH1t4qBxW9 +LQKBgQC36adC2Gu+FJRvx4Mkm73fLmVdFbP6Do7qNwyVVyaG80PDVrFQrlWm4Dt7 +AO9IwzaS1Lx+qmU1Fj1WfCtXuQa5nc9AzZ36TmM6+pAn8AC7PdNqc0qSdefVrIjj +zRGhUPaJV3A+sfO+xedBsAFnqNuX9oODYVGbTjuc2OWC30MGaw== +-----END RSA PRIVATE KEY----- +`)