From f67c2df9f0c1a0014cf1216a8f24ff777ed9dfbe Mon Sep 17 00:00:00 2001 From: subtle-byte Date: Sat, 8 Jul 2023 15:17:07 +0200 Subject: [PATCH] Limit concurrency, add contexts --- Makefile | 12 ++- cmd/server/db.go | 51 +++++++++ cmd/server/debug_middleware.go | 36 +++++++ cmd/server/main.go | 101 +++--------------- go.mod | 1 + go.sum | 2 + .../github_files_provider/github.go | 31 ++++-- .../github_files_provider/temp_file.go | 9 +- .../postgres_loc_cacher.go | 9 +- internal/server/github_handler/handler.go | 11 +- internal/server/rest/sorted_stat.go | 2 +- internal/service/github_stat/github_stat.go | 33 ++++-- 12 files changed, 177 insertions(+), 121 deletions(-) create mode 100644 cmd/server/db.go create mode 100644 cmd/server/debug_middleware.go diff --git a/Makefile b/Makefile index e50eb3f..5ab4d66 100644 --- a/Makefile +++ b/Makefile @@ -7,12 +7,18 @@ stop-db: # DB is optional, if not provided, the service will be run without cache run: DB_CONN="postgres://postgres:password@localhost:54329/?sslmode=disable" \ - DEBUG_TOKEN="" \ - go run cmd/server/main.go + DEBUG_TOKEN="dt" \ + MAX_REPO_SIZE_MB=100 \ + MAX_CONCURRENT_WORK=2 \ + go run ./cmd/server/main.go run-in-docker: docker build -t ghloc . - docker run --rm -p 8080:8080 -e DEBUG_TOKEN="" ghloc + docker run --rm -p 8080:8080 \ + -e DEBUG_TOKEN="dt" \ + -e MAX_REPO_SIZE_MB=100 \ + -e MAX_CONCURRENT_WORK=2 \ + ghloc test: go build -v ./... diff --git a/cmd/server/db.go b/cmd/server/db.go new file mode 100644 index 0000000..a1139bb --- /dev/null +++ b/cmd/server/db.go @@ -0,0 +1,51 @@ +package main + +import ( + "database/sql" + "errors" + "fmt" + "log" + + "github.com/golang-migrate/migrate/v4" +) + +type MigrationLogger struct { + Prefix string +} + +func (m MigrationLogger) Printf(format string, v ...interface{}) { + log.Print(m.Prefix, fmt.Sprintf(format, v...)) +} + +func (m MigrationLogger) Verbose() bool { + return false +} + +func connectAndMigrateDB(dbConn string) (_ *sql.DB, close func() error, err error) { + if dbConn == "" { + return nil, nil, fmt.Errorf("env var DB_CONN is not provided") + } + + m, err := migrate.New("file://migrations", dbConn) + if err != nil { + return nil, nil, fmt.Errorf("create migrator: %w", err) + } + m.Log = MigrationLogger{Prefix: "migration: "} + err = m.Up() + if err != nil && !errors.Is(err, migrate.ErrNoChange) { + return nil, nil, fmt.Errorf("migrate up: %w", err) + } + + close = func() error { return nil } + db, err := sql.Open("postgres", dbConn) + if err == nil { + close = db.Close + err = db.Ping() + } + + if err != nil { + close() + return nil, nil, fmt.Errorf("connect to db: %w", err) + } + return db, close, nil +} diff --git a/cmd/server/debug_middleware.go b/cmd/server/debug_middleware.go new file mode 100644 index 0000000..0f7c6ea --- /dev/null +++ b/cmd/server/debug_middleware.go @@ -0,0 +1,36 @@ +package main + +import ( + "fmt" + "net/http" + "net/http/httptest" + "runtime/pprof" +) + +func NewDebugMiddleware(debugToken string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + if debugToken == "" { + return http.HandlerFunc(http.NotFound) + } + fn := func(w http.ResponseWriter, r *http.Request) { + if r.FormValue("debug_token") == debugToken { + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", `attachment; filename="profile"`) + if err := pprof.StartCPUProfile(w); err != nil { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Del("Content-Disposition") + w.Header().Set("X-Go-Pprof", "1") + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err) + return + } + rr := httptest.ResponseRecorder{} + next.ServeHTTP(&rr, r) + pprof.StopCPUProfile() + } else { + http.NotFound(w, r) + } + } + return http.HandlerFunc(fn) + } +} diff --git a/cmd/server/main.go b/cmd/server/main.go index ea39217..c640a3f 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -1,98 +1,30 @@ package main import ( - "database/sql" - "errors" "fmt" "log" "net/http" - "net/http/httptest" - "os" - "runtime/pprof" // _ "net/http/pprof" + "github.com/caarlos0/env/v9" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" - "github.com/golang-migrate/migrate/v4" _ "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" _ "github.com/lib/pq" "github.com/subtle-byte/ghloc/internal/infrastructure/github_files_provider" "github.com/subtle-byte/ghloc/internal/infrastructure/postgres_loc_cacher" "github.com/subtle-byte/ghloc/internal/server/github_handler" - "github.com/subtle-byte/ghloc/internal/service/github_stat" + github_stat_service "github.com/subtle-byte/ghloc/internal/service/github_stat" ) -var debugToken *string - -func DebugMiddleware(next http.Handler) http.Handler { - if debugToken == nil { - return http.HandlerFunc(http.NotFound) - } - fn := func(w http.ResponseWriter, r *http.Request) { - if r.FormValue("debug_token") == *debugToken { - w.Header().Set("Content-Type", "application/octet-stream") - w.Header().Set("Content-Disposition", `attachment; filename="profile"`) - if err := pprof.StartCPUProfile(w); err != nil { - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Del("Content-Disposition") - w.Header().Set("X-Go-Pprof", "1") - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err) - return - } - rr := httptest.ResponseRecorder{} - next.ServeHTTP(&rr, r) - pprof.StopCPUProfile() - } else { - http.NotFound(w, r) - } - } - return http.HandlerFunc(fn) -} - -type MigrationLogger struct { - Prefix string -} - -func (m MigrationLogger) Printf(format string, v ...interface{}) { - log.Print(m.Prefix, fmt.Sprintf(format, v...)) -} - -func (m MigrationLogger) Verbose() bool { - return false -} - -func connectDB() (_ *sql.DB, close func() error, err error) { - dbConn := os.Getenv("DB_CONN") - if dbConn == "" { - return nil, nil, fmt.Errorf("env var DB_CONN is not provided") - } - - m, err := migrate.New("file://migrations", dbConn) - if err != nil { - return nil, nil, fmt.Errorf("create migrator: %w", err) - } - m.Log = MigrationLogger{Prefix: "migration: "} - err = m.Up() - if err != nil && !errors.Is(err, migrate.ErrNoChange) { - return nil, nil, fmt.Errorf("migrate up: %w", err) - } - - close = func() error { return nil } - db, err := sql.Open("postgres", dbConn) - if err == nil { - close = db.Close - err = db.Ping() - } - - if err != nil { - close() - return nil, nil, fmt.Errorf("connect to db: %w", err) - } - return db, close, nil +type Config struct { + DebugToken string `env:"DEBUG_TOKEN"` + MaxRepoSizeMB int `env:"MAX_REPO_SIZE_MB,notEmpty"` + MaxConcurrentWork int `env:"MAX_CONCURRENT_WORK,notEmpty"` + DbConnStr string `env:"DB_CONN"` } var buildTime = "unknown" // will be replaced during building the docker image @@ -100,14 +32,15 @@ var buildTime = "unknown" // will be replaced during building the docker image func main() { log.Printf("Starting up the app (build time: %v)\n", buildTime) - if token, ok := os.LookupEnv("DEBUG_TOKEN"); ok { - debugToken = &token - log.Println("Debug token is set") + cfg := &Config{} + if err := env.Parse(cfg); err != nil { + log.Fatalf("Parsing config: %v", err) } + log.Printf("Debug token is set: %v", cfg.DebugToken != "") - github := github_files_provider.Github{} - db, closeDB, err := connectDB() - pg := github_stat.LOCCacher(nil) + github := github_files_provider.New(cfg.MaxRepoSizeMB) + db, closeDB, err := connectAndMigrateDB(cfg.DbConnStr) + pg := github_stat_service.LOCCacher(nil) if err == nil { defer closeDB() pg = postgres_loc_cacher.NewPostgres(db) @@ -116,7 +49,7 @@ func main() { log.Printf("Error connecting to DB: %v", err) log.Println("Warning: continue without DB") } - service := github_stat.Service{pg, &github} + service := github_stat_service.New(pg, github, cfg.MaxConcurrentWork) router := chi.NewRouter() router.Use(middleware.RealIP) @@ -131,7 +64,7 @@ func main() { fmt.Fprintf(w, "Docs") }) - getStatHandler := &github_handler.GetStatHandler{&service, debugToken} + getStatHandler := &github_handler.GetStatHandler{service, cfg.DebugToken} getStatHandler.RegisterOn(router) redirectHandler := &github_handler.RedirectHandler{} @@ -139,7 +72,7 @@ func main() { // router.Mount("/debug", http.DefaultServeMux) // router.With(DebugMiddleware).Mount("/debug", http.DefaultServeMux) - router.With(DebugMiddleware).Route("/debug", func(r chi.Router) { + router.With(NewDebugMiddleware(cfg.DebugToken)).Route("/debug", func(r chi.Router) { getStatHandler.RegisterOn(r) }) fmt.Println("Listening on http://localhost:8080") diff --git a/go.mod b/go.mod index f7e8bba..82b6b15 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/subtle-byte/ghloc go 1.19 require ( + github.com/caarlos0/env/v9 v9.0.0 github.com/go-chi/chi/v5 v5.0.8 github.com/go-chi/cors v1.2.1 github.com/golang-migrate/migrate/v4 v4.16.2 diff --git a/go.sum b/go.sum index db399f9..bfb115e 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= +github.com/caarlos0/env/v9 v9.0.0 h1:SI6JNsOA+y5gj9njpgybykATIylrRMklbs5ch6wO6pc= +github.com/caarlos0/env/v9 v9.0.0/go.mod h1:ye5mlCVMYh6tZ+vCgrs/B95sj88cg5Tlnc0XIzgZ020= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dhui/dktest v0.3.16 h1:i6gq2YQEtcrjKbeJpBkWjE8MmLZPYllcjOFbTZuPDnw= diff --git a/internal/infrastructure/github_files_provider/github.go b/internal/infrastructure/github_files_provider/github.go index 62d9e84..9369662 100644 --- a/internal/infrastructure/github_files_provider/github.go +++ b/internal/infrastructure/github_files_provider/github.go @@ -3,6 +3,7 @@ package github_files_provider import ( "archive/zip" "bytes" + "context" "fmt" "io" "log" @@ -16,18 +17,23 @@ import ( ) type Github struct { + maxZipSizeBytes int } -const maxZipSize = 100 * 1024 * 1024 // 100 MiB +func New(maxZipSizeMB int) *Github { + return &Github{ + maxZipSizeBytes: maxZipSizeMB * 1024 * 1024, + } +} -func BuildGithubUrl(user, repo, branch string) string { +func buildGithubUrl(user, repo, branch string) string { return fmt.Sprintf("https://github.com/%v/%v/archive/refs/heads/%v.zip", user, repo, branch) } -func ReadIntoMemory(r io.Reader) (*bytes.Reader, error) { +func (g *Github) readIntoMemory(r io.Reader) (*bytes.Reader, error) { buf := &bytes.Buffer{} - lr := &LimitedReader{Reader: r, Remaining: maxZipSize} + lr := &LimitedReader{Reader: r, Remaining: g.maxZipSizeBytes} _, err := io.Copy(buf, lr) if err != nil { return nil, err @@ -36,15 +42,18 @@ func ReadIntoMemory(r io.Reader) (*bytes.Reader, error) { return bytes.NewReader(buf.Bytes()), nil } -func (r Github) GetContent(user, repo, branch string, tempStorage github_stat.TempStorage) (_ []github_stat.FileForPath, close func() error, _ error) { - url := BuildGithubUrl(user, repo, branch) +func (g *Github) GetContent(ctx context.Context, user, repo, branch string, tempStorage github_stat.TempStorage) (_ []github_stat.FileForPath, close func() error, _ error) { + url := buildGithubUrl(user, repo, branch) start := time.Now() - resp, err := http.Get(url) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - log.Println(url, err) - return nil, nil, err + return nil, nil, fmt.Errorf("create request: %w", err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("do request: %w", err) } defer resp.Body.Close() @@ -60,7 +69,7 @@ func (r Github) GetContent(user, repo, branch string, tempStorage github_stat.Te readerAt := io.ReaderAt(nil) readerLen := 0 if tempStorage == github_stat.TempStorageFile { - tempFile, err := NewTempFile(resp.Body) + tempFile, err := NewTempFile(resp.Body, g.maxZipSizeBytes) if err != nil { return nil, nil, err } @@ -68,7 +77,7 @@ func (r Github) GetContent(user, repo, branch string, tempStorage github_stat.Te readerAt = tempFile readerLen = tempFile.Len() } else { - r, err := ReadIntoMemory(resp.Body) + r, err := g.readIntoMemory(resp.Body) if err != nil { return nil, nil, err } diff --git a/internal/infrastructure/github_files_provider/temp_file.go b/internal/infrastructure/github_files_provider/temp_file.go index 9ed3f0d..4000516 100644 --- a/internal/infrastructure/github_files_provider/temp_file.go +++ b/internal/infrastructure/github_files_provider/temp_file.go @@ -2,7 +2,6 @@ package github_files_provider import ( "io" - "io/ioutil" "log" "os" ) @@ -12,23 +11,23 @@ type TempFile struct { len int } -func NewTempFile(r io.Reader) (_ *TempFile, err error) { +func NewTempFile(r io.Reader, maxSizeBytes int) (_ *TempFile, err error) { tf := &TempFile{} - tf.File, err = ioutil.TempFile("", "") + tf.File, err = os.CreateTemp("", "") if err != nil { return nil, err } log.Print("temp file: ", tf.File.Name()) - lr := &LimitedReader{Reader: r, Remaining: maxZipSize} + lr := &LimitedReader{Reader: r, Remaining: maxSizeBytes} _, err = io.Copy(tf.File, lr) if err != nil { tf.Close() return nil, err } - tf.len = maxZipSize - lr.Remaining + tf.len = maxSizeBytes - lr.Remaining return tf, nil } diff --git a/internal/infrastructure/postgres_loc_cacher/postgres_loc_cacher.go b/internal/infrastructure/postgres_loc_cacher/postgres_loc_cacher.go index 2ccdfaf..f183dfb 100644 --- a/internal/infrastructure/postgres_loc_cacher/postgres_loc_cacher.go +++ b/internal/infrastructure/postgres_loc_cacher/postgres_loc_cacher.go @@ -1,6 +1,7 @@ package postgres_loc_cacher import ( + "context" "database/sql" "encoding/json" "log" @@ -34,7 +35,7 @@ func repoName(user, repo, branch string) string { return user + "/" + repo + "/" + branch } -func (p Postgres) SetLOCs(user, repo, branch string, locs []loc_count.LOCForPath) error { +func (p Postgres) SetLOCs(ctx context.Context, user, repo, branch string, locs []loc_count.LOCForPath) error { repoName := repoName(user, repo, branch) bytes, err := json.Marshal(locs) @@ -44,7 +45,7 @@ func (p Postgres) SetLOCs(user, repo, branch string, locs []loc_count.LOCForPath start := time.Now() - _, err = p.db.Exec("INSERT INTO repos VALUES ($1, $2, $3)", repoName, bytes, time.Now().Unix()) + _, err = p.db.ExecContext(ctx, "INSERT INTO repos VALUES ($1, $2, $3)", repoName, bytes, time.Now().Unix()) if err != nil { return err } @@ -53,14 +54,14 @@ func (p Postgres) SetLOCs(user, repo, branch string, locs []loc_count.LOCForPath return nil } -func (p Postgres) GetLOCs(user, repo, branch string) (locs []loc_count.LOCForPath, _ error) { +func (p Postgres) GetLOCs(ctx context.Context, user, repo, branch string) (locs []loc_count.LOCForPath, _ error) { repoName := repoName(user, repo, branch) bytes := []byte(nil) start := time.Now() - err := p.db.QueryRow("SELECT locs FROM repos WHERE name = $1", repoName).Scan(&bytes) + err := p.db.QueryRowContext(ctx, "SELECT locs FROM repos WHERE name = $1", repoName).Scan(&bytes) if err != nil { if err == sql.ErrNoRows { return nil, github_stat.ErrNoData diff --git a/internal/server/github_handler/handler.go b/internal/server/github_handler/handler.go index ef4e101..2569d65 100644 --- a/internal/server/github_handler/handler.go +++ b/internal/server/github_handler/handler.go @@ -1,6 +1,7 @@ package github_handler import ( + "context" "net/http" "github.com/go-chi/chi/v5" @@ -10,12 +11,12 @@ import ( ) type Service interface { - GetStat(user, repo, branch string, filter, matcher *string, noLOCProvider bool, tempStorage github_stat.TempStorage) (*loc_count.StatTree, error) + GetStat(ctx context.Context, user, repo, branch string, filter, matcher *string, noLOCProvider bool, tempStorage github_stat.TempStorage) (*loc_count.StatTree, error) } type GetStatHandler struct { Service Service - DebugToken *string + DebugToken string } func (h *GetStatHandler) RegisterOn(router chi.Router) { @@ -31,9 +32,9 @@ func (h GetStatHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { noLOCProvider := false tempStorage := github_stat.TempStorageFile - if h.DebugToken != nil { + if h.DebugToken != "" { debugTokenInRequest := r.FormValue("debug_token") - if debugTokenInRequest == *h.DebugToken { + if debugTokenInRequest == h.DebugToken { if r.Form["no_cache"] != nil { noLOCProvider = true } @@ -56,7 +57,7 @@ func (h GetStatHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { matcher = &matchers[0] } - stat, err := h.Service.GetStat(user, repo, branch, filter, matcher, noLOCProvider, tempStorage) + stat, err := h.Service.GetStat(r.Context(), user, repo, branch, filter, matcher, noLOCProvider, tempStorage) if err != nil { rest.WriteResponse(w, err, true) return diff --git a/internal/server/rest/sorted_stat.go b/internal/server/rest/sorted_stat.go index ea13c25..826497d 100644 --- a/internal/server/rest/sorted_stat.go +++ b/internal/server/rest/sorted_stat.go @@ -28,7 +28,7 @@ func (st *SortedStat) marshalJson(w *Buffer) error { w.UnwriteByte() // remove newline inserted by json.Encoder } - if st.Children == nil { + if len(st.Children) == 0 { encode(st.LOC) return err } diff --git a/internal/service/github_stat/github_stat.go b/internal/service/github_stat/github_stat.go index 6b02647..f0e03c5 100644 --- a/internal/service/github_stat/github_stat.go +++ b/internal/service/github_stat/github_stat.go @@ -1,6 +1,7 @@ package github_stat import ( + "context" "fmt" "io" "log" @@ -12,8 +13,8 @@ import ( var ErrNoData = fmt.Errorf("no data") type LOCCacher interface { - SetLOCs(user, repo, branch string, locs []loc_count.LOCForPath) error - GetLOCs(user, repo, branch string) ([]loc_count.LOCForPath, error) // error may be ErrNoData + SetLOCs(ctx context.Context, user, repo, branch string, locs []loc_count.LOCForPath) error + GetLOCs(ctx context.Context, user, repo, branch string) ([]loc_count.LOCForPath, error) // error may be ErrNoData } type TempStorage int @@ -31,18 +32,34 @@ type FileForPath struct { } type ContentProvider interface { - GetContent(user, repo, branch string, tempStorage TempStorage) (_ []FileForPath, close func() error, _ error) + GetContent(ctx context.Context, user, repo, branch string, tempStorage TempStorage) (_ []FileForPath, close func() error, _ error) } type Service struct { LOCCacher LOCCacher // possibly nil ContentProvider ContentProvider + sem chan struct{} // semaphore for limiting number of concurrent work } -func (s *Service) GetStat(user, repo, branch string, filter, matcher *string, noLOCProvider bool, tempStorage TempStorage) (*loc_count.StatTree, error) { +func New(locCacher LOCCacher, contentProvider ContentProvider, maxParallelWork int) *Service { + return &Service{ + LOCCacher: locCacher, + ContentProvider: contentProvider, + sem: make(chan struct{}, maxParallelWork), + } +} + +func (s *Service) GetStat(ctx context.Context, user, repo, branch string, filter, matcher *string, noLOCProvider bool, tempStorage TempStorage) (*loc_count.StatTree, error) { + select { + case s.sem <- struct{}{}: + defer func() { <-s.sem }() + case <-ctx.Done(): + return nil, fmt.Errorf("wait in queue: %w", ctx.Err()) + } + if s.LOCCacher != nil { if !noLOCProvider { - locs, err := s.LOCCacher.GetLOCs(user, repo, branch) + locs, err := s.LOCCacher.GetLOCs(ctx, user, repo, branch) if err == nil { // TODO? return loc_count.BuildStatTree(locs, filter, matcher), nil } @@ -51,9 +68,9 @@ func (s *Service) GetStat(user, repo, branch string, filter, matcher *string, no } } - filesForPaths, close, err := s.ContentProvider.GetContent(user, repo, branch, tempStorage) + filesForPaths, close, err := s.ContentProvider.GetContent(ctx, user, repo, branch, tempStorage) if err != nil { - return nil, err + return nil, fmt.Errorf("get repo content: %w", err) } defer close() @@ -78,7 +95,7 @@ func (s *Service) GetStat(user, repo, branch string, filter, matcher *string, no log.Println("LOCs counted in", time.Since(start)) if s.LOCCacher != nil && !noLOCProvider { - err := s.LOCCacher.SetLOCs(user, repo, branch, locs) + err := s.LOCCacher.SetLOCs(ctx, user, repo, branch, locs) if err != nil { log.Println("Error saving LOCs:", err) }