Skip to content

Commit

Permalink
Merge pull request #24 from TRON-US/support-db-timeout
Browse files Browse the repository at this point in the history
Add ability to change db timeout on the fly and support passing new db
  • Loading branch information
Eric Chen authored and Robin committed Sep 19, 2020
2 parents eaa3832 + 1699b99 commit 349d439
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 13 deletions.
10 changes: 8 additions & 2 deletions db/postgres/tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package postgres

import (
"context"
"time"

"github.com/tron-us/go-common/v2/constant"

Expand All @@ -11,13 +12,18 @@ import (

// RunInTransactionContext wraps around underlying go-pg's rollback-supported transaction execution
// with our custom context so it can be easily passed down.
func (db *TGPGDB) RunInTransactionContext(ctx context.Context, txFunc func(context.Context) error) error {
func (db *TGPGDB) RunInTransactionContext(ctx context.Context, txFunc func(context.Context, *TGPGDB) error) error {
return db.DB.RunInTransaction(func(tx *pg.Tx) error {
// Pass ctx with tx object down to the transaction execution
return txFunc(context.WithValue(ctx, constant.PostgresTxContext, tx))
return txFunc(context.WithValue(ctx, constant.PostgresTxContext, tx), db)
})
}

// WithTimeout needs to create a new db instance in order to pass into the next chain of commands.
func (db *TGPGDB) WithTimeout(timeout time.Duration) *TGPGDB {
return NewTGPGDB(db.DB.WithTimeout(timeout))
}

// The following functions are necessary to override include to support both transaction
// and transaction-less queries through the ctx's tx existence.

Expand Down
26 changes: 15 additions & 11 deletions env/db/pg.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"os"
"strconv"
"time"

Expand Down Expand Up @@ -34,6 +35,9 @@ var (

func init() {

if os.Getenv(env.EnvNamePrefix) == "" && os.Getenv(env.EnvNamePrefixEnv) == "" {
return
}
if _, duW := env.GetEnv("DB_URL"); duW != "" {
DBWriteURL = duW
// if slave url passed, use it as read default
Expand All @@ -50,68 +54,68 @@ func init() {
writeUserName = _un
} else {
hasError = true
log.Debug(constant.EmptyVarError, zap.String("env", envDbWUnKey))
log.Warn(constant.EmptyVarError, zap.String("env", envDbWUnKey))
}

envDbWPwdKey, _pwd := env.GetEnv("DB_WRITE_PASSWORD")
if _pwd != "" {
writePwd = _pwd
} else {
hasError = true
log.Debug(constant.EmptyVarError, zap.String("env", envDbWPwdKey))
log.Warn(constant.EmptyVarError, zap.String("env", envDbWPwdKey))
}

envDbWHost, _host := env.GetEnv("DB_WRITE_HOST")
if _host != "" {
writeHost = _host
} else {
hasError = true
log.Debug(constant.EmptyVarError, zap.String("env", envDbWHost))
log.Warn(constant.EmptyVarError, zap.String("env", envDbWHost))
}

envDbWNameKey, _dbName := env.GetEnv("DB_WRITE_NAME")
if _dbName != "" {
writeDbName = _dbName
} else {
hasError = true
log.Debug(constant.EmptyVarError, zap.String("env", envDbWNameKey))
log.Warn(constant.EmptyVarError, zap.String("env", envDbWNameKey))
}
envDbRUnKey, un := env.GetEnv("DB_READ_USERNAME")
if un != "" {
readUserName = un
} else {
hasError = true
log.Debug(constant.EmptyVarError, zap.String("env", envDbRUnKey))
log.Warn(constant.EmptyVarError, zap.String("env", envDbRUnKey))
}

envDbRPwdKey, pwd := env.GetEnv("DB_READ_PASSWORD")
if pwd != "" {
readPwd = pwd
} else {
hasError = true
log.Debug(constant.EmptyVarError, zap.String("env", envDbRPwdKey))
log.Warn(constant.EmptyVarError, zap.String("env", envDbRPwdKey))
}

envDbRHostKey, host := env.GetEnv("DB_READ_HOST")
if host != "" {
readHost = host
} else {
hasError = true
log.Debug(constant.EmptyVarError, zap.String("env", envDbRHostKey))
log.Warn(constant.EmptyVarError, zap.String("env", envDbRHostKey))
}

envDbRNameKey, dbName := env.GetEnv("DB_READ_NAME")
if dbName != "" {
readDbName = dbName
} else {
hasError = true
log.Debug(constant.EmptyVarError, zap.String("env", envDbRNameKey))
log.Warn(constant.EmptyVarError, zap.String("env", envDbRNameKey))
}

DBWriteURL = "postgresql://" + writeUserName + ":" + writePwd + "@" + writeHost + ":5432/" + writeDbName
DBReadURL = "postgresql://" + readUserName + ":" + readPwd + "@" + readHost + ":5432/" + readDbName
if hasError {
log.Debug("error to get DBWriteUrl or DBReadURL from env", zap.String("WriteUrl", DBWriteURL), zap.String("ReadUrl", DBReadURL))
log.Warn("error to get DBWriteUrl or DBReadURL from env", zap.String("WriteUrl", DBWriteURL), zap.String("ReadUrl", DBReadURL))
}
}

Expand All @@ -120,14 +124,14 @@ func init() {
}
if envKey, dbst := env.GetEnv("DB_STMT_TIMEOUT"); dbst != "" {
if toInt, err := strconv.ParseInt(dbst, 10, 64); err != nil {
log.Debug(constant.IntConversionError, zap.String("env", envKey), zap.Error(err))
log.Warn(constant.IntConversionError, zap.String("env", envKey), zap.Error(err))
} else {
DBStmtTimeout = time.Duration(toInt) * time.Second
}
}
if envKey, dbnc := env.GetEnv("DB_NUM_CONNS"); dbnc != "" {
if toInt, err := strconv.ParseInt(dbnc, 10, 64); err != nil {
log.Debug(constant.IntConversionError, zap.String("env", envKey), zap.Error(err))
log.Warn(constant.IntConversionError, zap.String("env", envKey), zap.Error(err))
} else {
DBNumConns = int(toInt)
}
Expand Down

0 comments on commit 349d439

Please sign in to comment.