Skip to content

Commit

Permalink
Allow configuration of asset path.
Browse files Browse the repository at this point in the history
  • Loading branch information
fraggerfox committed Dec 25, 2024
1 parent 27775a5 commit c6a688f
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 13 deletions.
5 changes: 4 additions & 1 deletion cmd/yopass-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func init() {
pflag.String("address", "", "listen address (default 0.0.0.0)")
pflag.Int("port", 1337, "listen port")
pflag.String("database", "memcached", "database backend ('memcached' or 'redis')")
pflag.String("asset-path", "public", "path to the assets folder (default 'public')")
pflag.Int("max-length", 10000, "max length of encrypted secret")
pflag.String("memcached", "localhost:11211", "memcached address")
pflag.Int("metrics-port", -1, "metrics server listen port")
Expand All @@ -56,16 +57,18 @@ func main() {

cert := viper.GetString("tls-cert")
key := viper.GetString("tls-key")
assetPath := viper.GetString("asset-path")
quit := make(chan os.Signal, 1)

y := server.New(db, viper.GetInt("max-length"), registry, viper.GetBool("force-onetime-secrets"), logger)
y := server.New(db, assetPath, viper.GetInt("max-length"), registry, viper.GetBool("force-onetime-secrets"), logger)
yopassSrv := &http.Server{
Addr: fmt.Sprintf("%s:%d", viper.GetString("address"), viper.GetInt("port")),
Handler: y.HTTPHandler(),
TLSConfig: &tls.Config{MinVersion: tls.VersionTLS12},
}
go func() {
logger.Info("Starting yopass server", zap.String("address", yopassSrv.Addr))
logger.Info("Loading assets from: ", zap.String("asset-path", assetPath))
err := listenAndServe(yopassSrv, cert, key)
if !errors.Is(err, http.ErrServerClosed) {
logger.Fatal("yopass stopped unexpectedly", zap.Error(err))
Expand Down
3 changes: 1 addition & 2 deletions cmd/yopass/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"bytes"
"io/ioutil"
"net/http"
"os"
"strings"
Expand Down Expand Up @@ -270,7 +269,7 @@ func pingDemoServer() bool {
}

func tempFile(s string) (*os.File, error) {
f, err := ioutil.TempFile("", "yopass-")
f, err := os.CreateTemp("", "yopass-")
if err != nil {
return nil, err
}
Expand Down
6 changes: 4 additions & 2 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@ import (
// This should be created with server.New
type Server struct {
db Database
assetPath string
maxLength int
registry *prometheus.Registry
forceOneTimeSecrets bool
logger *zap.Logger
}

// New is the main way of creating the server.
func New(db Database, maxLength int, r *prometheus.Registry, forceOneTimeSecrets bool, logger *zap.Logger) Server {
func New(db Database, assetPath string, maxLength int, r *prometheus.Registry, forceOneTimeSecrets bool, logger *zap.Logger) Server {
if logger == nil {
logger = zap.NewNop()
}
return Server{
db: db,
assetPath: assetPath,
maxLength: maxLength,
registry: r,
forceOneTimeSecrets: forceOneTimeSecrets,
Expand Down Expand Up @@ -157,7 +159,7 @@ func (y *Server) HTTPHandler() http.Handler {
mx.HandleFunc("/file/"+keyParameter, y.deleteSecret).Methods(http.MethodDelete)
mx.HandleFunc("/file/"+keyParameter, y.optionsSecret).Methods(http.MethodOptions)

mx.PathPrefix("/").Handler(http.FileServer(http.Dir("public")))
mx.PathPrefix("/").Handler(http.FileServer(http.Dir(y.assetPath)))
return handlers.CustomLoggingHandler(nil, SecurityHeadersHandler(mx), httpLogFormatter(y.logger))
}

Expand Down
12 changes: 6 additions & 6 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestCreateSecret(t *testing.T) {
t.Run(fmt.Sprintf(tc.name), func(t *testing.T) {
req, _ := http.NewRequest("POST", "/secret", tc.body)
rr := httptest.NewRecorder()
y := New(tc.db, tc.maxLength, prometheus.NewRegistry(), false, zaptest.NewLogger(t))
y := New(tc.db, "public", tc.maxLength, prometheus.NewRegistry(), false, zaptest.NewLogger(t))
y.createSecret(rr, req)
var s yopass.Secret
json.Unmarshal(rr.Body.Bytes(), &s)
Expand Down Expand Up @@ -161,7 +161,7 @@ func TestOneTimeEnforcement(t *testing.T) {
t.Run(fmt.Sprintf(tc.name), func(t *testing.T) {
req, _ := http.NewRequest("POST", "/secret", tc.body)
rr := httptest.NewRecorder()
y := New(&mockDB{}, 100, prometheus.NewRegistry(), tc.requireOneTime, zaptest.NewLogger(t))
y := New(&mockDB{}, "public", 100, prometheus.NewRegistry(), tc.requireOneTime, zaptest.NewLogger(t))
y.createSecret(rr, req)
var s yopass.Secret
json.Unmarshal(rr.Body.Bytes(), &s)
Expand Down Expand Up @@ -205,7 +205,7 @@ func TestGetSecret(t *testing.T) {
t.Fatal(err)
}
rr := httptest.NewRecorder()
y := New(tc.db, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t))
y := New(tc.db, "public", 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t))
y.getSecret(rr, req)
cacheControl := rr.Header().Get("Cache-Control")
if cacheControl != "private, no-cache" {
Expand Down Expand Up @@ -256,7 +256,7 @@ func TestDeleteSecret(t *testing.T) {
t.Fatal(err)
}
rr := httptest.NewRecorder()
y := New(tc.db, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t))
y := New(tc.db, "public", 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t))
y.deleteSecret(rr, req)
var s struct {
Message string `json:"message"`
Expand Down Expand Up @@ -286,7 +286,7 @@ func TestMetrics(t *testing.T) {
path: "/secret/invalid-key-format",
},
}
y := New(&mockDB{}, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t))
y := New(&mockDB{}, "public", 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t))
h := y.HTTPHandler()

for _, r := range requests {
Expand Down Expand Up @@ -359,7 +359,7 @@ func TestSecurityHeaders(t *testing.T) {
},
}

y := New(&mockDB{}, 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t))
y := New(&mockDB{}, "public", 1, prometheus.NewRegistry(), false, zaptest.NewLogger(t))
h := y.HTTPHandler()

t.Parallel()
Expand Down
4 changes: 2 additions & 2 deletions pkg/yopass/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (

func TestFetch(t *testing.T) {
db := testDB(map[string]string{})
y := server.New(&db, 1024, prometheus.NewRegistry(), false, zaptest.NewLogger(t))
y := server.New(&db, "public", 1024, prometheus.NewRegistry(), false, zaptest.NewLogger(t))
ts := httptest.NewServer(y.HTTPHandler())
defer ts.Close()

Expand Down Expand Up @@ -46,7 +46,7 @@ func TestFetchInvalidServer(t *testing.T) {
}
func TestStore(t *testing.T) {
db := testDB(map[string]string{})
y := server.New(&db, 1024, prometheus.NewRegistry(), false, zaptest.NewLogger(t))
y := server.New(&db, "public", 1024, prometheus.NewRegistry(), false, zaptest.NewLogger(t))
ts := httptest.NewServer(y.HTTPHandler())
defer ts.Close()

Expand Down

0 comments on commit c6a688f

Please sign in to comment.