From b4606ffaf4278c27811d99cf1e411be71902026b Mon Sep 17 00:00:00 2001 From: Marcus Weiner Date: Tue, 10 Sep 2019 14:03:38 +0200 Subject: [PATCH 1/2] Cleanup log handling --- api/api.go | 8 +++--- api/api_test.go | 3 +- api/instance_test.go | 6 ++-- api/log.go | 6 ++-- api/middleware_test.go | 6 ++-- api/payments_test.go | 6 ++-- api/utils_test.go | 4 +-- cmd/migrate_cmd.go | 4 +-- cmd/multi_cmd.go | 8 +++--- cmd/root_cmd.go | 6 ++-- cmd/serve_cmd.go | 18 ++++++------ conf/configuration.go | 16 ++++++----- models/connection.go | 14 +++++----- models/connection_logger.go | 56 +++++++++++++++++++++++++++++++++++++ 14 files changed, 110 insertions(+), 51 deletions(-) create mode 100644 models/connection_logger.go diff --git a/api/api.go b/api/api.go index f3213cb..4bd2e9d 100644 --- a/api/api.go +++ b/api/api.go @@ -73,12 +73,12 @@ func waitForTermination(log logrus.FieldLogger, done <-chan struct{}) { } // NewAPI instantiates a new REST API using the default version. -func NewAPI(globalConfig *conf.GlobalConfiguration, db *gorm.DB) *API { - return NewAPIWithVersion(context.Background(), globalConfig, db, defaultVersion) +func NewAPI(globalConfig *conf.GlobalConfiguration, log logrus.FieldLogger, db *gorm.DB) *API { + return NewAPIWithVersion(context.Background(), globalConfig, log, db, defaultVersion) } // NewAPIWithVersion instantiates a new REST API. -func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfiguration, db *gorm.DB, version string) *API { +func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfiguration, log logrus.FieldLogger, db *gorm.DB, version string) *API { api := &API{ config: globalConfig, db: db, @@ -87,7 +87,7 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati } xffmw, _ := xff.Default() - logger := newStructuredLogger(logrus.StandardLogger()) + logger := newStructuredLogger(log) r := newRouter() r.UseBypass(xffmw.Handler) diff --git a/api/api_test.go b/api/api_test.go index 62316d4..619aad6 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -6,6 +6,7 @@ import ( "net/http/httptest" "testing" + "github.com/sirupsen/logrus" "github.com/sirupsen/logrus/hooks/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -25,7 +26,7 @@ func TestTraceWrapper(t *testing.T) { ctx, err := WithInstanceConfig(context.Background(), globalConfig.SMTP, config, "") require.NoError(t, err) - api := NewAPIWithVersion(ctx, globalConfig, nil, "") + api := NewAPIWithVersion(ctx, globalConfig, logrus.StandardLogger(), nil, "") server := httptest.NewServer(api.handler) defer server.Close() diff --git a/api/instance_test.go b/api/instance_test.go index 7cd21f0..d45e1a0 100644 --- a/api/instance_test.go +++ b/api/instance_test.go @@ -25,14 +25,14 @@ type InstanceTestSuite struct { } func (ts *InstanceTestSuite) SetupTest() { - globalConfig, err := conf.LoadGlobal("test.env") + globalConfig, log, err := conf.LoadGlobal("test.env") require.NoError(ts.T(), err) globalConfig.OperatorToken = operatorToken globalConfig.MultiInstanceMode = true - db, err := models.Connect(globalConfig) + db, err := models.Connect(globalConfig, log) require.NoError(ts.T(), err) - api := NewAPI(globalConfig, db) + api := NewAPI(globalConfig, log, db) ts.API = api // Cleanup existing instance diff --git a/api/log.go b/api/log.go index 0b0f173..9540adb 100644 --- a/api/log.go +++ b/api/log.go @@ -10,16 +10,16 @@ import ( "github.com/sirupsen/logrus" ) -func newStructuredLogger(logger *logrus.Logger) func(next http.Handler) http.Handler { +func newStructuredLogger(logger logrus.FieldLogger) func(next http.Handler) http.Handler { return chimiddleware.RequestLogger(&structuredLogger{logger}) } type structuredLogger struct { - Logger *logrus.Logger + Logger logrus.FieldLogger } func (l *structuredLogger) NewLogEntry(r *http.Request) chimiddleware.LogEntry { - entry := &structuredLoggerEntry{Logger: logrus.NewEntry(l.Logger)} + entry := &structuredLoggerEntry{Logger: l.Logger} logFields := logrus.Fields{ "component": "api", "method": r.Method, diff --git a/api/middleware_test.go b/api/middleware_test.go index d28b304..d79656a 100644 --- a/api/middleware_test.go +++ b/api/middleware_test.go @@ -21,13 +21,13 @@ type MiddlewareTestSuite struct { } func (ts *MiddlewareTestSuite) SetupTest() { - globalConfig, err := conf.LoadGlobal("test.env") + globalConfig, log, err := conf.LoadGlobal("test.env") require.NoError(ts.T(), err) globalConfig.MultiInstanceMode = true - db, err := models.Connect(globalConfig) + db, err := models.Connect(globalConfig, log) require.NoError(ts.T(), err) - api := NewAPI(globalConfig, db) + api := NewAPI(globalConfig, log, db) ts.API = api } diff --git a/api/payments_test.go b/api/payments_test.go index a3e6532..0763c2a 100644 --- a/api/payments_test.go +++ b/api/payments_test.go @@ -253,7 +253,7 @@ func TestPaymentsRefund(t *testing.T) { err = signHTTPRequest(r, testAdminToken("magical-unicorn", ""), test.Config.JWT.Secret) require.NoError(t, err) - NewAPIWithVersion(ctx, test.GlobalConfig, test.DB, defaultVersion).handler.ServeHTTP(w, r) + NewAPIWithVersion(ctx, test.GlobalConfig, logrus.StandardLogger(), test.DB, defaultVersion).handler.ServeHTTP(w, r) rsp := new(models.Transaction) extractPayload(t, http.StatusOK, w, rsp) @@ -674,7 +674,7 @@ func TestPaymentPreauthorize(t *testing.T) { globalConfig := new(conf.GlobalConfiguration) ctx, err := WithInstanceConfig(context.Background(), globalConfig.SMTP, test.Config, "") require.NoError(t, err) - NewAPIWithVersion(ctx, test.GlobalConfig, test.DB, "").handler.ServeHTTP(recorder, req) + NewAPIWithVersion(ctx, test.GlobalConfig, logrus.StandardLogger(), test.DB, "").handler.ServeHTTP(recorder, req) rsp := payments.PreauthorizationResult{} extractPayload(t, http.StatusOK, recorder, &rsp) @@ -715,7 +715,7 @@ func TestPaymentPreauthorize(t *testing.T) { globalConfig := new(conf.GlobalConfiguration) ctx, err := WithInstanceConfig(context.Background(), globalConfig.SMTP, test.Config, "") require.NoError(t, err) - NewAPIWithVersion(ctx, test.GlobalConfig, test.DB, "").handler.ServeHTTP(recorder, req) + NewAPIWithVersion(ctx, test.GlobalConfig, logrus.StandardLogger(), test.DB, "").handler.ServeHTTP(recorder, req) rsp := payments.PreauthorizationResult{} extractPayload(t, http.StatusOK, recorder, &rsp) diff --git a/api/utils_test.go b/api/utils_test.go index 1b68d06..fa5318d 100644 --- a/api/utils_test.go +++ b/api/utils_test.go @@ -55,7 +55,7 @@ func db(t *testing.T) (*gorm.DB, *conf.GlobalConfiguration, *conf.Configuration, globalConfig.DB.Driver = "sqlite3" globalConfig.DB.URL = f.Name() - db, err := models.Connect(globalConfig) + db, err := models.Connect(globalConfig, logrus.StandardLogger()) if err != nil { assert.FailNow(t, "failed to connect to db: "+err.Error()) } @@ -386,7 +386,7 @@ func (r *RouteTest) TestEndpoint(method string, url string, body io.Reader, toke globalConfig := new(conf.GlobalConfiguration) ctx, err := WithInstanceConfig(context.Background(), globalConfig.SMTP, r.Config, "") require.NoError(r.T, err) - NewAPIWithVersion(ctx, r.GlobalConfig, r.DB, "").handler.ServeHTTP(recorder, req) + NewAPIWithVersion(ctx, r.GlobalConfig, logrus.StandardLogger(), r.DB, "").handler.ServeHTTP(recorder, req) return recorder } diff --git a/cmd/migrate_cmd.go b/cmd/migrate_cmd.go index 072915c..d379d38 100644 --- a/cmd/migrate_cmd.go +++ b/cmd/migrate_cmd.go @@ -15,8 +15,8 @@ var migrateCmd = cobra.Command{ }, } -func migrate(globalConfig *conf.GlobalConfiguration, config *conf.Configuration) { - db, err := models.Connect(globalConfig) +func migrate(globalConfig *conf.GlobalConfiguration, log logrus.FieldLogger, config *conf.Configuration) { + db, err := models.Connect(globalConfig, log) if err != nil { logrus.Fatalf("Error opening database: %+v", err) } diff --git a/cmd/multi_cmd.go b/cmd/multi_cmd.go index 34d1d8b..fb3752d 100644 --- a/cmd/multi_cmd.go +++ b/cmd/multi_cmd.go @@ -18,7 +18,7 @@ var multiCmd = cobra.Command{ } func multi(cmd *cobra.Command, args []string) { - globalConfig, err := conf.LoadGlobal(configFile) + globalConfig, log, err := conf.LoadGlobal(configFile) if err != nil { logrus.Fatalf("Failed to load configuration: %+v", err) } @@ -26,20 +26,20 @@ func multi(cmd *cobra.Command, args []string) { logrus.Fatal("Operator token secret is required") } - db, err := models.Connect(globalConfig) + db, err := models.Connect(globalConfig, log.WithField("component", "db")) if err != nil { logrus.Fatalf("Error opening database: %+v", err) } defer db.Close() - bgDB, err := models.Connect(globalConfig) + bgDB, err := models.Connect(globalConfig, log.WithField("component", "db").WithField("bgdb", true)) if err != nil { logrus.Fatalf("Error opening database: %+v", err) } defer bgDB.Close() globalConfig.MultiInstanceMode = true - api := api.NewAPIWithVersion(context.Background(), globalConfig, db.Debug(), Version) + api := api.NewAPIWithVersion(context.Background(), globalConfig, log, db.Debug(), Version) l := fmt.Sprintf("%v:%v", globalConfig.API.Host, globalConfig.API.Port) logrus.Infof("GoCommerce API started on: %s", l) diff --git a/cmd/root_cmd.go b/cmd/root_cmd.go index eb479fc..5194b18 100644 --- a/cmd/root_cmd.go +++ b/cmd/root_cmd.go @@ -25,8 +25,8 @@ func RootCmd() *cobra.Command { return &rootCmd } -func execWithConfig(cmd *cobra.Command, fn func(globalConfig *conf.GlobalConfiguration, config *conf.Configuration)) { - globalConfig, err := conf.LoadGlobal(configFile) +func execWithConfig(cmd *cobra.Command, fn func(globalConfig *conf.GlobalConfiguration, log logrus.FieldLogger, config *conf.Configuration)) { + globalConfig, log, err := conf.LoadGlobal(configFile) if err != nil { logrus.Fatalf("Failed to load configuration: %+v", err) } @@ -35,5 +35,5 @@ func execWithConfig(cmd *cobra.Command, fn func(globalConfig *conf.GlobalConfigu logrus.Fatalf("Failed to load configuration: %+v", err) } - fn(globalConfig, config) + fn(globalConfig, log, config) } diff --git a/cmd/serve_cmd.go b/cmd/serve_cmd.go index e7ba741..0f34592 100644 --- a/cmd/serve_cmd.go +++ b/cmd/serve_cmd.go @@ -19,29 +19,29 @@ var serveCmd = cobra.Command{ }, } -func serve(globalConfig *conf.GlobalConfiguration, config *conf.Configuration) { - db, err := models.Connect(globalConfig) +func serve(globalConfig *conf.GlobalConfiguration, log logrus.FieldLogger, config *conf.Configuration) { + db, err := models.Connect(globalConfig, log.WithField("component", "db")) if err != nil { - logrus.Fatalf("Error opening database: %+v", err) + log.Fatalf("Error opening database: %+v", err) } defer db.Close() - bgDB, err := models.Connect(globalConfig) + bgDB, err := models.Connect(globalConfig, log.WithField("component", "db").WithField("bgdb", true)) if err != nil { - logrus.Fatalf("Error opening database: %+v", err) + log.Fatalf("Error opening database: %+v", err) } defer bgDB.Close() ctx, err := api.WithInstanceConfig(context.Background(), globalConfig.SMTP, config, "") if err != nil { - logrus.Fatalf("Error loading instance config: %+v", err) + log.Fatalf("Error loading instance config: %+v", err) } - api := api.NewAPIWithVersion(ctx, globalConfig, db, Version) + api := api.NewAPIWithVersion(ctx, globalConfig, log, db, Version) l := fmt.Sprintf("%v:%v", globalConfig.API.Host, globalConfig.API.Port) - logrus.Infof("GoCommerce API started on: %s", l) + log.Infof("GoCommerce API started on: %s", l) - models.RunHooks(bgDB, logrus.WithField("component", "hooks")) + models.RunHooks(bgDB, log.WithField("component", "hooks")) api.ListenAndServe(l) } diff --git a/conf/configuration.go b/conf/configuration.go index 22a4d73..1371d3f 100644 --- a/conf/configuration.go +++ b/conf/configuration.go @@ -5,6 +5,7 @@ import ( "github.com/joho/godotenv" "github.com/kelseyhightower/envconfig" + "github.com/sirupsen/logrus" ) // DBConfiguration holds all the database related configuration. @@ -39,7 +40,7 @@ type GlobalConfiguration struct { } DB DBConfiguration Logging LoggingConfig `envconfig:"LOG"` - OperatorToken string `split_words:"true"` + OperatorToken string `split_words:"true"` MultiInstanceMode bool SMTP SMTPConfiguration `json:"smtp"` } @@ -118,19 +119,20 @@ func loadEnvironment(filename string) error { } // LoadGlobal will construct the core config from the file -func LoadGlobal(filename string) (*GlobalConfiguration, error) { +func LoadGlobal(filename string) (*GlobalConfiguration, *logrus.Entry, error) { if err := loadEnvironment(filename); err != nil { - return nil, err + return nil, nil, err } config := new(GlobalConfiguration) if err := envconfig.Process("gocommerce", config); err != nil { - return nil, err + return nil, nil, err } - if _, err := ConfigureLogging(&config.Logging); err != nil { - return nil, err + log, err := ConfigureLogging(&config.Logging) + if err != nil { + return nil, nil, err } - return config, nil + return config, log, nil } // LoadConfig loads the per-instance configuration from a file diff --git a/models/connection.go b/models/connection.go index 648933e..9e83d2f 100644 --- a/models/connection.go +++ b/models/connection.go @@ -5,10 +5,9 @@ import ( _ "github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/dialers/mysql" _ "github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/dialers/postgres" _ "github.com/go-sql-driver/mysql" + "github.com/jinzhu/gorm" _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" - - "github.com/jinzhu/gorm" "github.com/netlify/gocommerce/conf" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -21,7 +20,7 @@ import ( var Namespace string // Connect will connect to that storage engine -func Connect(config *conf.GlobalConfiguration) (*gorm.DB, error) { +func Connect(config *conf.GlobalConfiguration, log logrus.FieldLogger) (*gorm.DB, error) { if config.DB.Namespace != "" { Namespace = config.DB.Namespace } @@ -34,9 +33,8 @@ func Connect(config *conf.GlobalConfiguration) (*gorm.DB, error) { return nil, errors.Wrap(err, "opening database connection") } - if logrus.StandardLogger().Level == logrus.DebugLevel { - db.LogMode(true) - } + db.SetLogger(NewDBLogger(log)) + db.LogMode(true) err = db.DB().Ping() if err != nil { @@ -44,7 +42,9 @@ func Connect(config *conf.GlobalConfiguration) (*gorm.DB, error) { } if config.DB.Automigrate { - if err := AutoMigrate(db); err != nil { + migDB := db.New() + migDB.SetLogger(NewDBLogger(log.WithField("task", "migration"))) + if err := AutoMigrate(migDB); err != nil { return nil, errors.Wrap(err, "migrating tables") } } diff --git a/models/connection_logger.go b/models/connection_logger.go new file mode 100644 index 0000000..ddbd999 --- /dev/null +++ b/models/connection_logger.go @@ -0,0 +1,56 @@ +package models + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/sirupsen/logrus" +) + +type DBLogger struct { + logrus.FieldLogger +} + +func NewDBLogger(log logrus.FieldLogger) *DBLogger { + return &DBLogger{log} +} + +func (dbl *DBLogger) Print(params ...interface{}) { + if len(params) <= 1 { + return + } + + level := params[0] + log := dbl.WithField("gorm_level", level) + + if entry, ok := dbl.FieldLogger.(*logrus.Entry); ok && entry.Logger.Level >= logrus.TraceLevel { + log = log.WithField("gorm_source", params[1]) + } + + if level != "sql" { + log.Debug(params[2:]...) + return + } + + dur := params[2].(time.Duration) + sql := params[3].(string) + sqlValues := params[4].([]interface{}) + rows := params[5].(int64) + + values := "" + if valuesJSON, err := json.Marshal(sqlValues); err == nil { + values = string(valuesJSON) + } else { + values = fmt.Sprintf("%+v", sqlValues) + } + + log. + WithField("dur_ns", dur.Nanoseconds()). + WithField("dur", dur). + WithField("sql", strings.ReplaceAll(sql, `"`, `'`)). + WithField("values", strings.ReplaceAll(values, `"`, `'`)). + WithField("rows", rows). + Debug("sql query") +} From e50215cf432ea358e60f6da949051de5a0775634 Mon Sep 17 00:00:00 2001 From: Marcus Weiner Date: Tue, 10 Sep 2019 14:07:41 +0200 Subject: [PATCH 2/2] Use custom logger for sql debug logging --- api/api.go | 4 +++- api/db_logger.go | 28 ++++++++++++++++++++++++++++ api/download.go | 18 ++++++++++-------- api/instance.go | 12 +++++++----- api/order.go | 20 +++++++++++--------- api/payments.go | 28 +++++++++++++++------------- api/reports.go | 9 +++++---- api/user.go | 26 ++++++++++++++------------ context/context.go | 16 ++++++++++++++++ models/connection_logger.go | 6 +----- 10 files changed, 110 insertions(+), 57 deletions(-) create mode 100644 api/db_logger.go diff --git a/api/api.go b/api/api.go index 4bd2e9d..5bd1b7e 100644 --- a/api/api.go +++ b/api/api.go @@ -98,6 +98,7 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati r.Route("/", func(r *router) { r.UseBypass(logger) + r.Use(api.loggingDB) if globalConfig.MultiInstanceMode { r.Use(api.loadInstanceConfig) } @@ -147,9 +148,10 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati if globalConfig.MultiInstanceMode { // Operator microservice API - r.WithBypass(logger).With(api.verifyOperatorRequest).Get("/", api.GetAppManifest) + r.WithBypass(logger).With(api.loggingDB).With(api.verifyOperatorRequest).Get("/", api.GetAppManifest) r.Route("/instances", func(r *router) { r.UseBypass(logger) + r.Use(api.loggingDB) r.Use(api.verifyOperatorRequest) r.Post("/", api.CreateInstance) diff --git a/api/db_logger.go b/api/db_logger.go new file mode 100644 index 0000000..b75c978 --- /dev/null +++ b/api/db_logger.go @@ -0,0 +1,28 @@ +package api + +import ( + "context" + "net/http" + + "github.com/jinzhu/gorm" + gcontext "github.com/netlify/gocommerce/context" + "github.com/netlify/gocommerce/models" +) + +func (a *API) loggingDB(w http.ResponseWriter, r *http.Request) (context.Context, error) { + if a.db == nil { + return r.Context(), nil + } + + log := getLogEntry(r) + db := a.db.New() + db.SetLogger(models.NewDBLogger(log)) + + return gcontext.WithDB(r.Context(), db), nil +} + +// DB provides callers with a database instance configured for request logging +func (a *API) DB(r *http.Request) *gorm.DB { + ctx := r.Context() + return gcontext.GetDB(ctx) +} diff --git a/api/download.go b/api/download.go index c19ebd3..b2e9a93 100644 --- a/api/download.go +++ b/api/download.go @@ -15,13 +15,14 @@ const maxIPsPerDay = 50 // DownloadURL returns a signed URL to download a purchased asset. func (a *API) DownloadURL(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() + db := a.DB(r) downloadID := chi.URLParam(r, "download_id") logEntrySetField(r, "download_id", downloadID) claims := gcontext.GetClaims(ctx) assets := gcontext.GetAssetStore(ctx) download := &models.Download{} - if result := a.db.Where("id = ?", downloadID).First(download); result.Error != nil { + if result := db.Where("id = ?", downloadID).First(download); result.Error != nil { if result.RecordNotFound() { return notFoundError("Download not found") } @@ -29,7 +30,7 @@ func (a *API) DownloadURL(w http.ResponseWriter, r *http.Request) error { } order := &models.Order{} - if result := a.db.Where("id = ?", download.OrderID).First(order); result.Error != nil { + if result := db.Where("id = ?", download.OrderID).First(order); result.Error != nil { if result.RecordNotFound() { return notFoundError("Download order not found") } @@ -44,7 +45,7 @@ func (a *API) DownloadURL(w http.ResponseWriter, r *http.Request) error { return unauthorizedError("This download has not been paid yet") } - rows, err := a.db.Model(&models.Event{}). + rows, err := db.Model(&models.Event{}). Select("count(distinct(ip))"). Where("order_id = ? and created_at > ? and changes = 'download'", order.ID, time.Now().Add(-24*time.Hour)). Rows() @@ -66,7 +67,7 @@ func (a *API) DownloadURL(w http.ResponseWriter, r *http.Request) error { return internalServerError("Error signing download").WithInternalError(err) } - tx := a.db.Begin() + tx := db.Begin() tx.Model(download).Updates(map[string]interface{}{"download_count": gorm.Expr("download_count + 1")}) var subject string if claims != nil { @@ -81,12 +82,13 @@ func (a *API) DownloadURL(w http.ResponseWriter, r *http.Request) error { // DownloadList lists all purchased downloads for an order or a user. func (a *API) DownloadList(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() + db := a.DB(r) orderID := gcontext.GetOrderID(ctx) log := getLogEntry(r) order := &models.Order{} if orderID != "" { - if result := a.db.Where("id = ?", orderID).First(order); result.Error != nil { + if result := db.Where("id = ?", orderID).First(order); result.Error != nil { if result.RecordNotFound() { return notFoundError("Download order not found") } @@ -106,10 +108,10 @@ func (a *API) DownloadList(w http.ResponseWriter, r *http.Request) error { } } - orderTable := a.db.NewScope(models.Order{}).QuotedTableName() - downloadsTable := a.db.NewScope(models.Download{}).QuotedTableName() + orderTable := db.NewScope(models.Order{}).QuotedTableName() + downloadsTable := db.NewScope(models.Download{}).QuotedTableName() - query := a.db.Joins("join " + orderTable + " ON " + downloadsTable + ".order_id = " + orderTable + ".id and " + orderTable + ".payment_state = 'paid'") + query := db.Joins("join " + orderTable + " ON " + downloadsTable + ".order_id = " + orderTable + ".id and " + orderTable + ".payment_state = 'paid'") if order != nil { query = query.Where(orderTable+".id = ?", order.ID) } else { diff --git a/api/instance.go b/api/instance.go index e7590d4..243bb66 100644 --- a/api/instance.go +++ b/api/instance.go @@ -16,7 +16,7 @@ func (a *API) loadInstance(w http.ResponseWriter, r *http.Request) (context.Cont instanceID := chi.URLParam(r, "instance_id") logEntrySetField(r, "instance_id", instanceID) - i, err := models.GetInstance(a.db, instanceID) + i, err := models.GetInstance(a.DB(r), instanceID) if err != nil { if models.IsNotFoundError(err) { return nil, notFoundError("Instance not found") @@ -47,12 +47,14 @@ type InstanceResponse struct { } func (a *API) CreateInstance(w http.ResponseWriter, r *http.Request) error { + db := a.DB(r) + params := InstanceRequestParams{} if err := json.NewDecoder(r.Body).Decode(¶ms); err != nil { return badRequestError("Error decoding params: %v", err) } - _, err := models.GetInstanceByUUID(a.db, params.UUID) + _, err := models.GetInstanceByUUID(db, params.UUID) if err != nil { if !models.IsNotFoundError(err) { return internalServerError("Database error looking up instance").WithInternalError(err) @@ -66,7 +68,7 @@ func (a *API) CreateInstance(w http.ResponseWriter, r *http.Request) error { UUID: params.UUID, BaseConfig: params.BaseConfig, } - if err = models.CreateInstance(a.db, &i); err != nil { + if err = models.CreateInstance(db, &i); err != nil { return internalServerError("Database error creating instance").WithInternalError(err) } @@ -95,7 +97,7 @@ func (a *API) UpdateInstance(w http.ResponseWriter, r *http.Request) error { i.BaseConfig = params.BaseConfig } - if err := models.UpdateInstance(a.db, i); err != nil { + if err := models.UpdateInstance(a.DB(r), i); err != nil { return internalServerError("Database error updating instance").WithInternalError(err) } return sendJSON(w, http.StatusOK, i) @@ -103,7 +105,7 @@ func (a *API) UpdateInstance(w http.ResponseWriter, r *http.Request) error { func (a *API) DeleteInstance(w http.ResponseWriter, r *http.Request) error { i := gcontext.GetInstance(r.Context()) - if err := models.DeleteInstance(a.db, i); err != nil { + if err := models.DeleteInstance(a.DB(r), i); err != nil { return internalServerError("Database error deleting instance").WithInternalError(err) } diff --git a/api/order.go b/api/order.go index 63654ae..dd032b0 100644 --- a/api/order.go +++ b/api/order.go @@ -85,6 +85,7 @@ func (a *API) withOrderID(w http.ResponseWriter, r *http.Request) (context.Conte // ClaimOrders will look for any orders with no user id belonging to an email and claim them func (a *API) ClaimOrders(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() + db := a.DB(r) log := getLogEntry(r) instanceID := gcontext.GetInstanceID(ctx) @@ -103,7 +104,7 @@ func (a *API) ClaimOrders(w http.ResponseWriter, r *http.Request) error { }) // now find all the order associated with that email - query := orderQuery(a.db) + query := orderQuery(db) query = query.Where(&models.Order{ InstanceID: instanceID, UserID: "", @@ -115,7 +116,7 @@ func (a *API) ClaimOrders(w http.ResponseWriter, r *http.Request) error { return internalServerError("Failed to query for orders with email: %s", claims.Email).WithInternalError(res.Error) } - tx := a.db.Begin() + tx := db.Begin() // create the user user := models.User{ @@ -154,7 +155,7 @@ func (a *API) ReceiptView(w http.ResponseWriter, r *http.Request) error { logEntrySetField(r, "order_id", id) order := &models.Order{} - if result := orderQuery(a.db).Preload("Transactions").First(order, "id = ?", id); result.Error != nil { + if result := orderQuery(a.DB(r)).Preload("Transactions").First(order, "id = ?", id); result.Error != nil { if result.RecordNotFound() { return notFoundError("Order not found") } @@ -198,7 +199,7 @@ func (a *API) ResendOrderReceipt(w http.ResponseWriter, r *http.Request) error { } order := &models.Order{} - if result := orderQuery(a.db).Preload("Transactions").First(order, "id = ?", id); result.Error != nil { + if result := orderQuery(a.DB(r)).Preload("Transactions").First(order, "id = ?", id); result.Error != nil { if result.RecordNotFound() { return notFoundError("Order not found") } @@ -246,7 +247,7 @@ func (a *API) OrderList(w http.ResponseWriter, r *http.Request) error { var err error params := r.URL.Query() - query := orderQuery(a.db) + query := orderQuery(a.DB(r)) query, err = parseOrderParams(query, params) if err != nil { return badRequestError("Bad parameters in query: %v", err) @@ -286,7 +287,7 @@ func (a *API) OrderView(w http.ResponseWriter, r *http.Request) error { log := getLogEntry(r) order := &models.Order{} - if result := orderQuery(a.db).First(order, "id = ?", id); result.Error != nil { + if result := orderQuery(a.DB(r)).First(order, "id = ?", id); result.Error != nil { if result.RecordNotFound() { return notFoundError("Order not found") } @@ -338,7 +339,7 @@ func (a *API) OrderCreate(w http.ResponseWriter, r *http.Request) error { "email": params.Email, "currency": params.Currency, }).Debug("Created order, starting to process request") - tx := a.db.Begin() + tx := a.DB(r).Begin() order.IP = r.RemoteAddr order.MetaData = params.MetaData @@ -425,6 +426,7 @@ func (a *API) OrderCreate(w http.ResponseWriter, r *http.Request) error { // There are also blocks to changing certain fields after the state has been locked func (a *API) OrderUpdate(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() + db := a.DB(r) orderID := gcontext.GetOrderID(ctx) log := getLogEntry(r) claims := gcontext.GetClaims(ctx) @@ -440,7 +442,7 @@ func (a *API) OrderUpdate(w http.ResponseWriter, r *http.Request) error { // verify that the order exists existingOrder := new(models.Order) - rsp := orderQuery(a.db).First(existingOrder, "id = ?", orderID) + rsp := orderQuery(db).First(existingOrder, "id = ?", orderID) if rsp.RecordNotFound() { return notFoundError("Failed to find order with id '%s'", orderID) } @@ -486,7 +488,7 @@ func (a *API) OrderUpdate(w http.ResponseWriter, r *http.Request) error { changes = append(changes, "vatnumber") } - tx := a.db.Begin() + tx := db.Begin() // // handle the addresses diff --git a/api/payments.go b/api/payments.go index 40e8bcf..3045fa8 100644 --- a/api/payments.go +++ b/api/payments.go @@ -45,7 +45,7 @@ func (a *API) PaymentListForUser(w http.ResponseWriter, r *http.Request) error { return notFoundError("Couldn't find a record for " + userID) } - trans, httpErr := queryForTransactions(a.db, log, "user_id = ?", userID) + trans, httpErr := queryForTransactions(a.DB(r), log, "user_id = ?", userID) if httpErr != nil { return httpErr } @@ -60,7 +60,7 @@ func (a *API) PaymentListForOrder(w http.ResponseWriter, r *http.Request) error orderID := gcontext.GetOrderID(ctx) claims := gcontext.GetClaims(ctx) - order, httpErr := queryForOrder(a.db, orderID, log) + order, httpErr := queryForOrder(a.DB(r), orderID, log) if httpErr != nil { return httpErr } @@ -138,7 +138,7 @@ func (a *API) PaymentCreate(w http.ResponseWriter, r *http.Request) error { } orderID := gcontext.GetOrderID(ctx) - tx := a.db.Begin() + tx := a.DB(r).Begin() order := &models.Order{} loader := tx. Preload("LineItems"). @@ -237,9 +237,10 @@ func (a *API) PaymentCreate(w http.ResponseWriter, r *http.Request) error { func (a *API) PaymentConfirm(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() log := getLogEntry(r) + db := a.DB(r) payID := chi.URLParam(r, "payment_id") - trans, httpErr := a.getTransaction(payID) + trans, httpErr := getTransaction(db, payID) if httpErr != nil { return httpErr } @@ -260,7 +261,7 @@ func (a *API) PaymentConfirm(w http.ResponseWriter, r *http.Request) error { } order := &models.Order{} - if rsp := a.db.Find(order, "id = ?", trans.OrderID); rsp.Error != nil { + if rsp := db.Find(order, "id = ?", trans.OrderID); rsp.Error != nil { if rsp.RecordNotFound() { return notFoundError("Order not found") } @@ -286,7 +287,7 @@ func (a *API) PaymentConfirm(w http.ResponseWriter, r *http.Request) error { return internalServerError("Error on provider while trying to confirm: %v. Try again later.", err) } - tx := a.db.Begin() + tx := db.Begin() if trans.InvoiceNumber == 0 { invoiceNumber, err := models.NextInvoiceNumber(tx, order.InstanceID) @@ -311,7 +312,7 @@ func (a *API) PaymentConfirm(w http.ResponseWriter, r *http.Request) error { func (a *API) PaymentList(w http.ResponseWriter, r *http.Request) error { log := getLogEntry(r) instanceID := gcontext.GetInstanceID(r.Context()) - query := a.db.Where("instance_id = ?", instanceID) + query := a.DB(r).Where("instance_id = ?", instanceID) query, err := parsePaymentQueryParams(query, r.URL.Query()) if err != nil { @@ -328,7 +329,7 @@ func (a *API) PaymentList(w http.ResponseWriter, r *http.Request) error { // PaymentView returns information about a single payment. It is only available to admins. func (a *API) PaymentView(w http.ResponseWriter, r *http.Request) error { payID := chi.URLParam(r, "payment_id") - trans, httpErr := a.getTransaction(payID) + trans, httpErr := getTransaction(a.DB(r), payID) if httpErr != nil { return httpErr } @@ -339,6 +340,7 @@ func (a *API) PaymentView(w http.ResponseWriter, r *http.Request) error { // refunds if desired. It is only available to admins. func (a *API) PaymentRefund(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() + db := a.DB(r) config := gcontext.GetConfig(ctx) params := PaymentParams{Currency: "USD"} err := json.NewDecoder(r.Body).Decode(¶ms) @@ -347,7 +349,7 @@ func (a *API) PaymentRefund(w http.ResponseWriter, r *http.Request) error { } payID := chi.URLParam(r, "payment_id") - trans, httpErr := a.getTransaction(payID) + trans, httpErr := getTransaction(db, payID) if httpErr != nil { return httpErr } @@ -369,7 +371,7 @@ func (a *API) PaymentRefund(w http.ResponseWriter, r *http.Request) error { } log := getLogEntry(r) - order, httpErr := queryForOrder(a.db, trans.OrderID, log) + order, httpErr := queryForOrder(db, trans.OrderID, log) if httpErr != nil { return httpErr } @@ -398,7 +400,7 @@ func (a *API) PaymentRefund(w http.ResponseWriter, r *http.Request) error { Status: models.PendingState, } - tx := a.db.Begin() + tx := db.Begin() tx.Create(m) provID := provider.Name() log.Debugf("Starting refund to %s", provID) @@ -480,8 +482,8 @@ func (a *API) PreauthorizePayment(w http.ResponseWriter, r *http.Request) error // ------------------------------------------------------------------------------------------------ // Helpers // ------------------------------------------------------------------------------------------------ -func (a *API) getTransaction(payID string) (*models.Transaction, *HTTPError) { - trans, err := models.GetTransaction(a.db, payID) +func getTransaction(db *gorm.DB, payID string) (*models.Transaction, *HTTPError) { + trans, err := models.GetTransaction(db, payID) if err != nil { return nil, internalServerError("Error while querying for transactions").WithInternalError(err) } diff --git a/api/reports.go b/api/reports.go index 9494c1c..4663fb1 100644 --- a/api/reports.go +++ b/api/reports.go @@ -26,7 +26,7 @@ type productsRow struct { func (a *API) SalesReport(w http.ResponseWriter, r *http.Request) error { instanceID := gcontext.GetInstanceID(r.Context()) - query := a.db. + query := a.DB(r). Model(&models.Order{}). Select("sum(total) as total, sum(sub_total) as subtotal, sum(taxes) as taxes, currency, count(*) as orders"). Where("payment_state = 'paid' AND instance_id = ?", instanceID). @@ -57,10 +57,11 @@ func (a *API) SalesReport(w http.ResponseWriter, r *http.Request) error { // ProductsReport list the products sold within a period func (a *API) ProductsReport(w http.ResponseWriter, r *http.Request) error { + db := a.DB(r) instanceID := gcontext.GetInstanceID(r.Context()) - ordersTable := a.db.NewScope(models.Order{}).QuotedTableName() - itemsTable := a.db.NewScope(models.LineItem{}).QuotedTableName() - query := a.db. + ordersTable := db.NewScope(models.Order{}).QuotedTableName() + itemsTable := db.NewScope(models.LineItem{}).QuotedTableName() + query := db. Model(&models.LineItem{}). Select("sku, path, sum(quantity * price) as total, currency"). Joins("JOIN " + ordersTable + " ON " + ordersTable + ".id = " + itemsTable + ".order_id " + "AND " + ordersTable + ".payment_state = 'paid'"). diff --git a/api/user.go b/api/user.go index ad634d8..7ac84a6 100644 --- a/api/user.go +++ b/api/user.go @@ -20,7 +20,7 @@ func (a *API) withUser(w http.ResponseWriter, r *http.Request) (context.Context, logEntrySetField(r, "user_id", userID) ctx := r.Context() - if u, err := models.GetUser(a.db, userID); err != nil { + if u, err := models.GetUser(a.DB(r), userID); err != nil { return nil, internalServerError("problem while querying for userID: %s", userID).WithInternalError(err) } else if u != nil { ctx = gcontext.WithUser(ctx, u) @@ -81,15 +81,16 @@ func persistUserName(tx *gorm.DB, order *models.Order, claims *claims.JWTClaims) // limit # of records to return (max) func (a *API) UserList(w http.ResponseWriter, r *http.Request) error { log := getLogEntry(r) + db := a.DB(r) - query, err := parseUserQueryParams(a.db, r.URL.Query()) + query, err := parseUserQueryParams(db, r.URL.Query()) if err != nil { return badRequestError("Bad parameters in query: %v", err) } log.Debug("Parsed url params") - orderTable := a.db.NewScope(models.Order{}).QuotedTableName() - userTable := a.db.NewScope(models.User{}).QuotedTableName() + orderTable := db.NewScope(models.Order{}).QuotedTableName() + userTable := db.NewScope(models.User{}).QuotedTableName() query = query. Joins("LEFT JOIN " + orderTable + " ON " + userTable + ".id = " + orderTable + ".user_id"). Group(userTable + ".id") @@ -131,7 +132,7 @@ func (a *API) UserView(w http.ResponseWriter, r *http.Request) error { } orders := []models.Order{} - a.db.Where("user_id = ?", user.ID).Find(&orders).Count(&user.OrderCount) + a.DB(r).Where("user_id = ?", user.ID).Find(&orders).Count(&user.OrderCount) return sendJSON(w, http.StatusOK, user) } @@ -146,7 +147,7 @@ func (a *API) AddressList(w http.ResponseWriter, r *http.Request) error { } addrs := []models.Address{} - results := a.db.Where("user_id = ?", userID).Find(&addrs) + results := a.DB(r).Where("user_id = ?", userID).Find(&addrs) if results.Error != nil { return internalServerError("problem while querying for userID: %s", userID).WithInternalError(results.Error) } @@ -168,7 +169,7 @@ func (a *API) AddressView(w http.ResponseWriter, r *http.Request) error { ID: addrID, UserID: userID, } - results := a.db.First(addr) + results := a.DB(r).First(addr) if results.Error != nil { return internalServerError("problem while querying for userID: %s", userID).WithInternalError(results.Error) } @@ -190,7 +191,7 @@ func (a *API) UserDelete(w http.ResponseWriter, r *http.Request) error { return nil } - rsp := a.db.Delete(user) + rsp := a.DB(r).Delete(user) if rsp.Error != nil { return internalServerError("error while deleting user").WithInternalError(rsp.Error) } @@ -201,8 +202,9 @@ func (a *API) UserDelete(w http.ResponseWriter, r *http.Request) error { func (a *API) UserBulkDelete(w http.ResponseWriter, r *http.Request) error { log := getLogEntry(r) + db := a.DB(r) - query, err := parseUserBulkDeleteParams(a.db, r.URL.Query()) + query, err := parseUserBulkDeleteParams(db, r.URL.Query()) if err != nil { return badRequestError("Bad parameters in query: %v", err) } @@ -212,7 +214,7 @@ func (a *API) UserBulkDelete(w http.ResponseWriter, r *http.Request) error { return internalServerError("error while deleting user").WithInternalError(result.Error) } - tx := a.db.Begin() + tx := db.Begin() defer func() { if r := recover(); r != nil { tx.Rollback() @@ -246,7 +248,7 @@ func (a *API) AddressDelete(w http.ResponseWriter, r *http.Request) error { return nil } - rsp := a.db.Delete(&models.Address{ID: addrID}) + rsp := a.DB(r).Delete(&models.Address{ID: addrID}) if rsp.RecordNotFound() { log.Warn("Attempted to delete an address that doesn't exist") return nil @@ -282,7 +284,7 @@ func (a *API) CreateNewAddress(w http.ResponseWriter, r *http.Request) error { ID: uuid.NewRandom().String(), UserID: userID, } - rsp := a.db.Create(&addr) + rsp := a.DB(r).Create(&addr) if rsp.Error != nil { return internalServerError("failed to save address").WithInternalError(rsp.Error) } diff --git a/context/context.go b/context/context.go index fe93089..ffc43e4 100644 --- a/context/context.go +++ b/context/context.go @@ -6,6 +6,7 @@ import ( "context" "github.com/dgrijalva/jwt-go" + "github.com/jinzhu/gorm" "github.com/netlify/gocommerce/assetstores" "github.com/netlify/gocommerce/claims" @@ -36,6 +37,7 @@ const ( orderIDKey = contextKey("order_id") instanceIDKey = contextKey("instance_id") instanceKey = contextKey("instance") + dbKey = contextKey("db") ) // WithConfig adds the tenant configuration to the context. @@ -252,3 +254,17 @@ func GetInstance(ctx context.Context) *models.Instance { } return obj.(*models.Instance) } + +// GetDB reads the database from the context. +func GetDB(ctx context.Context) *gorm.DB { + obj := ctx.Value(dbKey) + if obj == nil { + return nil + } + return obj.(*gorm.DB) +} + +// WithDB adds the database to the context. +func WithDB(ctx context.Context, db *gorm.DB) context.Context { + return context.WithValue(ctx, dbKey, db) +} diff --git a/models/connection_logger.go b/models/connection_logger.go index ddbd999..b61f6a0 100644 --- a/models/connection_logger.go +++ b/models/connection_logger.go @@ -23,11 +23,7 @@ func (dbl *DBLogger) Print(params ...interface{}) { } level := params[0] - log := dbl.WithField("gorm_level", level) - - if entry, ok := dbl.FieldLogger.(*logrus.Entry); ok && entry.Logger.Level >= logrus.TraceLevel { - log = log.WithField("gorm_source", params[1]) - } + log := dbl.WithField("gorm_level", level).WithField("db_src", params[1]) if level != "sql" { log.Debug(params[2:]...)