From 3f1d182d2c6ad18ec711940bebb8bf6d60de61cb Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Sat, 6 Jul 2024 21:45:15 +0200 Subject: [PATCH] fix support for HTTP preflight requests (#1792) (#3535) --- internal/api/api.go | 9 +++++ internal/api/api_test.go | 37 +++++++++++++++++++ internal/metrics/metrics.go | 9 +++++ internal/metrics/metrics_test.go | 49 ++++++++++++++++++++++++++ internal/playback/server.go | 9 +++++ internal/playback/server_test.go | 48 +++++++++++++++++++++++++ internal/pprof/pprof.go | 9 +++++ internal/pprof/pprof_test.go | 48 +++++++++++++++++++++++++ internal/servers/hls/http_server.go | 8 +++-- internal/servers/hls/server_test.go | 37 +++++++++++++++++++ internal/servers/webrtc/server_test.go | 20 +++++++---- 11 files changed, 273 insertions(+), 10 deletions(-) create mode 100644 internal/metrics/metrics_test.go create mode 100644 internal/playback/server_test.go create mode 100644 internal/pprof/pprof_test.go diff --git a/internal/api/api.go b/internal/api/api.go index 072ce0a584b..c0324624902 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -251,6 +251,15 @@ func (a *API) writeError(ctx *gin.Context, status int, err error) { func (a *API) middlewareOrigin(ctx *gin.Context) { ctx.Writer.Header().Set("Access-Control-Allow-Origin", a.AllowOrigin) ctx.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + + // preflight requests + if ctx.Request.Method == http.MethodOptions && + ctx.Request.Header.Get("Access-Control-Request-Method") != "" { + ctx.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET, POST, PATCH, DELETE") + ctx.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + ctx.AbortWithStatus(http.StatusNoContent) + return + } } func (a *API) middlewareAuth(ctx *gin.Context) { diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 9cddc2576fc..a2a6a5ecaf9 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -74,6 +74,43 @@ func checkError(t *testing.T, msg string, body io.Reader) { require.Equal(t, map[string]interface{}{"error": msg}, resErr) } +func TestPreflightRequest(t *testing.T) { + api := API{ + Address: "localhost:9997", + AllowOrigin: "*", + ReadTimeout: conf.StringDuration(10 * time.Second), + AuthManager: test.NilAuthManager, + Parent: &testParent{}, + } + err := api.Initialize() + require.NoError(t, err) + defer api.Close() + + tr := &http.Transport{} + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + req, err := http.NewRequest(http.MethodOptions, "http://localhost:9997", nil) + require.NoError(t, err) + + req.Header.Add("Access-Control-Request-Method", "GET") + + res, err := hc.Do(req) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusNoContent, res.StatusCode) + + byts, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Equal(t, "*", res.Header.Get("Access-Control-Allow-Origin")) + require.Equal(t, "true", res.Header.Get("Access-Control-Allow-Credentials")) + require.Equal(t, "OPTIONS, GET, POST, PATCH, DELETE", res.Header.Get("Access-Control-Allow-Methods")) + require.Equal(t, "Authorization, Content-Type", res.Header.Get("Access-Control-Allow-Headers")) + require.Equal(t, byts, []byte{}) +} + func TestConfigAuth(t *testing.T) { cnf := tempConf(t, "api: yes\n") diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 96c1a64523f..11a5af72ec6 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -107,6 +107,15 @@ func (m *Metrics) onRequest(ctx *gin.Context) { ctx.Writer.Header().Set("Access-Control-Allow-Origin", m.AllowOrigin) ctx.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + // preflight requests + if ctx.Request.Method == http.MethodOptions && + ctx.Request.Header.Get("Access-Control-Request-Method") != "" { + ctx.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET") + ctx.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization") + ctx.Writer.WriteHeader(http.StatusNoContent) + return + } + if ctx.Request.URL.Path != "/metrics" || ctx.Request.Method != http.MethodGet { return } diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go new file mode 100644 index 00000000000..7bd3099d980 --- /dev/null +++ b/internal/metrics/metrics_test.go @@ -0,0 +1,49 @@ +package metrics + +import ( + "io" + "net/http" + "testing" + "time" + + "github.com/bluenviron/mediamtx/internal/conf" + "github.com/bluenviron/mediamtx/internal/test" + "github.com/stretchr/testify/require" +) + +func TestPreflightRequest(t *testing.T) { + api := Metrics{ + Address: "localhost:9998", + AllowOrigin: "*", + ReadTimeout: conf.StringDuration(10 * time.Second), + AuthManager: test.NilAuthManager, + Parent: test.NilLogger, + } + err := api.Initialize() + require.NoError(t, err) + defer api.Close() + + tr := &http.Transport{} + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + req, err := http.NewRequest(http.MethodOptions, "http://localhost:9998", nil) + require.NoError(t, err) + + req.Header.Add("Access-Control-Request-Method", "GET") + + res, err := hc.Do(req) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusNoContent, res.StatusCode) + + byts, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Equal(t, "*", res.Header.Get("Access-Control-Allow-Origin")) + require.Equal(t, "true", res.Header.Get("Access-Control-Allow-Credentials")) + require.Equal(t, "OPTIONS, GET", res.Header.Get("Access-Control-Allow-Methods")) + require.Equal(t, "Authorization", res.Header.Get("Access-Control-Allow-Headers")) + require.Equal(t, byts, []byte{}) +} diff --git a/internal/playback/server.go b/internal/playback/server.go index abea96e7a6e..52086169708 100644 --- a/internal/playback/server.go +++ b/internal/playback/server.go @@ -109,6 +109,15 @@ func (s *Server) safeFindPathConf(name string) (*conf.Path, error) { func (s *Server) middlewareOrigin(ctx *gin.Context) { ctx.Writer.Header().Set("Access-Control-Allow-Origin", s.AllowOrigin) ctx.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + + // preflight requests + if ctx.Request.Method == http.MethodOptions && + ctx.Request.Header.Get("Access-Control-Request-Method") != "" { + ctx.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET") + ctx.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization") + ctx.AbortWithStatus(http.StatusNoContent) + return + } } func (s *Server) doAuth(ctx *gin.Context, pathName string) bool { diff --git a/internal/playback/server_test.go b/internal/playback/server_test.go new file mode 100644 index 00000000000..9fb905c657e --- /dev/null +++ b/internal/playback/server_test.go @@ -0,0 +1,48 @@ +package playback + +import ( + "io" + "net/http" + "testing" + "time" + + "github.com/bluenviron/mediamtx/internal/conf" + "github.com/bluenviron/mediamtx/internal/test" + "github.com/stretchr/testify/require" +) + +func TestPreflightRequest(t *testing.T) { + s := &Server{ + Address: "127.0.0.1:9996", + AllowOrigin: "*", + ReadTimeout: conf.StringDuration(10 * time.Second), + Parent: test.NilLogger, + } + err := s.Initialize() + require.NoError(t, err) + defer s.Close() + + tr := &http.Transport{} + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + req, err := http.NewRequest(http.MethodOptions, "http://localhost:9996", nil) + require.NoError(t, err) + + req.Header.Add("Access-Control-Request-Method", "GET") + + res, err := hc.Do(req) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusNoContent, res.StatusCode) + + byts, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Equal(t, "*", res.Header.Get("Access-Control-Allow-Origin")) + require.Equal(t, "true", res.Header.Get("Access-Control-Allow-Credentials")) + require.Equal(t, "OPTIONS, GET", res.Header.Get("Access-Control-Allow-Methods")) + require.Equal(t, "Authorization", res.Header.Get("Access-Control-Allow-Headers")) + require.Equal(t, byts, []byte{}) +} diff --git a/internal/pprof/pprof.go b/internal/pprof/pprof.go index ff6f2dc2499..2cb0164f85e 100644 --- a/internal/pprof/pprof.go +++ b/internal/pprof/pprof.go @@ -83,6 +83,15 @@ func (pp *PPROF) onRequest(ctx *gin.Context) { ctx.Writer.Header().Set("Access-Control-Allow-Origin", pp.AllowOrigin) ctx.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + // preflight requests + if ctx.Request.Method == http.MethodOptions && + ctx.Request.Header.Get("Access-Control-Request-Method") != "" { + ctx.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET") + ctx.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization") + ctx.Writer.WriteHeader(http.StatusNoContent) + return + } + user, pass, hasCredentials := ctx.Request.BasicAuth() err := pp.AuthManager.Authenticate(&auth.Request{ diff --git a/internal/pprof/pprof_test.go b/internal/pprof/pprof_test.go new file mode 100644 index 00000000000..d26de96514b --- /dev/null +++ b/internal/pprof/pprof_test.go @@ -0,0 +1,48 @@ +package pprof + +import ( + "io" + "net/http" + "testing" + "time" + + "github.com/bluenviron/mediamtx/internal/conf" + "github.com/bluenviron/mediamtx/internal/test" + "github.com/stretchr/testify/require" +) + +func TestPreflightRequest(t *testing.T) { + s := &PPROF{ + Address: "127.0.0.1:9999", + AllowOrigin: "*", + ReadTimeout: conf.StringDuration(10 * time.Second), + Parent: test.NilLogger, + } + err := s.Initialize() + require.NoError(t, err) + defer s.Close() + + tr := &http.Transport{} + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + req, err := http.NewRequest(http.MethodOptions, "http://localhost:9999", nil) + require.NoError(t, err) + + req.Header.Add("Access-Control-Request-Method", "GET") + + res, err := hc.Do(req) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusNoContent, res.StatusCode) + + byts, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Equal(t, "*", res.Header.Get("Access-Control-Allow-Origin")) + require.Equal(t, "true", res.Header.Get("Access-Control-Allow-Credentials")) + require.Equal(t, "OPTIONS, GET", res.Header.Get("Access-Control-Allow-Methods")) + require.Equal(t, "Authorization", res.Header.Get("Access-Control-Allow-Headers")) + require.Equal(t, byts, []byte{}) +} diff --git a/internal/servers/hls/http_server.go b/internal/servers/hls/http_server.go index 752b5c2a9dc..882e3443307 100644 --- a/internal/servers/hls/http_server.go +++ b/internal/servers/hls/http_server.go @@ -102,9 +102,11 @@ func (s *httpServer) onRequest(ctx *gin.Context) { switch ctx.Request.Method { case http.MethodOptions: - ctx.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET") - ctx.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Range") - ctx.Writer.WriteHeader(http.StatusNoContent) + if ctx.Request.Header.Get("Access-Control-Request-Method") != "" { + ctx.Writer.Header().Set("Access-Control-Allow-Methods", "OPTIONS, GET") + ctx.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Range") + ctx.Writer.WriteHeader(http.StatusNoContent) + } return case http.MethodGet: diff --git a/internal/servers/hls/server_test.go b/internal/servers/hls/server_test.go index 76ef901d7a2..b014702b038 100644 --- a/internal/servers/hls/server_test.go +++ b/internal/servers/hls/server_test.go @@ -2,6 +2,7 @@ package hls import ( "fmt" + "io" "net/http" "os" "path/filepath" @@ -60,6 +61,42 @@ func (pm *dummyPathManager) AddReader(req defs.PathAddReaderReq) (defs.Path, *st return pm.addReader(req) } +func TestPreflightRequest(t *testing.T) { + s := &Server{ + Address: "127.0.0.1:8888", + AllowOrigin: "*", + ReadTimeout: conf.StringDuration(10 * time.Second), + Parent: test.NilLogger, + } + err := s.Initialize() + require.NoError(t, err) + defer s.Close() + + tr := &http.Transport{} + defer tr.CloseIdleConnections() + hc := &http.Client{Transport: tr} + + req, err := http.NewRequest(http.MethodOptions, "http://localhost:8888", nil) + require.NoError(t, err) + + req.Header.Add("Access-Control-Request-Method", "GET") + + res, err := hc.Do(req) + require.NoError(t, err) + defer res.Body.Close() + + require.Equal(t, http.StatusNoContent, res.StatusCode) + + byts, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Equal(t, "*", res.Header.Get("Access-Control-Allow-Origin")) + require.Equal(t, "true", res.Header.Get("Access-Control-Allow-Credentials")) + require.Equal(t, "OPTIONS, GET", res.Header.Get("Access-Control-Allow-Methods")) + require.Equal(t, "Authorization, Range", res.Header.Get("Access-Control-Allow-Headers")) + require.Equal(t, byts, []byte{}) +} + func TestServerNotFound(t *testing.T) { for _, ca := range []string{ "always remux off", diff --git a/internal/servers/webrtc/server_test.go b/internal/servers/webrtc/server_test.go index 0fcbf124f49..6a05433c1ff 100644 --- a/internal/servers/webrtc/server_test.go +++ b/internal/servers/webrtc/server_test.go @@ -3,6 +3,7 @@ package webrtc import ( "bytes" "context" + "io" "net/http" "net/url" "reflect" @@ -102,7 +103,7 @@ func initializeTestServer(t *testing.T) *Server { Encryption: false, ServerKey: "", ServerCert: "", - AllowOrigin: "", + AllowOrigin: "*", TrustedProxies: conf.IPNetworks{}, ReadTimeout: conf.StringDuration(10 * time.Second), WriteQueueSize: 512, @@ -146,7 +147,7 @@ func TestServerStaticPages(t *testing.T) { } } -func TestServerOptionsPreflight(t *testing.T) { +func TestPreflightRequest(t *testing.T) { s := initializeTestServer(t) defer s.Close() @@ -154,11 +155,10 @@ func TestServerOptionsPreflight(t *testing.T) { defer tr.CloseIdleConnections() hc := &http.Client{Transport: tr} - // preflight requests must always work, without authentication - req, err := http.NewRequest(http.MethodOptions, "http://localhost:8886/teststream/whip", nil) + req, err := http.NewRequest(http.MethodOptions, "http://localhost:8886", nil) require.NoError(t, err) - req.Header.Set("Access-Control-Request-Method", "OPTIONS") + req.Header.Add("Access-Control-Request-Method", "GET") res, err := hc.Do(req) require.NoError(t, err) @@ -166,8 +166,14 @@ func TestServerOptionsPreflight(t *testing.T) { require.Equal(t, http.StatusNoContent, res.StatusCode) - _, ok := res.Header["Link"] - require.Equal(t, false, ok) + byts, err := io.ReadAll(res.Body) + require.NoError(t, err) + + require.Equal(t, "*", res.Header.Get("Access-Control-Allow-Origin")) + require.Equal(t, "true", res.Header.Get("Access-Control-Allow-Credentials")) + require.Equal(t, "OPTIONS, GET, POST, PATCH, DELETE", res.Header.Get("Access-Control-Allow-Methods")) + require.Equal(t, "Authorization, Content-Type, If-Match", res.Header.Get("Access-Control-Allow-Headers")) + require.Equal(t, byts, []byte{}) } func TestServerOptionsICEServer(t *testing.T) {