diff --git a/pkg/config/router.go b/pkg/config/router.go index 6b97d7b63..29052b917 100644 --- a/pkg/config/router.go +++ b/pkg/config/router.go @@ -96,6 +96,7 @@ type BackendRule struct { ConnectionRetries int `json:"connection_retries" yaml:"connection_retries" toml:"connection_retries"` ConnectionTimeout time.Duration `json:"connection_timeout" yaml:"connection_timeout" toml:"connection_timeout"` KeepAlive time.Duration `json:"keep_alive" yaml:"keep_alive" toml:"keep_alive"` + TcpUserTimeout time.Duration `json:"tcp_user_timeout" yaml:"tcp_user_timeout" toml:"tcp_user_timeout"` } type FrontendRule struct { diff --git a/pkg/conn/instance.go b/pkg/conn/instance.go index e1570e1b5..2ef564eb1 100644 --- a/pkg/conn/instance.go +++ b/pkg/conn/instance.go @@ -5,6 +5,8 @@ import ( "encoding/binary" "fmt" "net" + "os" + "syscall" "time" "github.com/jackc/pgx/v5/pgproto3" @@ -157,6 +159,24 @@ func (pgi *PostgreSQLInstance) Receive() (pgproto3.BackendMessage, error) { return pgi.frontend.Receive() } +func setTCPUserTimeout(d time.Duration) func(string, string, syscall.RawConn) error { + return func(network, address string, c syscall.RawConn) error { + var sysErr error + var err = c.Control(func(fd uintptr) { + /* + #define TCP_USER_TIMEOUT 18 // How long for loss retry before timeout + */ + + sysErr = syscall.SetsockoptInt(int(fd), syscall.SOL_TCP, 0x12, + int(d.Milliseconds())) + }) + if sysErr != nil { + return os.NewSyscallError("setsockopt", sysErr) + } + return err + } +} + // NewInstanceConn creates a new instance connection to a PostgreSQL database. // // Parameters: @@ -166,10 +186,12 @@ func (pgi *PostgreSQLInstance) Receive() (pgproto3.BackendMessage, error) { // // Return: // - (DBInstance, error): The newly created instance connection and any error that occurred. -func NewInstanceConn(host string, shard string, tlsconfig *tls.Config, timout time.Duration, keepAlive time.Duration) (DBInstance, error) { +func NewInstanceConn(host string, shard string, tlsconfig *tls.Config, timout time.Duration, keepAlive time.Duration, tcpUserTimeout time.Duration) (DBInstance, error) { dd := net.Dialer{ Timeout: timout, KeepAlive: keepAlive, + + Control: setTCPUserTimeout(tcpUserTimeout), } netconn, err := dd.Dial("tcp", host) diff --git a/pkg/datashard/datashard.go b/pkg/datashard/datashard.go index be32a60cb..216c7df66 100644 --- a/pkg/datashard/datashard.go +++ b/pkg/datashard/datashard.go @@ -153,9 +153,10 @@ func (sh *Conn) Cancel() error { pgiTmp, err := conn.NewInstanceConn( sh.dedicated.Hostname(), sh.dedicated.ShardName(), - nil /* no tls for cancel */, + nil, /* no tls for cancel */ time.Second, time.Second, + time.Millisecond*9500, ) if err != nil { return err diff --git a/pkg/pool/dbpool.go b/pkg/pool/dbpool.go index acbf042c8..281a6c55b 100644 --- a/pkg/pool/dbpool.go +++ b/pkg/pool/dbpool.go @@ -434,7 +434,13 @@ func NewDBPool(mapping map[string]*config.Shard, sp *startup.StartupParams) DBPo keepAlive = rule.KeepAlive } - pgi, err := conn.NewInstanceConn(host, shardKey.Name, tlsconfig, connTimeout, keepAlive) + tcpUserTimeout := defaultTcpUserTimeout + + if rule.TcpUserTimeout != 0 { + tcpUserTimeout = rule.TcpUserTimeout + } + + pgi, err := conn.NewInstanceConn(host, shardKey.Name, tlsconfig, connTimeout, keepAlive, tcpUserTimeout) if err != nil { return nil, err } diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go index 5afdf7469..b1ad75dc9 100644 --- a/pkg/pool/pool.go +++ b/pkg/pool/pool.go @@ -13,6 +13,7 @@ const ( defaultInstanceConnectionRetries = 10 defaultInstanceConnectionTimeout = time.Second defaultKeepAlive = time.Second + defaultTcpUserTimeout = time.Millisecond * 9500 ) type ConnectionKepper interface { diff --git a/router/qrouter/local.go b/router/qrouter/local.go index fc5814182..18bc0c748 100644 --- a/router/qrouter/local.go +++ b/router/qrouter/local.go @@ -76,3 +76,14 @@ func (l *LocalQrouter) Route(_ context.Context, _ lyx.Node, _ session.SessionPar func (l *LocalQrouter) ListKeyRanges(ctx context.Context) ([]*kr.KeyRange, error) { return nil, nil } + +// TODO : unit tests +func (l *LocalQrouter) DataShardsRoutes() []*routingstate.DataShardRoute { + return []*routingstate.DataShardRoute{ + &routingstate.DataShardRoute{Shkey: kr.ShardKey{ + Name: l.ds.ID, + RW: false, + }, + }, + } +}