Skip to content

Commit

Permalink
Prefer hosts in the same availability zone (#813)
Browse files Browse the repository at this point in the history
* introduce AZ, some workarounds

* on my way to fix build

* finnally fix build

* introduces strategies

* small fix

* add AZ everywhere

* fix build

* refactor code a bit

* add tests, refactor a bit

* InstancePoolImpl -> DBPool

* drop mirror.go
  • Loading branch information
Denchick authored Nov 15, 2024
1 parent 054cb1e commit cae2320
Show file tree
Hide file tree
Showing 25 changed files with 447 additions and 488 deletions.
7 changes: 3 additions & 4 deletions coordinator/provider/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"github.com/pg-sharding/spqr/pkg/datatransfers"
"github.com/pg-sharding/spqr/pkg/meta"
"github.com/pg-sharding/spqr/pkg/models/topology"
proto "github.com/pg-sharding/spqr/pkg/protos"
"github.com/pg-sharding/spqr/pkg/shard"

"github.com/pg-sharding/spqr/qdb/ops"
Expand Down Expand Up @@ -1852,7 +1851,7 @@ func (qc *qdbCoordinator) ProcClient(ctx context.Context, nconn net.Conn, pt por

// TODO : unit tests
func (qc *qdbCoordinator) AddDataShard(ctx context.Context, shard *datashards.DataShard) error {
return qc.db.AddShard(ctx, qdb.NewShard(shard.ID, shard.Cfg.Hosts))
return qc.db.AddShard(ctx, qdb.NewShard(shard.ID, shard.Cfg.RawHosts))
}

func (qc *qdbCoordinator) AddWorldShard(_ context.Context, _ *datashards.DataShard) error {
Expand All @@ -1876,7 +1875,7 @@ func (qc *qdbCoordinator) ListShards(ctx context.Context) ([]*datashards.DataSha
shards = append(shards, &datashards.DataShard{
ID: shard.ID,
Cfg: &config.Shard{
Hosts: shard.Hosts,
RawHosts: shard.RawHosts,
},
})
}
Expand All @@ -1887,7 +1886,7 @@ func (qc *qdbCoordinator) ListShards(ctx context.Context) ([]*datashards.DataSha
// TODO : unit tests
func (qc *qdbCoordinator) UpdateCoordinator(ctx context.Context, address string) error {
return qc.traverseRouters(ctx, func(cc *grpc.ClientConn) error {
c := proto.NewTopologyServiceClient(cc)
c := routerproto.NewTopologyServiceClient(cc)
spqrlog.Zero.Debug().Str("address", address).Msg("updating coordinator address")
_, err := c.UpdateCoordinator(ctx, &routerproto.UpdateCoordinatorRequest{
Address: address,
Expand Down
5 changes: 2 additions & 3 deletions coordinator/provider/shards.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package provider
import (
"context"

routerproto "github.com/pg-sharding/spqr/pkg/protos"
"github.com/pg-sharding/spqr/pkg/shard"
"github.com/pg-sharding/spqr/pkg/txstatus"
"google.golang.org/protobuf/types/known/emptypb"
Expand Down Expand Up @@ -74,7 +73,7 @@ func (s *ShardServer) GetShard(ctx context.Context, shardRequest *protos.ShardRe
}

type CoordShardInfo struct {
underlying *routerproto.BackendConnectionsInfo
underlying *protos.BackendConnectionsInfo
router string
}

Expand All @@ -98,7 +97,7 @@ func (c *CoordShardInfo) ListPreparedStatements() []shard.PreparedStatementsMgrD
return nil
}

func NewCoordShardInfo(conn *routerproto.BackendConnectionsInfo, router string) shard.Shardinfo {
func NewCoordShardInfo(conn *protos.BackendConnectionsInfo, router string) shard.Shardinfo {
return &CoordShardInfo{
underlying: conn,
router: router,
Expand Down
75 changes: 62 additions & 13 deletions pkg/config/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"log"
"os"
"strings"
"sync"
"time"

"github.com/BurntSushi/toml"
Expand Down Expand Up @@ -44,6 +45,9 @@ type Router struct {
PidFileName string `json:"pid_filename" toml:"pid_filename" yaml:"pid_filename"`
LogFileName string `json:"log_filename" toml:"log_filename" yaml:"log_filename"`

AvailabilityZone string `json:"availability_zone" toml:"availability_zone" yaml:"availability_zone"`
PreferSameAvailabilityZone bool `json:"prefer_same_availability_zone" toml:"prefer_same_availability_zone" yaml:"prefer_same_availability_zone"`

Host string `json:"host" toml:"host" yaml:"host"`
RouterPort string `json:"router_port" toml:"router_port" yaml:"router_port"`
RouterROPort string `json:"router_ro_port" toml:"router_ro_port" yaml:"router_ro_port"`
Expand Down Expand Up @@ -87,16 +91,17 @@ type QRouter struct {
}

type BackendRule struct {
DB string `json:"db" yaml:"db" toml:"db"`
Usr string `json:"usr" yaml:"usr" toml:"usr"`
AuthRules map[string]*AuthBackendCfg `json:"auth_rules" yaml:"auth_rules" toml:"auth_rules"` // TODO validate
DefaultAuthRule *AuthBackendCfg `json:"auth_rule" yaml:"auth_rule" toml:"auth_rule"`
PoolDefault bool `json:"pool_default" yaml:"pool_default" toml:"pool_default"`
ConnectionLimit int `json:"connection_limit" yaml:"connection_limit" toml:"connection_limit"`
ConnectionRetries int `json:"connection_retries" yaml:"connection_retries" toml:"connection_retries"`
ConnectionTimeout time.Duration `json:"connection_timeout" yaml:"connection_timeout" toml:"connection_timeout"`
KeepAlive time.Duration `json:"keep_alive" yaml:"keep_alive" toml:"keep_alive"`
TcpUserTimeout time.Duration `json:"tcp_user_timeout" yaml:"tcp_user_timeout" toml:"tcp_user_timeout"`
DB string `json:"db" yaml:"db" toml:"db"`
Usr string `json:"usr" yaml:"usr" toml:"usr"`
AuthRules map[string]*AuthBackendCfg `json:"auth_rules" yaml:"auth_rules" toml:"auth_rules"` // TODO validate
DefaultAuthRule *AuthBackendCfg `json:"auth_rule" yaml:"auth_rule" toml:"auth_rule"`
PoolDefault bool `json:"pool_default" yaml:"pool_default" toml:"pool_default"`

ConnectionLimit int `json:"connection_limit" yaml:"connection_limit" toml:"connection_limit"`
ConnectionRetries int `json:"connection_retries" yaml:"connection_retries" toml:"connection_retries"`
ConnectionTimeout time.Duration `json:"connection_timeout" yaml:"connection_timeout" toml:"connection_timeout"`
KeepAlive time.Duration `json:"keep_alive" yaml:"keep_alive" toml:"keep_alive"`
TcpUserTimeout time.Duration `json:"tcp_user_timeout" yaml:"tcp_user_timeout" toml:"tcp_user_timeout"`
}

type FrontendRule struct {
Expand All @@ -119,9 +124,53 @@ const (
)

type Shard struct {
Hosts []string `json:"hosts" toml:"hosts" yaml:"hosts"`
Type ShardType `json:"type" toml:"type" yaml:"type"`
TLS *TLSConfig `json:"tls" yaml:"tls" toml:"tls"`
RawHosts []string `json:"hosts" toml:"hosts" yaml:"hosts"` // format host:port:availability_zone
parsedHosts []Host
parsedAddresses []string
once sync.Once

Type ShardType `json:"type" toml:"type" yaml:"type"`
TLS *TLSConfig `json:"tls" yaml:"tls" toml:"tls"`
}

type Host struct {
Address string // format host:port
AZ string // Availability zone
}

// parseHosts parses the raw hosts into a slice of Hosts.
// The format of the RawHost is host:port:availability_zone.
// If the availability_zone is not provided, it is empty.
// If the port is not provided, it does not matter
func (s *Shard) parseHosts() {
for _, rawHost := range s.RawHosts {
host := Host{}
parts := strings.Split(rawHost, ":")
if len(parts) > 3 {
log.Printf("invalid host format: expected 'host:port:availability_zone', got '%s'", rawHost)
continue
} else if len(parts) == 3 {
host.AZ = parts[2]
host.Address = fmt.Sprintf("%s:%s", parts[0], parts[1])
} else {
host.Address = rawHost
}

s.parsedHosts = append(s.parsedHosts, host)
s.parsedAddresses = append(s.parsedAddresses, host.Address)
}
}

func (s *Shard) Hosts() []string {
s.once.Do(s.parseHosts)

return s.parsedAddresses
}

func (s *Shard) HostsAZ() []Host {
s.once.Do(s.parseHosts)

return s.parsedHosts
}

func ValueOrDefaultInt(value int, def int) int {
Expand Down
36 changes: 28 additions & 8 deletions pkg/conn/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type DBInstance interface {
ReqBackendSsl(*tls.Config) error

Hostname() string
AvailabilityZone() string
ShardName() string

Close() error
Expand All @@ -45,6 +46,7 @@ type PostgreSQLInstance struct {
frontend *pgproto3.Frontend

hostname string
az string // availability zone
shardname string
status InstanceStatus

Expand Down Expand Up @@ -118,6 +120,17 @@ func (pgi *PostgreSQLInstance) Hostname() string {
return pgi.hostname
}

// AvailabilityZone returns the availability zone of the PostgreSQLInstance.
//
// Parameters:
// - None.
//
// Returns:
// - string: The availability zone of the PostgreSQLInstance.
func (pgi *PostgreSQLInstance) AvailabilityZone() string {
return pgi.az
}

// ShardName returns the shard name of the PostgreSQLInstance.
//
// Parameters:
Expand Down Expand Up @@ -187,28 +200,35 @@ func setTCPUserTimeout(d time.Duration) func(string, string, syscall.RawConn) er
// NewInstanceConn creates a new instance connection to a PostgreSQL database.
//
// Parameters:
// - host (string): The host of the PostgreSQL database.
// - shard (string): The shard name of the PostgreSQL database.
// - tlsconfig (*tls.Config): The TLS configuration for the connection.
// - host (string): The hostname of the PostgreSQL instance.
// - availabilityZone (string): The availability zone of the PostgreSQL instance.
// - shardname (string): The name of the shard.
// - tlsconfig (*tls.Config): The TLS configuration to use for the SSL/TLS handshake.
// - timeout (time.Duration): The timeout for the connection.
// - keepAlive (time.Duration): The keep alive duration for the connection.
// - tcpUserTimeout (time.Duration): The TCP user timeout duration for the connection.
//
// Return:
// - (DBInstance, error): The newly created instance connection and any error that occurred.
func NewInstanceConn(host string, shard string, tlsconfig *tls.Config, timout time.Duration, keepAlive time.Duration, tcpUserTimeout time.Duration) (DBInstance, error) {
// Returns:
// - DBInstance: The new PostgreSQLInstance.
// - error: An error if there was a problem creating the new PostgreSQLInstance.
func NewInstanceConn(host string, availabilityZone string, shardname string, tlsconfig *tls.Config, timeout time.Duration, keepAlive time.Duration, tcpUserTimeout time.Duration) (DBInstance, error) {
dd := net.Dialer{
Timeout: timout,
Timeout: timeout,
KeepAlive: keepAlive,

Control: setTCPUserTimeout(tcpUserTimeout),
}

// assuming here host is in the form of hostname:port
netconn, err := dd.Dial("tcp", host)
if err != nil {
return nil, err
}

instance := &PostgreSQLInstance{
hostname: host,
shardname: shard,
az: availabilityZone,
shardname: shardname,
conn: netconn,
status: NotInitialized,
tlsconfig: tlsconfig,
Expand Down
4 changes: 2 additions & 2 deletions pkg/coord/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ func (a *Adapter) ListShards(ctx context.Context) ([]*datashards.DataShard, erro
for _, shard := range shards {
ds = append(ds, &datashards.DataShard{
ID: shard.Id,
Cfg: &config.Shard{Hosts: shard.Hosts},
Cfg: &config.Shard{RawHosts: shard.Hosts},
})
}
return ds, err
Expand All @@ -647,7 +647,7 @@ func (a *Adapter) GetShard(ctx context.Context, shardID string) (*datashards.Dat
resp, err := c.GetShard(ctx, &proto.ShardRequest{Id: shardID})
return &datashards.DataShard{
ID: resp.Shard.Id,
Cfg: &config.Shard{Hosts: resp.Shard.Hosts},
Cfg: &config.Shard{RawHosts: resp.Shard.Hosts},
}, err
}

Expand Down
6 changes: 3 additions & 3 deletions pkg/coord/local/clocal.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (lc *LocalCoordinator) ListShards(ctx context.Context) ([]*datashards.DataS
retShards = append(retShards, &datashards.DataShard{
ID: sh.ID,
Cfg: &config.Shard{
Hosts: sh.Hosts,
RawHosts: sh.RawHosts,
},
})
}
Expand Down Expand Up @@ -669,8 +669,8 @@ func (lc *LocalCoordinator) AddDataShard(ctx context.Context, ds *datashards.Dat
lc.DataShardCfgs[ds.ID] = ds.Cfg

return lc.qdb.AddShard(ctx, &qdb.Shard{
ID: ds.ID,
Hosts: ds.Cfg.Hosts,
ID: ds.ID,
RawHosts: ds.Cfg.RawHosts,
})
}

Expand Down
1 change: 1 addition & 0 deletions pkg/datashard/datashard.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ func (sh *Conn) TxServed() int64 {
func (sh *Conn) Cancel() error {
pgiTmp, err := conn.NewInstanceConn(
sh.dedicated.Hostname(),
sh.dedicated.AvailabilityZone(),
sh.dedicated.ShardName(),
nil, /* no tls for cancel */
time.Second,
Expand Down
4 changes: 2 additions & 2 deletions pkg/meta/meta.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ func processCreate(ctx context.Context, astmt spqrparser.Statement, mngr EntityM
return cli.CreateKeyRange(ctx, req)
case *spqrparser.ShardDefinition:
dataShard := datashards.NewDataShard(stmt.Id, &config.Shard{
Hosts: stmt.Hosts,
Type: config.DataShard,
RawHosts: stmt.Hosts,
Type: config.DataShard,
})
if err := mngr.AddDataShard(ctx, dataShard); err != nil {
return err
Expand Down
14 changes: 14 additions & 0 deletions pkg/mock/conn/mock_instance.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit cae2320

Please sign in to comment.