diff --git a/README.md b/README.md index c6a1e388f..1bec0c0a1 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,7 @@ $ yopass-server -h --redis string Redis URL (default "redis://localhost:6379/0") --tls-cert string path to TLS certificate --tls-key string path to TLS key + --cors allow an origin to interact with the API (defaults to self) ``` Encrypted secrets can be stored either in Memcached or Redis by changing the `--database` flag. diff --git a/cmd/yopass-server/main.go b/cmd/yopass-server/main.go index 0853b35e8..5095bf4bd 100644 --- a/cmd/yopass-server/main.go +++ b/cmd/yopass-server/main.go @@ -36,6 +36,7 @@ func init() { pflag.String("tls-cert", "", "path to TLS certificate") pflag.String("tls-key", "", "path to TLS key") pflag.Bool("force-onetime-secrets", false, "reject non onetime secrets from being created") + pflag.String("cors", "", "allow an origin to interact with the API (defaults to self)") pflag.CommandLine.AddGoFlag(&flag.Flag{Name: "log-level", Usage: "Log level", Value: &logLevel}) viper.AutomaticEnv() @@ -74,7 +75,7 @@ func main() { key := viper.GetString("tls-key") quit := make(chan os.Signal, 1) - y := server.New(db, viper.GetInt("max-length"), registry, viper.GetBool("force-onetime-secrets"), logger) + y := server.New(db, viper.GetInt("max-length"), registry, viper.GetBool("force-onetime-secrets"), logger, viper.GetString("cors")) yopassSrv := &http.Server{ Addr: fmt.Sprintf("%s:%d", viper.GetString("address"), viper.GetInt("port")), Handler: y.HTTPHandler(), diff --git a/pkg/server/server.go b/pkg/server/server.go index 874c837dd..e62324fe1 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -23,10 +23,11 @@ type Server struct { registry *prometheus.Registry forceOneTimeSecrets bool logger *zap.Logger + cors string } // New is the main way of creating the server. -func New(db Database, maxLength int, r *prometheus.Registry, forceOneTimeSecrets bool, logger *zap.Logger) Server { +func New(db Database, maxLength int, r *prometheus.Registry, forceOneTimeSecrets bool, logger *zap.Logger, cors string) Server { if logger == nil { logger = zap.NewNop() } @@ -36,12 +37,13 @@ func New(db Database, maxLength int, r *prometheus.Registry, forceOneTimeSecrets registry: r, forceOneTimeSecrets: forceOneTimeSecrets, logger: logger, + cors: cors, } } // createSecret creates secret func (y *Server) createSecret(w http.ResponseWriter, request *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Origin", y.cors) decoder := json.NewDecoder(request.Body) var s yopass.Secret @@ -95,7 +97,7 @@ func (y *Server) createSecret(w http.ResponseWriter, request *http.Request) { // getSecret from database func (y *Server) getSecret(w http.ResponseWriter, request *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Origin", y.cors) w.Header().Set("Cache-Control", "private, no-cache") secretKey := mux.Vars(request)["key"] @@ -120,7 +122,7 @@ func (y *Server) getSecret(w http.ResponseWriter, request *http.Request) { // deleteSecret from database func (y *Server) deleteSecret(w http.ResponseWriter, request *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Origin", y.cors) deleted, err := y.db.Delete(mux.Vars(request)["key"]) if err != nil { @@ -138,8 +140,9 @@ func (y *Server) deleteSecret(w http.ResponseWriter, request *http.Request) { // optionsSecret handle the Options http method by returning the correct CORS headers func (y *Server) optionsSecret(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", strings.Join([]string{http.MethodGet, http.MethodDelete, http.MethodOptions}, ",")) + w.Header().Set("Access-Control-Allow-Origin", y.cors) + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + w.Header().Set("Access-Control-Allow-Methods", strings.Join([]string{http.MethodGet, http.MethodPost, http.MethodDelete, http.MethodOptions}, ",")) } // HTTPHandler containing all routes @@ -148,15 +151,19 @@ func (y *Server) HTTPHandler() http.Handler { mx.Use(newMetricsMiddleware(y.registry)) mx.HandleFunc("/secret", y.createSecret).Methods(http.MethodPost) + mx.HandleFunc("/secret", y.optionsSecret).Methods(http.MethodOptions) mx.HandleFunc("/secret/"+keyParameter, y.getSecret).Methods(http.MethodGet) mx.HandleFunc("/secret/"+keyParameter, y.deleteSecret).Methods(http.MethodDelete) mx.HandleFunc("/secret/"+keyParameter, y.optionsSecret).Methods(http.MethodOptions) mx.HandleFunc("/file", y.createSecret).Methods(http.MethodPost) + mx.HandleFunc("/file", y.optionsSecret).Methods(http.MethodOptions) mx.HandleFunc("/file/"+keyParameter, y.getSecret).Methods(http.MethodGet) mx.HandleFunc("/file/"+keyParameter, y.deleteSecret).Methods(http.MethodDelete) mx.HandleFunc("/file/"+keyParameter, y.optionsSecret).Methods(http.MethodOptions) + mx.Use(mux.CORSMethodMiddleware(mx)) + mx.PathPrefix("/").Handler(http.FileServer(http.Dir("public"))) return handlers.CustomLoggingHandler(nil, SecurityHeadersHandler(mx), httpLogFormatter(y.logger)) } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 63b7e4323..203fd035e 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -104,7 +104,7 @@ func TestCreateSecret(t *testing.T) { t.Run(fmt.Sprintf(tc.name), func(t *testing.T) { req, _ := http.NewRequest("POST", "/secret", tc.body) rr := httptest.NewRecorder() - y := New(tc.db, tc.maxLength, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) + y := New(tc.db, tc.maxLength, prometheus.NewRegistry(), false, zaptest.NewLogger(t), "") y.createSecret(rr, req) var s yopass.Secret json.Unmarshal(rr.Body.Bytes(), &s) @@ -161,7 +161,7 @@ func TestOneTimeEnforcement(t *testing.T) { t.Run(fmt.Sprintf(tc.name), func(t *testing.T) { req, _ := http.NewRequest("POST", "/secret", tc.body) rr := httptest.NewRecorder() - y := New(&mockDB{}, 100, prometheus.NewRegistry(), tc.requireOneTime, zaptest.NewLogger(t)) + y := New(&mockDB{}, 100, prometheus.NewRegistry(), tc.requireOneTime, zaptest.NewLogger(t), "") y.createSecret(rr, req) var s yopass.Secret json.Unmarshal(rr.Body.Bytes(), &s) @@ -205,7 +205,7 @@ func TestGetSecret(t *testing.T) { t.Fatal(err) } rr := httptest.NewRecorder() - y := New(tc.db, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) + y := New(tc.db, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t), "") y.getSecret(rr, req) cacheControl := rr.Header().Get("Cache-Control") if cacheControl != "private, no-cache" { @@ -256,7 +256,7 @@ func TestDeleteSecret(t *testing.T) { t.Fatal(err) } rr := httptest.NewRecorder() - y := New(tc.db, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) + y := New(tc.db, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t), "") y.deleteSecret(rr, req) var s struct { Message string `json:"message"` @@ -286,7 +286,7 @@ func TestMetrics(t *testing.T) { path: "/secret/invalid-key-format", }, } - y := New(&mockDB{}, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) + y := New(&mockDB{}, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t), "") h := y.HTTPHandler() for _, r := range requests { @@ -359,7 +359,7 @@ func TestSecurityHeaders(t *testing.T) { }, } - y := New(&mockDB{}, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) + y := New(&mockDB{}, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t), "") h := y.HTTPHandler() t.Parallel() diff --git a/pkg/yopass/client_test.go b/pkg/yopass/client_test.go index 8bfa89564..5dc85d8b0 100644 --- a/pkg/yopass/client_test.go +++ b/pkg/yopass/client_test.go @@ -3,10 +3,11 @@ package yopass_test import ( "errors" "fmt" - "go.uber.org/zap/zaptest" "net/http/httptest" "testing" + "go.uber.org/zap/zaptest" + "github.com/jhaals/yopass/pkg/server" "github.com/jhaals/yopass/pkg/yopass" "github.com/prometheus/client_golang/prometheus" @@ -14,7 +15,7 @@ import ( func TestFetch(t *testing.T) { db := testDB(map[string]string{}) - y := server.New(&db, 1024, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) + y := server.New(&db, 1024, prometheus.NewRegistry(), false, zaptest.NewLogger(t), "") ts := httptest.NewServer(y.HTTPHandler()) defer ts.Close() @@ -46,7 +47,7 @@ func TestFetchInvalidServer(t *testing.T) { } func TestStore(t *testing.T) { db := testDB(map[string]string{}) - y := server.New(&db, 1024, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) + y := server.New(&db, 1024, prometheus.NewRegistry(), false, zaptest.NewLogger(t), "") ts := httptest.NewServer(y.HTTPHandler()) defer ts.Close()