Skip to content

Commit

Permalink
Set TCP USER TIMEOUT (#723)
Browse files Browse the repository at this point in the history
  • Loading branch information
reshke authored Jul 18, 2024
1 parent 59fa35f commit 46815b5
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 3 deletions.
1 change: 1 addition & 0 deletions pkg/config/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ type BackendRule struct {
ConnectionRetries int `json:"connection_retries" yaml:"connection_retries" toml:"connection_retries"`
ConnectionTimeout time.Duration `json:"connection_timeout" yaml:"connection_timeout" toml:"connection_timeout"`
KeepAlive time.Duration `json:"keep_alive" yaml:"keep_alive" toml:"keep_alive"`
TcpUserTimeout time.Duration `json:"tcp_user_timeout" yaml:"tcp_user_timeout" toml:"tcp_user_timeout"`
}

type FrontendRule struct {
Expand Down
24 changes: 23 additions & 1 deletion pkg/conn/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"encoding/binary"
"fmt"
"net"
"os"
"syscall"
"time"

"github.com/jackc/pgx/v5/pgproto3"
Expand Down Expand Up @@ -157,6 +159,24 @@ func (pgi *PostgreSQLInstance) Receive() (pgproto3.BackendMessage, error) {
return pgi.frontend.Receive()
}

func setTCPUserTimeout(d time.Duration) func(string, string, syscall.RawConn) error {
return func(network, address string, c syscall.RawConn) error {
var sysErr error
var err = c.Control(func(fd uintptr) {
/*
#define TCP_USER_TIMEOUT 18 // How long for loss retry before timeout
*/

sysErr = syscall.SetsockoptInt(int(fd), syscall.SOL_TCP, 0x12,
int(d.Milliseconds()))
})
if sysErr != nil {
return os.NewSyscallError("setsockopt", sysErr)
}
return err
}
}

// NewInstanceConn creates a new instance connection to a PostgreSQL database.
//
// Parameters:
Expand All @@ -166,10 +186,12 @@ func (pgi *PostgreSQLInstance) Receive() (pgproto3.BackendMessage, error) {
//
// Return:
// - (DBInstance, error): The newly created instance connection and any error that occurred.
func NewInstanceConn(host string, shard string, tlsconfig *tls.Config, timout time.Duration, keepAlive time.Duration) (DBInstance, error) {
func NewInstanceConn(host string, shard string, tlsconfig *tls.Config, timout time.Duration, keepAlive time.Duration, tcpUserTimeout time.Duration) (DBInstance, error) {
dd := net.Dialer{
Timeout: timout,
KeepAlive: keepAlive,

Control: setTCPUserTimeout(tcpUserTimeout),
}

netconn, err := dd.Dial("tcp", host)
Expand Down
3 changes: 2 additions & 1 deletion pkg/datashard/datashard.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,10 @@ func (sh *Conn) Cancel() error {
pgiTmp, err := conn.NewInstanceConn(
sh.dedicated.Hostname(),
sh.dedicated.ShardName(),
nil /* no tls for cancel */,
nil, /* no tls for cancel */
time.Second,
time.Second,
time.Millisecond*9500,
)
if err != nil {
return err
Expand Down
8 changes: 7 additions & 1 deletion pkg/pool/dbpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,13 @@ func NewDBPool(mapping map[string]*config.Shard, sp *startup.StartupParams) DBPo
keepAlive = rule.KeepAlive
}

pgi, err := conn.NewInstanceConn(host, shardKey.Name, tlsconfig, connTimeout, keepAlive)
tcpUserTimeout := defaultTcpUserTimeout

if rule.TcpUserTimeout != 0 {
tcpUserTimeout = rule.TcpUserTimeout
}

pgi, err := conn.NewInstanceConn(host, shardKey.Name, tlsconfig, connTimeout, keepAlive, tcpUserTimeout)
if err != nil {
return nil, err
}
Expand Down
1 change: 1 addition & 0 deletions pkg/pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const (
defaultInstanceConnectionRetries = 10
defaultInstanceConnectionTimeout = time.Second
defaultKeepAlive = time.Second
defaultTcpUserTimeout = time.Millisecond * 9500
)

type ConnectionKepper interface {
Expand Down
11 changes: 11 additions & 0 deletions router/qrouter/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,14 @@ func (l *LocalQrouter) Route(_ context.Context, _ lyx.Node, _ session.SessionPar
func (l *LocalQrouter) ListKeyRanges(ctx context.Context) ([]*kr.KeyRange, error) {
return nil, nil
}

// TODO : unit tests
func (l *LocalQrouter) DataShardsRoutes() []*routingstate.DataShardRoute {
return []*routingstate.DataShardRoute{
&routingstate.DataShardRoute{Shkey: kr.ShardKey{
Name: l.ds.ID,
RW: false,
},
},
}
}

0 comments on commit 46815b5

Please sign in to comment.