diff --git a/cmd/yopass-server/main.go b/cmd/yopass-server/main.go index 31a7ff3e9..a23ced580 100644 --- a/cmd/yopass-server/main.go +++ b/cmd/yopass-server/main.go @@ -30,6 +30,7 @@ func init() { pflag.String("address", "", "listen address (default 0.0.0.0)") pflag.Int("port", 1337, "listen port") pflag.String("database", "memcached", "database backend ('memcached' or 'redis')") + pflag.String("asset-path", "public", "path to the assets folder (default 'public')") pflag.Int("max-length", 10000, "max length of encrypted secret") pflag.String("memcached", "localhost:11211", "memcached address") pflag.Int("metrics-port", -1, "metrics server listen port") @@ -56,9 +57,10 @@ func main() { cert := viper.GetString("tls-cert") key := viper.GetString("tls-key") + assetPath := viper.GetString("asset-path") quit := make(chan os.Signal, 1) - y := server.New(db, viper.GetInt("max-length"), registry, viper.GetBool("force-onetime-secrets"), logger) + y := server.New(db, assetPath, viper.GetInt("max-length"), registry, viper.GetBool("force-onetime-secrets"), logger) yopassSrv := &http.Server{ Addr: fmt.Sprintf("%s:%d", viper.GetString("address"), viper.GetInt("port")), Handler: y.HTTPHandler(), @@ -66,6 +68,7 @@ func main() { } go func() { logger.Info("Starting yopass server", zap.String("address", yopassSrv.Addr)) + logger.Info("Loading assets from: ", zap.String("asset-path", assetPath)) err := listenAndServe(yopassSrv, cert, key) if !errors.Is(err, http.ErrServerClosed) { logger.Fatal("yopass stopped unexpectedly", zap.Error(err)) diff --git a/cmd/yopass/main_test.go b/cmd/yopass/main_test.go index a85cca229..29e2bfcac 100644 --- a/cmd/yopass/main_test.go +++ b/cmd/yopass/main_test.go @@ -2,7 +2,6 @@ package main import ( "bytes" - "io/ioutil" "net/http" "os" "strings" @@ -270,7 +269,7 @@ func pingDemoServer() bool { } func tempFile(s string) (*os.File, error) { - f, err := ioutil.TempFile("", "yopass-") + f, err := os.CreateTemp("", "yopass-") if err != nil { return nil, err } diff --git a/pkg/server/server.go b/pkg/server/server.go index 874c837dd..f4e912a34 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -19,6 +19,7 @@ import ( // This should be created with server.New type Server struct { db Database + assetPath string maxLength int registry *prometheus.Registry forceOneTimeSecrets bool @@ -26,12 +27,13 @@ type Server struct { } // New is the main way of creating the server. -func New(db Database, maxLength int, r *prometheus.Registry, forceOneTimeSecrets bool, logger *zap.Logger) Server { +func New(db Database, assetPath string, maxLength int, r *prometheus.Registry, forceOneTimeSecrets bool, logger *zap.Logger) Server { if logger == nil { logger = zap.NewNop() } return Server{ db: db, + assetPath: assetPath, maxLength: maxLength, registry: r, forceOneTimeSecrets: forceOneTimeSecrets, @@ -157,7 +159,7 @@ func (y *Server) HTTPHandler() http.Handler { mx.HandleFunc("/file/"+keyParameter, y.deleteSecret).Methods(http.MethodDelete) mx.HandleFunc("/file/"+keyParameter, y.optionsSecret).Methods(http.MethodOptions) - mx.PathPrefix("/").Handler(http.FileServer(http.Dir("public"))) + mx.PathPrefix("/").Handler(http.FileServer(http.Dir(y.assetPath))) return handlers.CustomLoggingHandler(nil, SecurityHeadersHandler(mx), httpLogFormatter(y.logger)) } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 63b7e4323..e863d88bb 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -104,7 +104,7 @@ func TestCreateSecret(t *testing.T) { t.Run(fmt.Sprintf(tc.name), func(t *testing.T) { req, _ := http.NewRequest("POST", "/secret", tc.body) rr := httptest.NewRecorder() - y := New(tc.db, tc.maxLength, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) + y := New(tc.db, "public", tc.maxLength, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) y.createSecret(rr, req) var s yopass.Secret json.Unmarshal(rr.Body.Bytes(), &s) @@ -161,7 +161,7 @@ func TestOneTimeEnforcement(t *testing.T) { t.Run(fmt.Sprintf(tc.name), func(t *testing.T) { req, _ := http.NewRequest("POST", "/secret", tc.body) rr := httptest.NewRecorder() - y := New(&mockDB{}, 100, prometheus.NewRegistry(), tc.requireOneTime, zaptest.NewLogger(t)) + y := New(&mockDB{}, "public", 100, prometheus.NewRegistry(), tc.requireOneTime, zaptest.NewLogger(t)) y.createSecret(rr, req) var s yopass.Secret json.Unmarshal(rr.Body.Bytes(), &s) @@ -205,7 +205,7 @@ func TestGetSecret(t *testing.T) { t.Fatal(err) } rr := httptest.NewRecorder() - y := New(tc.db, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) + y := New(tc.db, "public", 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) y.getSecret(rr, req) cacheControl := rr.Header().Get("Cache-Control") if cacheControl != "private, no-cache" { @@ -256,7 +256,7 @@ func TestDeleteSecret(t *testing.T) { t.Fatal(err) } rr := httptest.NewRecorder() - y := New(tc.db, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) + y := New(tc.db, "public", 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) y.deleteSecret(rr, req) var s struct { Message string `json:"message"` @@ -286,7 +286,7 @@ func TestMetrics(t *testing.T) { path: "/secret/invalid-key-format", }, } - y := New(&mockDB{}, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) + y := New(&mockDB{}, "public", 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) h := y.HTTPHandler() for _, r := range requests { @@ -359,7 +359,7 @@ func TestSecurityHeaders(t *testing.T) { }, } - y := New(&mockDB{}, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) + y := New(&mockDB{}, "public", 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) h := y.HTTPHandler() t.Parallel() diff --git a/pkg/yopass/client_test.go b/pkg/yopass/client_test.go index 8bfa89564..613d4e197 100644 --- a/pkg/yopass/client_test.go +++ b/pkg/yopass/client_test.go @@ -14,7 +14,7 @@ import ( func TestFetch(t *testing.T) { db := testDB(map[string]string{}) - y := server.New(&db, 1024, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) + y := server.New(&db, "public", 1024, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) ts := httptest.NewServer(y.HTTPHandler()) defer ts.Close() @@ -46,7 +46,7 @@ func TestFetchInvalidServer(t *testing.T) { } func TestStore(t *testing.T) { db := testDB(map[string]string{}) - y := server.New(&db, 1024, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) + y := server.New(&db, "public", 1024, prometheus.NewRegistry(), false, zaptest.NewLogger(t)) ts := httptest.NewServer(y.HTTPHandler()) defer ts.Close()