From cae23203cb72b3f216cfd7f27f5b39a9ea19f169 Mon Sep 17 00:00:00 2001 From: Denis Volkov <3149929+Denchick@users.noreply.github.com> Date: Fri, 15 Nov 2024 11:59:31 +0100 Subject: [PATCH] Prefer hosts in the same availability zone (#813) * introduce AZ, some workarounds * on my way to fix build * finnally fix build * introduces strategies * small fix * add AZ everywhere * fix build * refactor code a bit * add tests, refactor a bit * InstancePoolImpl -> DBPool * drop mirror.go --- coordinator/provider/coordinator.go | 7 +- coordinator/provider/shards.go | 5 +- pkg/config/router.go | 75 ++++++++-- pkg/conn/instance.go | 36 +++-- pkg/coord/adapter.go | 4 +- pkg/coord/local/clocal.go | 6 +- pkg/datashard/datashard.go | 1 + pkg/meta/meta.go | 4 +- pkg/mock/conn/mock_instance.go | 14 ++ pkg/mock/pool/mock_pool.go | 164 +-------------------- pkg/models/datashards/datashard.go | 10 +- pkg/pool/dbpool.go | 207 ++++++++++++++------------- pkg/pool/dbpool_test.go | 141 +++++++++++++++--- pkg/pool/pool.go | 15 +- pkg/pool/shardpool.go | 57 ++++---- pkg/pool/shardpool_test.go | 16 +-- pkg/tsa/tsa.go | 1 + qdb/etcdqdb.go | 4 +- qdb/memqdb_test.go | 4 +- qdb/models.go | 8 +- router/qrouter/proxy_routing_test.go | 88 +++--------- router/route/route.go | 11 +- router/server/mirror.go | 42 ------ router/server/multishard.go | 11 +- router/server/shard.go | 4 +- 25 files changed, 447 insertions(+), 488 deletions(-) delete mode 100644 router/server/mirror.go diff --git a/coordinator/provider/coordinator.go b/coordinator/provider/coordinator.go index 40fcc24d4..93c59752e 100644 --- a/coordinator/provider/coordinator.go +++ b/coordinator/provider/coordinator.go @@ -23,7 +23,6 @@ import ( "github.com/pg-sharding/spqr/pkg/datatransfers" "github.com/pg-sharding/spqr/pkg/meta" "github.com/pg-sharding/spqr/pkg/models/topology" - proto "github.com/pg-sharding/spqr/pkg/protos" "github.com/pg-sharding/spqr/pkg/shard" "github.com/pg-sharding/spqr/qdb/ops" @@ -1852,7 +1851,7 @@ func (qc *qdbCoordinator) ProcClient(ctx context.Context, nconn net.Conn, pt por // TODO : unit tests func (qc *qdbCoordinator) AddDataShard(ctx context.Context, shard *datashards.DataShard) error { - return qc.db.AddShard(ctx, qdb.NewShard(shard.ID, shard.Cfg.Hosts)) + return qc.db.AddShard(ctx, qdb.NewShard(shard.ID, shard.Cfg.RawHosts)) } func (qc *qdbCoordinator) AddWorldShard(_ context.Context, _ *datashards.DataShard) error { @@ -1876,7 +1875,7 @@ func (qc *qdbCoordinator) ListShards(ctx context.Context) ([]*datashards.DataSha shards = append(shards, &datashards.DataShard{ ID: shard.ID, Cfg: &config.Shard{ - Hosts: shard.Hosts, + RawHosts: shard.RawHosts, }, }) } @@ -1887,7 +1886,7 @@ func (qc *qdbCoordinator) ListShards(ctx context.Context) ([]*datashards.DataSha // TODO : unit tests func (qc *qdbCoordinator) UpdateCoordinator(ctx context.Context, address string) error { return qc.traverseRouters(ctx, func(cc *grpc.ClientConn) error { - c := proto.NewTopologyServiceClient(cc) + c := routerproto.NewTopologyServiceClient(cc) spqrlog.Zero.Debug().Str("address", address).Msg("updating coordinator address") _, err := c.UpdateCoordinator(ctx, &routerproto.UpdateCoordinatorRequest{ Address: address, diff --git a/coordinator/provider/shards.go b/coordinator/provider/shards.go index 754ee4331..5a2cd33ec 100644 --- a/coordinator/provider/shards.go +++ b/coordinator/provider/shards.go @@ -3,7 +3,6 @@ package provider import ( "context" - routerproto "github.com/pg-sharding/spqr/pkg/protos" "github.com/pg-sharding/spqr/pkg/shard" "github.com/pg-sharding/spqr/pkg/txstatus" "google.golang.org/protobuf/types/known/emptypb" @@ -74,7 +73,7 @@ func (s *ShardServer) GetShard(ctx context.Context, shardRequest *protos.ShardRe } type CoordShardInfo struct { - underlying *routerproto.BackendConnectionsInfo + underlying *protos.BackendConnectionsInfo router string } @@ -98,7 +97,7 @@ func (c *CoordShardInfo) ListPreparedStatements() []shard.PreparedStatementsMgrD return nil } -func NewCoordShardInfo(conn *routerproto.BackendConnectionsInfo, router string) shard.Shardinfo { +func NewCoordShardInfo(conn *protos.BackendConnectionsInfo, router string) shard.Shardinfo { return &CoordShardInfo{ underlying: conn, router: router, diff --git a/pkg/config/router.go b/pkg/config/router.go index 5fdcbe43a..05d5e4dce 100644 --- a/pkg/config/router.go +++ b/pkg/config/router.go @@ -6,6 +6,7 @@ import ( "log" "os" "strings" + "sync" "time" "github.com/BurntSushi/toml" @@ -44,6 +45,9 @@ type Router struct { PidFileName string `json:"pid_filename" toml:"pid_filename" yaml:"pid_filename"` LogFileName string `json:"log_filename" toml:"log_filename" yaml:"log_filename"` + AvailabilityZone string `json:"availability_zone" toml:"availability_zone" yaml:"availability_zone"` + PreferSameAvailabilityZone bool `json:"prefer_same_availability_zone" toml:"prefer_same_availability_zone" yaml:"prefer_same_availability_zone"` + Host string `json:"host" toml:"host" yaml:"host"` RouterPort string `json:"router_port" toml:"router_port" yaml:"router_port"` RouterROPort string `json:"router_ro_port" toml:"router_ro_port" yaml:"router_ro_port"` @@ -87,16 +91,17 @@ type QRouter struct { } type BackendRule struct { - DB string `json:"db" yaml:"db" toml:"db"` - Usr string `json:"usr" yaml:"usr" toml:"usr"` - AuthRules map[string]*AuthBackendCfg `json:"auth_rules" yaml:"auth_rules" toml:"auth_rules"` // TODO validate - DefaultAuthRule *AuthBackendCfg `json:"auth_rule" yaml:"auth_rule" toml:"auth_rule"` - PoolDefault bool `json:"pool_default" yaml:"pool_default" toml:"pool_default"` - ConnectionLimit int `json:"connection_limit" yaml:"connection_limit" toml:"connection_limit"` - ConnectionRetries int `json:"connection_retries" yaml:"connection_retries" toml:"connection_retries"` - ConnectionTimeout time.Duration `json:"connection_timeout" yaml:"connection_timeout" toml:"connection_timeout"` - KeepAlive time.Duration `json:"keep_alive" yaml:"keep_alive" toml:"keep_alive"` - TcpUserTimeout time.Duration `json:"tcp_user_timeout" yaml:"tcp_user_timeout" toml:"tcp_user_timeout"` + DB string `json:"db" yaml:"db" toml:"db"` + Usr string `json:"usr" yaml:"usr" toml:"usr"` + AuthRules map[string]*AuthBackendCfg `json:"auth_rules" yaml:"auth_rules" toml:"auth_rules"` // TODO validate + DefaultAuthRule *AuthBackendCfg `json:"auth_rule" yaml:"auth_rule" toml:"auth_rule"` + PoolDefault bool `json:"pool_default" yaml:"pool_default" toml:"pool_default"` + + ConnectionLimit int `json:"connection_limit" yaml:"connection_limit" toml:"connection_limit"` + ConnectionRetries int `json:"connection_retries" yaml:"connection_retries" toml:"connection_retries"` + ConnectionTimeout time.Duration `json:"connection_timeout" yaml:"connection_timeout" toml:"connection_timeout"` + KeepAlive time.Duration `json:"keep_alive" yaml:"keep_alive" toml:"keep_alive"` + TcpUserTimeout time.Duration `json:"tcp_user_timeout" yaml:"tcp_user_timeout" toml:"tcp_user_timeout"` } type FrontendRule struct { @@ -119,9 +124,53 @@ const ( ) type Shard struct { - Hosts []string `json:"hosts" toml:"hosts" yaml:"hosts"` - Type ShardType `json:"type" toml:"type" yaml:"type"` - TLS *TLSConfig `json:"tls" yaml:"tls" toml:"tls"` + RawHosts []string `json:"hosts" toml:"hosts" yaml:"hosts"` // format host:port:availability_zone + parsedHosts []Host + parsedAddresses []string + once sync.Once + + Type ShardType `json:"type" toml:"type" yaml:"type"` + TLS *TLSConfig `json:"tls" yaml:"tls" toml:"tls"` +} + +type Host struct { + Address string // format host:port + AZ string // Availability zone +} + +// parseHosts parses the raw hosts into a slice of Hosts. +// The format of the RawHost is host:port:availability_zone. +// If the availability_zone is not provided, it is empty. +// If the port is not provided, it does not matter +func (s *Shard) parseHosts() { + for _, rawHost := range s.RawHosts { + host := Host{} + parts := strings.Split(rawHost, ":") + if len(parts) > 3 { + log.Printf("invalid host format: expected 'host:port:availability_zone', got '%s'", rawHost) + continue + } else if len(parts) == 3 { + host.AZ = parts[2] + host.Address = fmt.Sprintf("%s:%s", parts[0], parts[1]) + } else { + host.Address = rawHost + } + + s.parsedHosts = append(s.parsedHosts, host) + s.parsedAddresses = append(s.parsedAddresses, host.Address) + } +} + +func (s *Shard) Hosts() []string { + s.once.Do(s.parseHosts) + + return s.parsedAddresses +} + +func (s *Shard) HostsAZ() []Host { + s.once.Do(s.parseHosts) + + return s.parsedHosts } func ValueOrDefaultInt(value int, def int) int { diff --git a/pkg/conn/instance.go b/pkg/conn/instance.go index bacd2a167..856944113 100644 --- a/pkg/conn/instance.go +++ b/pkg/conn/instance.go @@ -29,6 +29,7 @@ type DBInstance interface { ReqBackendSsl(*tls.Config) error Hostname() string + AvailabilityZone() string ShardName() string Close() error @@ -45,6 +46,7 @@ type PostgreSQLInstance struct { frontend *pgproto3.Frontend hostname string + az string // availability zone shardname string status InstanceStatus @@ -118,6 +120,17 @@ func (pgi *PostgreSQLInstance) Hostname() string { return pgi.hostname } +// AvailabilityZone returns the availability zone of the PostgreSQLInstance. +// +// Parameters: +// - None. +// +// Returns: +// - string: The availability zone of the PostgreSQLInstance. +func (pgi *PostgreSQLInstance) AvailabilityZone() string { + return pgi.az +} + // ShardName returns the shard name of the PostgreSQLInstance. // // Parameters: @@ -187,20 +200,26 @@ func setTCPUserTimeout(d time.Duration) func(string, string, syscall.RawConn) er // NewInstanceConn creates a new instance connection to a PostgreSQL database. // // Parameters: -// - host (string): The host of the PostgreSQL database. -// - shard (string): The shard name of the PostgreSQL database. -// - tlsconfig (*tls.Config): The TLS configuration for the connection. +// - host (string): The hostname of the PostgreSQL instance. +// - availabilityZone (string): The availability zone of the PostgreSQL instance. +// - shardname (string): The name of the shard. +// - tlsconfig (*tls.Config): The TLS configuration to use for the SSL/TLS handshake. +// - timeout (time.Duration): The timeout for the connection. +// - keepAlive (time.Duration): The keep alive duration for the connection. +// - tcpUserTimeout (time.Duration): The TCP user timeout duration for the connection. // -// Return: -// - (DBInstance, error): The newly created instance connection and any error that occurred. -func NewInstanceConn(host string, shard string, tlsconfig *tls.Config, timout time.Duration, keepAlive time.Duration, tcpUserTimeout time.Duration) (DBInstance, error) { +// Returns: +// - DBInstance: The new PostgreSQLInstance. +// - error: An error if there was a problem creating the new PostgreSQLInstance. +func NewInstanceConn(host string, availabilityZone string, shardname string, tlsconfig *tls.Config, timeout time.Duration, keepAlive time.Duration, tcpUserTimeout time.Duration) (DBInstance, error) { dd := net.Dialer{ - Timeout: timout, + Timeout: timeout, KeepAlive: keepAlive, Control: setTCPUserTimeout(tcpUserTimeout), } + // assuming here host is in the form of hostname:port netconn, err := dd.Dial("tcp", host) if err != nil { return nil, err @@ -208,7 +227,8 @@ func NewInstanceConn(host string, shard string, tlsconfig *tls.Config, timout ti instance := &PostgreSQLInstance{ hostname: host, - shardname: shard, + az: availabilityZone, + shardname: shardname, conn: netconn, status: NotInitialized, tlsconfig: tlsconfig, diff --git a/pkg/coord/adapter.go b/pkg/coord/adapter.go index c19ca00e7..f8875e391 100644 --- a/pkg/coord/adapter.go +++ b/pkg/coord/adapter.go @@ -625,7 +625,7 @@ func (a *Adapter) ListShards(ctx context.Context) ([]*datashards.DataShard, erro for _, shard := range shards { ds = append(ds, &datashards.DataShard{ ID: shard.Id, - Cfg: &config.Shard{Hosts: shard.Hosts}, + Cfg: &config.Shard{RawHosts: shard.Hosts}, }) } return ds, err @@ -647,7 +647,7 @@ func (a *Adapter) GetShard(ctx context.Context, shardID string) (*datashards.Dat resp, err := c.GetShard(ctx, &proto.ShardRequest{Id: shardID}) return &datashards.DataShard{ ID: resp.Shard.Id, - Cfg: &config.Shard{Hosts: resp.Shard.Hosts}, + Cfg: &config.Shard{RawHosts: resp.Shard.Hosts}, }, err } diff --git a/pkg/coord/local/clocal.go b/pkg/coord/local/clocal.go index 57dae711c..8dc2cc577 100644 --- a/pkg/coord/local/clocal.go +++ b/pkg/coord/local/clocal.go @@ -282,7 +282,7 @@ func (lc *LocalCoordinator) ListShards(ctx context.Context) ([]*datashards.DataS retShards = append(retShards, &datashards.DataShard{ ID: sh.ID, Cfg: &config.Shard{ - Hosts: sh.Hosts, + RawHosts: sh.RawHosts, }, }) } @@ -669,8 +669,8 @@ func (lc *LocalCoordinator) AddDataShard(ctx context.Context, ds *datashards.Dat lc.DataShardCfgs[ds.ID] = ds.Cfg return lc.qdb.AddShard(ctx, &qdb.Shard{ - ID: ds.ID, - Hosts: ds.Cfg.Hosts, + ID: ds.ID, + RawHosts: ds.Cfg.RawHosts, }) } diff --git a/pkg/datashard/datashard.go b/pkg/datashard/datashard.go index f9322638e..4a39abbfc 100644 --- a/pkg/datashard/datashard.go +++ b/pkg/datashard/datashard.go @@ -154,6 +154,7 @@ func (sh *Conn) TxServed() int64 { func (sh *Conn) Cancel() error { pgiTmp, err := conn.NewInstanceConn( sh.dedicated.Hostname(), + sh.dedicated.AvailabilityZone(), sh.dedicated.ShardName(), nil, /* no tls for cancel */ time.Second, diff --git a/pkg/meta/meta.go b/pkg/meta/meta.go index c88001d1c..17fdafaed 100644 --- a/pkg/meta/meta.go +++ b/pkg/meta/meta.go @@ -227,8 +227,8 @@ func processCreate(ctx context.Context, astmt spqrparser.Statement, mngr EntityM return cli.CreateKeyRange(ctx, req) case *spqrparser.ShardDefinition: dataShard := datashards.NewDataShard(stmt.Id, &config.Shard{ - Hosts: stmt.Hosts, - Type: config.DataShard, + RawHosts: stmt.Hosts, + Type: config.DataShard, }) if err := mngr.AddDataShard(ctx, dataShard); err != nil { return err diff --git a/pkg/mock/conn/mock_instance.go b/pkg/mock/conn/mock_instance.go index 8b96ea135..9862757df 100644 --- a/pkg/mock/conn/mock_instance.go +++ b/pkg/mock/conn/mock_instance.go @@ -36,6 +36,20 @@ func (m *MockDBInstance) EXPECT() *MockDBInstanceMockRecorder { return m.recorder } +// AvailabilityZone mocks base method. +func (m *MockDBInstance) AvailabilityZone() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AvailabilityZone") + ret0, _ := ret[0].(string) + return ret0 +} + +// AvailabilityZone indicates an expected call of AvailabilityZone. +func (mr *MockDBInstanceMockRecorder) AvailabilityZone() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AvailabilityZone", reflect.TypeOf((*MockDBInstance)(nil).AvailabilityZone)) +} + // Cancel mocks base method. func (m *MockDBInstance) Cancel(csm *pgproto3.CancelRequest) error { m.ctrl.T.Helper() diff --git a/pkg/mock/pool/mock_pool.go b/pkg/mock/pool/mock_pool.go index 17aa4e881..a0f613369 100644 --- a/pkg/mock/pool/mock_pool.go +++ b/pkg/mock/pool/mock_pool.go @@ -12,7 +12,6 @@ import ( kr "github.com/pg-sharding/spqr/pkg/models/kr" pool "github.com/pg-sharding/spqr/pkg/pool" shard "github.com/pg-sharding/spqr/pkg/shard" - tsa "github.com/pg-sharding/spqr/pkg/tsa" ) // MockConnectionKepper is a mock of ConnectionKepper interface. @@ -198,7 +197,7 @@ func (m *MockMultiShardPool) EXPECT() *MockMultiShardPoolMockRecorder { } // ConnectionHost mocks base method. -func (m *MockMultiShardPool) ConnectionHost(clid uint, shardKey kr.ShardKey, host string) (shard.Shard, error) { +func (m *MockMultiShardPool) ConnectionHost(clid uint, shardKey kr.ShardKey, host config.Host) (shard.Shard, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ConnectionHost", clid, shardKey, host) ret0, _ := ret[0].(shard.Shard) @@ -330,164 +329,3 @@ func (mr *MockPoolIteratorMockRecorder) ForEachPool(cb interface{}) *gomock.Call mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForEachPool", reflect.TypeOf((*MockPoolIterator)(nil).ForEachPool), cb) } - -// MockDBPool is a mock of DBPool interface. -type MockDBPool struct { - ctrl *gomock.Controller - recorder *MockDBPoolMockRecorder -} - -// MockDBPoolMockRecorder is the mock recorder for MockDBPool. -type MockDBPoolMockRecorder struct { - mock *MockDBPool -} - -// NewMockDBPool creates a new mock instance. -func NewMockDBPool(ctrl *gomock.Controller) *MockDBPool { - mock := &MockDBPool{ctrl: ctrl} - mock.recorder = &MockDBPoolMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockDBPool) EXPECT() *MockDBPoolMockRecorder { - return m.recorder -} - -// ConnectionHost mocks base method. -func (m *MockDBPool) ConnectionHost(clid uint, shardKey kr.ShardKey, host string) (shard.Shard, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ConnectionHost", clid, shardKey, host) - ret0, _ := ret[0].(shard.Shard) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ConnectionHost indicates an expected call of ConnectionHost. -func (mr *MockDBPoolMockRecorder) ConnectionHost(clid, shardKey, host interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionHost", reflect.TypeOf((*MockDBPool)(nil).ConnectionHost), clid, shardKey, host) -} - -// ConnectionWithTSA mocks base method. -func (m *MockDBPool) ConnectionWithTSA(clid uint, shardKey kr.ShardKey, tsa tsa.TSA) (shard.Shard, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ConnectionWithTSA", clid, shardKey, tsa) - ret0, _ := ret[0].(shard.Shard) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ConnectionWithTSA indicates an expected call of ConnectionWithTSA. -func (mr *MockDBPoolMockRecorder) ConnectionWithTSA(clid, shardKey, tsa interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ConnectionWithTSA", reflect.TypeOf((*MockDBPool)(nil).ConnectionWithTSA), clid, shardKey, tsa) -} - -// Discard mocks base method. -func (m *MockDBPool) Discard(sh shard.Shard) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Discard", sh) - ret0, _ := ret[0].(error) - return ret0 -} - -// Discard indicates an expected call of Discard. -func (mr *MockDBPoolMockRecorder) Discard(sh interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Discard", reflect.TypeOf((*MockDBPool)(nil).Discard), sh) -} - -// ForEach mocks base method. -func (m *MockDBPool) ForEach(cb func(shard.Shardinfo) error) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ForEach", cb) - ret0, _ := ret[0].(error) - return ret0 -} - -// ForEach indicates an expected call of ForEach. -func (mr *MockDBPoolMockRecorder) ForEach(cb interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForEach", reflect.TypeOf((*MockDBPool)(nil).ForEach), cb) -} - -// ForEachPool mocks base method. -func (m *MockDBPool) ForEachPool(cb func(pool.Pool) error) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ForEachPool", cb) - ret0, _ := ret[0].(error) - return ret0 -} - -// ForEachPool indicates an expected call of ForEachPool. -func (mr *MockDBPoolMockRecorder) ForEachPool(cb interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ForEachPool", reflect.TypeOf((*MockDBPool)(nil).ForEachPool), cb) -} - -// Put mocks base method. -func (m *MockDBPool) Put(host shard.Shard) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Put", host) - ret0, _ := ret[0].(error) - return ret0 -} - -// Put indicates an expected call of Put. -func (mr *MockDBPoolMockRecorder) Put(host interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Put", reflect.TypeOf((*MockDBPool)(nil).Put), host) -} - -// SetRule mocks base method. -func (m *MockDBPool) SetRule(rule *config.BackendRule) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetRule", rule) -} - -// SetRule indicates an expected call of SetRule. -func (mr *MockDBPoolMockRecorder) SetRule(rule interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetRule", reflect.TypeOf((*MockDBPool)(nil).SetRule), rule) -} - -// SetShuffleHosts mocks base method. -func (m *MockDBPool) SetShuffleHosts(arg0 bool) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "SetShuffleHosts", arg0) -} - -// SetShuffleHosts indicates an expected call of SetShuffleHosts. -func (mr *MockDBPoolMockRecorder) SetShuffleHosts(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetShuffleHosts", reflect.TypeOf((*MockDBPool)(nil).SetShuffleHosts), arg0) -} - -// ShardMapping mocks base method. -func (m *MockDBPool) ShardMapping() map[string]*config.Shard { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ShardMapping") - ret0, _ := ret[0].(map[string]*config.Shard) - return ret0 -} - -// ShardMapping indicates an expected call of ShardMapping. -func (mr *MockDBPoolMockRecorder) ShardMapping() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShardMapping", reflect.TypeOf((*MockDBPool)(nil).ShardMapping)) -} - -// View mocks base method. -func (m *MockDBPool) View() pool.Statistics { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "View") - ret0, _ := ret[0].(pool.Statistics) - return ret0 -} - -// View indicates an expected call of View. -func (mr *MockDBPoolMockRecorder) View() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "View", reflect.TypeOf((*MockDBPool)(nil).View)) -} diff --git a/pkg/models/datashards/datashard.go b/pkg/models/datashards/datashard.go index 07f57a9a0..b6ed00f3e 100644 --- a/pkg/models/datashards/datashard.go +++ b/pkg/models/datashards/datashard.go @@ -36,7 +36,7 @@ func NewDataShard(name string, cfg *config.Shard) *DataShard { // - *proto.Shard: The converted proto.Shard object. func DataShardToProto(shard *DataShard) *proto.Shard { return &proto.Shard{ - Hosts: shard.Cfg.Hosts, + Hosts: shard.Cfg.Hosts(), Id: shard.ID, } } @@ -52,8 +52,8 @@ func DataShardToProto(shard *DataShard) *proto.Shard { // - *DataShard: The created DataShard instance. func DataShardFromProto(shard *proto.Shard) *DataShard { return NewDataShard(shard.Id, &config.Shard{ - Hosts: shard.Hosts, - Type: config.DataShard, + RawHosts: shard.Hosts, + Type: config.DataShard, }) } @@ -68,7 +68,7 @@ func DataShardFromProto(shard *proto.Shard) *DataShard { // - *DataShard: The created DataShard instance. func DataShardFromDb(shard *qdb.Shard) *DataShard { return NewDataShard(shard.ID, &config.Shard{ - Hosts: shard.Hosts, - Type: config.DataShard, + RawHosts: shard.RawHosts, + Type: config.DataShard, }) } diff --git a/pkg/pool/dbpool.go b/pkg/pool/dbpool.go index 0e98eaf27..c7d5c3e6e 100644 --- a/pkg/pool/dbpool.go +++ b/pkg/pool/dbpool.go @@ -4,6 +4,7 @@ import ( "fmt" "math/rand" "net" + "sort" "strings" "sync" "time" @@ -22,38 +23,31 @@ import ( type TsaKey struct { Tsa tsa.TSA Host string + AZ string } -type InstancePoolImpl struct { +type DBPool struct { Pool - pool MultiShardPool - shardMapping map[string]*config.Shard - - shuffleHosts bool - + pool MultiShardPool + shardMapping map[string]*config.Shard cacheTSAchecks sync.Map + checker tsa.TSAChecker - checker tsa.TSAChecker + ShuffleHosts bool + PreferAZ string } // ConnectionHost implements DBPool. -func (s *InstancePoolImpl) ConnectionHost(clid uint, shardKey kr.ShardKey, host string) (shard.Shard, error) { +func (s *DBPool) ConnectionHost(clid uint, shardKey kr.ShardKey, host config.Host) (shard.Shard, error) { return s.pool.ConnectionHost(clid, shardKey, host) } // View implements DBPool. // Subtle: this method shadows the method (Pool).View of InstancePoolImpl.Pool. -func (s *InstancePoolImpl) View() Statistics { +func (s *DBPool) View() Statistics { panic("unimplemented") } -// SetShuffleHosts implements DBPool. -func (s *InstancePoolImpl) SetShuffleHosts(val bool) { - s.shuffleHosts = val -} - -var _ DBPool = &InstancePoolImpl{} - // traverseHostsMatchCB traverses the list of hosts and invokes the provided callback function // for each host until the callback returns true. It returns the shard that satisfies the callback // condition. If no shard satisfies the condition, it returns nil. @@ -68,22 +62,21 @@ var _ DBPool = &InstancePoolImpl{} // - shard.Shard: The shard that satisfies the callback condition, or nil if no shard satisfies the condition. // // TODO : unit tests -func (s *InstancePoolImpl) traverseHostsMatchCB( - clid uint, - key kr.ShardKey, hosts []string, cb func(shard.Shard) bool, tsa tsa.TSA) shard.Shard { - +func (s *DBPool) traverseHostsMatchCB(clid uint, key kr.ShardKey, hosts []config.Host, cb func(shard.Shard) bool, tsa tsa.TSA) shard.Shard { for _, host := range hosts { sh, err := s.pool.ConnectionHost(clid, key, host) if err != nil { s.cacheTSAchecks.Store(TsaKey{ Tsa: tsa, - Host: host, + Host: host.Address, + AZ: host.AZ, }, false) spqrlog.Zero.Error(). Err(err). - Str("host", host). + Str("host", host.Address). + Str("az", host.AZ). Uint("client", clid). Msg("failed to get connection to host for client") continue @@ -100,7 +93,7 @@ func (s *InstancePoolImpl) traverseHostsMatchCB( return nil } -// SelectReadOnlyShardHost selects a read-only shard host from the given list of hosts based on the provided client ID and shard key. +// selectReadOnlyShardHost selects a read-only shard host from the given list of hosts based on the provided client ID and shard key. // It traverses the hosts and performs checks to ensure the selected shard host is suitable for read-only operations. // If a suitable shard host is found, it is returned along with a nil error. // If no suitable shard host is found, an error is returned with a message indicating the reason for failure. @@ -115,9 +108,7 @@ func (s *InstancePoolImpl) traverseHostsMatchCB( // - error: An error if no suitable shard host is found. // // TODO : unit tests -func (s *InstancePoolImpl) SelectReadOnlyShardHost( - clid uint, - key kr.ShardKey, hosts []string, targetSessionAttrs tsa.TSA) (shard.Shard, error) { +func (s *DBPool) selectReadOnlyShardHost(clid uint, key kr.ShardKey, hosts []config.Host, tsa tsa.TSA) (shard.Shard, error) { totalMsg := make([]string, 0) sh := s.traverseHostsMatchCB(clid, key, hosts, func(shard shard.Shard) bool { if ch, reason, err := s.checker.CheckTSA(shard); err != nil { @@ -125,15 +116,17 @@ func (s *InstancePoolImpl) SelectReadOnlyShardHost( _ = s.pool.Discard(shard) s.cacheTSAchecks.Store(TsaKey{ - Tsa: targetSessionAttrs, + Tsa: tsa, Host: shard.Instance().Hostname(), + AZ: shard.Instance().AvailabilityZone(), }, false) return false } else { s.cacheTSAchecks.Store(TsaKey{ - Tsa: targetSessionAttrs, + Tsa: tsa, Host: shard.Instance().Hostname(), + AZ: shard.Instance().AvailabilityZone(), }, !ch) if ch { @@ -144,7 +137,7 @@ func (s *InstancePoolImpl) SelectReadOnlyShardHost( return true } - }, targetSessionAttrs) + }, tsa) if sh != nil { return sh, nil } @@ -152,7 +145,7 @@ func (s *InstancePoolImpl) SelectReadOnlyShardHost( return nil, fmt.Errorf("shard %s failed to find replica within %s", key.Name, strings.Join(totalMsg, ";")) } -// SelectReadWriteShardHost selects a read-write shard host from the given list of hosts based on the provided client ID and shard key. +// selectReadWriteShardHost selects a read-write shard host from the given list of hosts based on the provided client ID and shard key. // It traverses the hosts and checks if each shard is available and suitable for read-write operations. // If a suitable shard is found, it is returned along with no error. // If no suitable shard is found, an error is returned indicating the failure reason. @@ -167,9 +160,7 @@ func (s *InstancePoolImpl) SelectReadOnlyShardHost( // - error: An error if no suitable shard host is found. // // TODO : unit tests -func (s *InstancePoolImpl) SelectReadWriteShardHost( - clid uint, - key kr.ShardKey, hosts []string, targetSessionAttrs tsa.TSA) (shard.Shard, error) { +func (s *DBPool) selectReadWriteShardHost(clid uint, key kr.ShardKey, hosts []config.Host, tsa tsa.TSA) (shard.Shard, error) { totalMsg := make([]string, 0) sh := s.traverseHostsMatchCB(clid, key, hosts, func(shard shard.Shard) bool { if ch, reason, err := s.checker.CheckTSA(shard); err != nil { @@ -177,15 +168,17 @@ func (s *InstancePoolImpl) SelectReadWriteShardHost( _ = s.pool.Discard(shard) s.cacheTSAchecks.Store(TsaKey{ - Tsa: targetSessionAttrs, + Tsa: tsa, Host: shard.Instance().Hostname(), + AZ: shard.Instance().AvailabilityZone(), }, false) return false } else { s.cacheTSAchecks.Store(TsaKey{ - Tsa: targetSessionAttrs, + Tsa: tsa, Host: shard.Instance().Hostname(), + AZ: shard.Instance().AvailabilityZone(), }, ch) if !ch { @@ -196,7 +189,7 @@ func (s *InstancePoolImpl) SelectReadWriteShardHost( return true } - }, targetSessionAttrs) + }, tsa) if sh != nil { return sh, nil } @@ -218,53 +211,18 @@ func (s *InstancePoolImpl) SelectReadWriteShardHost( // - error: An error if the connection cannot be established. // // TODO : unit tests -func (s *InstancePoolImpl) ConnectionWithTSA( - clid uint, - key kr.ShardKey, - targetSessionAttrs tsa.TSA) (shard.Shard, error) { +func (s *DBPool) ConnectionWithTSA(clid uint, key kr.ShardKey, targetSessionAttrs tsa.TSA) (shard.Shard, error) { spqrlog.Zero.Debug(). Uint("client", clid). Str("shard", key.Name). Str("tsa", string(targetSessionAttrs)). Msg("acquiring new instance connection for client to shard with target session attrs") - var hostOrder []string - var posCache []string - var negCache []string - - if _, ok := s.shardMapping[key.Name]; !ok { - return nil, fmt.Errorf("shard with name %q not found", key.Name) + hostOrder, err := s.BuildHostOrder(key, targetSessionAttrs) + if err != nil { + return nil, err } - for _, host := range s.shardMapping[key.Name].Hosts { - tsaKey := TsaKey{ - Tsa: targetSessionAttrs, - Host: host, - } - - if res, ok := s.cacheTSAchecks.Load(tsaKey); ok { - if res.(bool) { - posCache = append(posCache, host) - } else { - negCache = append(negCache, host) - } - } else { - // assume ok - posCache = append(posCache, host) - } - } - - if s.shuffleHosts { - rand.Shuffle(len(posCache), func(i, j int) { - posCache[i], posCache[j] = posCache[j], posCache[i] - }) - rand.Shuffle(len(negCache), func(i, j int) { - negCache[i], negCache[j] = negCache[j], negCache[i] - }) - } - - hostOrder = append(posCache, negCache...) - /* pool.Connection will reoder hosts in such way, that preferred tsa will go first */ switch targetSessionAttrs { case "": @@ -278,45 +236,98 @@ func (s *InstancePoolImpl) ConnectionWithTSA( s.cacheTSAchecks.Store(TsaKey{ Tsa: config.TargetSessionAttrsAny, - Host: host, + Host: host.Address, + AZ: host.AZ, }, false) spqrlog.Zero.Error(). Err(err). - Str("host", host). + Str("host", host.Address). + Str("availability-zone", host.AZ). Uint("client", clid). Msg("failed to get connection to host for client") continue } s.cacheTSAchecks.Store(TsaKey{ Tsa: config.TargetSessionAttrsAny, - Host: host, + Host: host.Address, + AZ: host.AZ, }, true) return shard, nil } return nil, fmt.Errorf("failed to get connection to any shard host within %s", total_msg) case config.TargetSessionAttrsRO: - return s.SelectReadOnlyShardHost(clid, key, hostOrder, targetSessionAttrs) + return s.selectReadOnlyShardHost(clid, key, hostOrder, targetSessionAttrs) case config.TargetSessionAttrsPS: - if res, err := s.SelectReadOnlyShardHost(clid, key, hostOrder, targetSessionAttrs); err != nil { - return s.SelectReadWriteShardHost(clid, key, hostOrder, targetSessionAttrs) + if res, err := s.selectReadOnlyShardHost(clid, key, hostOrder, targetSessionAttrs); err != nil { + return s.selectReadWriteShardHost(clid, key, hostOrder, targetSessionAttrs) } else { return res, nil } case config.TargetSessionAttrsRW: - return s.SelectReadWriteShardHost(clid, key, hostOrder, targetSessionAttrs) + return s.selectReadWriteShardHost(clid, key, hostOrder, targetSessionAttrs) default: return nil, fmt.Errorf("failed to match correct target session attrs") } } +func (s *DBPool) BuildHostOrder(key kr.ShardKey, targetSessionAttrs tsa.TSA) ([]config.Host, error) { + var hostOrder []config.Host + var posCache []config.Host + var negCache []config.Host + + if _, ok := s.shardMapping[key.Name]; !ok { + return nil, fmt.Errorf("shard with name %q not found", key.Name) + } + + for _, host := range s.shardMapping[key.Name].HostsAZ() { + tsaKey := TsaKey{ + Tsa: targetSessionAttrs, + Host: host.Address, + AZ: host.AZ, + } + + if res, ok := s.cacheTSAchecks.Load(tsaKey); ok { + if res.(bool) { + posCache = append(posCache, host) + } else { + negCache = append(negCache, host) + } + } else { + + posCache = append(posCache, host) + } + } + + if s.ShuffleHosts { + rand.Shuffle(len(posCache), func(i, j int) { + posCache[i], posCache[j] = posCache[j], posCache[i] + }) + rand.Shuffle(len(negCache), func(i, j int) { + negCache[i], negCache[j] = negCache[j], negCache[i] + }) + } + + if len(s.PreferAZ) > 0 { + sort.Slice(posCache, func(i, j int) bool { + return posCache[i].AZ == s.PreferAZ + }) + sort.Slice(negCache, func(i, j int) bool { + return negCache[i].AZ == s.PreferAZ + }) + } + + hostOrder = append(posCache, negCache...) + return hostOrder, nil +} + // SetRule initializes the backend rule in the instance pool. // It takes a pointer to a BackendRule as a parameter and saves it // // Parameters: // - rule: A pointer to a BackendRule representing the backend rule to be initialized. -func (s *InstancePoolImpl) SetRule(rule *config.BackendRule) { +func (s *DBPool) SetRule(rule *config.BackendRule) { s.pool.SetRule(rule) } @@ -324,7 +335,7 @@ func (s *InstancePoolImpl) SetRule(rule *config.BackendRule) { // // Returns: // - map[string]*config.Shard: The shard mapping of the instance pool. -func (s *InstancePoolImpl) ShardMapping() map[string]*config.Shard { +func (s *DBPool) ShardMapping() map[string]*config.Shard { return s.shardMapping } @@ -336,7 +347,7 @@ func (s *InstancePoolImpl) ShardMapping() map[string]*config.Shard { // // Returns: // - error: An error if the callback function returns an error. -func (s *InstancePoolImpl) ForEach(cb func(sh shard.Shardinfo) error) error { +func (s *DBPool) ForEach(cb func(sh shard.Shardinfo) error) error { return s.pool.ForEach(cb) } @@ -351,7 +362,7 @@ func (s *InstancePoolImpl) ForEach(cb func(sh shard.Shardinfo) error) error { // - error: An error if the shard is discarded or if there is an error putting the shard into the pool. // // TODO : unit tests -func (s *InstancePoolImpl) Put(sh shard.Shard) error { +func (s *DBPool) Put(sh shard.Shard) error { if sh.Sync() != 0 { spqrlog.Zero.Error(). Uint("shard", spqrlog.GetPointer(sh)). @@ -377,7 +388,7 @@ func (s *InstancePoolImpl) Put(sh shard.Shard) error { // // Returns: // - error: An error if the callback function returns an error. -func (s *InstancePoolImpl) ForEachPool(cb func(pool Pool) error) error { +func (s *DBPool) ForEachPool(cb func(pool Pool) error) error { return s.pool.ForEachPool(cb) } @@ -389,7 +400,7 @@ func (s *InstancePoolImpl) ForEachPool(cb func(pool Pool) error) error { // // Returns: // - error: An error if the removal fails, nil otherwise. -func (s *InstancePoolImpl) Discard(sh shard.Shard) error { +func (s *DBPool) Discard(sh shard.Shard) error { return s.pool.Discard(sh) } @@ -405,10 +416,10 @@ func (s *InstancePoolImpl) Discard(sh shard.Shard) error { // // Returns: // - DBPool: A DBPool interface that represents the created pool. -func NewDBPool(mapping map[string]*config.Shard, sp *startup.StartupParams) DBPool { - allocator := func(shardKey kr.ShardKey, host string, rule *config.BackendRule) (shard.Shard, error) { +func NewDBPool(mapping map[string]*config.Shard, startupParams *startup.StartupParams, preferAZ string) *DBPool { + allocator := func(shardKey kr.ShardKey, host config.Host, rule *config.BackendRule) (shard.Shard, error) { shardConfig := mapping[shardKey.Name] - hostname, _, _ := net.SplitHostPort(host) // TODO try to remove this + hostname, _, _ := net.SplitHostPort(host.Address) // TODO try to remove this tlsconfig, err := shardConfig.TLS.Init(hostname) if err != nil { return nil, err @@ -418,28 +429,28 @@ func NewDBPool(mapping map[string]*config.Shard, sp *startup.StartupParams) DBPo keepAlive := config.ValueOrDefaultDuration(rule.KeepAlive, defaultKeepAlive) tcpUserTimeout := config.ValueOrDefaultDuration(rule.TcpUserTimeout, defaultTcpUserTimeout) - pgi, err := conn.NewInstanceConn(host, shardKey.Name, tlsconfig, connTimeout, keepAlive, tcpUserTimeout) + pgi, err := conn.NewInstanceConn(host.Address, host.AZ, shardKey.Name, tlsconfig, connTimeout, keepAlive, tcpUserTimeout) if err != nil { return nil, err } - return datashard.NewShard(shardKey, pgi, mapping[shardKey.Name], rule, sp) + return datashard.NewShard(shardKey, pgi, mapping[shardKey.Name], rule, startupParams) } - return &InstancePoolImpl{ + return &DBPool{ pool: NewPool(allocator), shardMapping: mapping, - shuffleHosts: true, + ShuffleHosts: true, + PreferAZ: preferAZ, cacheTSAchecks: sync.Map{}, checker: tsa.NewTSAChecker(), } } -func NewDBPoolFromMultiPool(mapping map[string]*config.Shard, sp *startup.StartupParams, mp MultiShardPool, shuffleHosts bool, tsaRecheckDuration time.Duration) DBPool { - return &InstancePoolImpl{ +func NewDBPoolFromMultiPool(mapping map[string]*config.Shard, sp *startup.StartupParams, mp MultiShardPool, tsaRecheckDuration time.Duration) *DBPool { + return &DBPool{ pool: mp, shardMapping: mapping, - shuffleHosts: shuffleHosts, cacheTSAchecks: sync.Map{}, checker: tsa.NewTSACheckerWithDuration(tsaRecheckDuration), } diff --git a/pkg/pool/dbpool_test.go b/pkg/pool/dbpool_test.go index ece5ce521..89df68009 100644 --- a/pkg/pool/dbpool_test.go +++ b/pkg/pool/dbpool_test.go @@ -13,6 +13,7 @@ import ( "github.com/pg-sharding/spqr/pkg/models/kr" "github.com/pg-sharding/spqr/pkg/pool" "github.com/pg-sharding/spqr/pkg/startup" + "github.com/pg-sharding/spqr/pkg/tsa" "github.com/pg-sharding/spqr/pkg/txstatus" "github.com/stretchr/testify/assert" ) @@ -32,22 +33,25 @@ func TestDbPoolOrderCaching(t *testing.T) { dbpool := pool.NewDBPoolFromMultiPool(map[string]*config.Shard{ key.Name: { - Hosts: []string{ + RawHosts: []string{ "h1", "h2", "h3", }, }, - }, &startup.StartupParams{}, underyling_pool, false, time.Hour) + }, &startup.StartupParams{}, underyling_pool, time.Hour) ins1 := mockinst.NewMockDBInstance(ctrl) ins1.EXPECT().Hostname().AnyTimes().Return("h1") + ins1.EXPECT().AvailabilityZone().AnyTimes().Return("") ins2 := mockinst.NewMockDBInstance(ctrl) ins2.EXPECT().Hostname().AnyTimes().Return("h2") + ins2.EXPECT().AvailabilityZone().AnyTimes().Return("") ins3 := mockinst.NewMockDBInstance(ctrl) ins3.EXPECT().Hostname().AnyTimes().Return("h3") + ins3.EXPECT().AvailabilityZone().AnyTimes().Return("") h1 := mockshard.NewMockShard(ctrl) h1.EXPECT().Instance().AnyTimes().Return(ins1) @@ -68,9 +72,9 @@ func TestDbPoolOrderCaching(t *testing.T) { h1, h2, h3, } - underyling_pool.EXPECT().ConnectionHost(clId, key, "h1").Times(1).Return(h1, nil) - underyling_pool.EXPECT().ConnectionHost(clId, key, "h2").Times(1).Return(h2, nil) - underyling_pool.EXPECT().ConnectionHost(clId, key, "h3").Times(1).Return(h3, nil) + underyling_pool.EXPECT().ConnectionHost(clId, key, config.Host{Address: "h1"}).Times(1).Return(h1, nil) + underyling_pool.EXPECT().ConnectionHost(clId, key, config.Host{Address: "h2"}).Times(1).Return(h2, nil) + underyling_pool.EXPECT().ConnectionHost(clId, key, config.Host{Address: "h3"}).Times(1).Return(h3, nil) for ind, h := range hs { @@ -110,16 +114,18 @@ func TestDbPoolOrderCaching(t *testing.T) { sh, err := dbpool.ConnectionWithTSA(clId, key, config.TargetSessionAttrsRW) - assert.Equal(sh, h3) + assert.Equal(sh.Instance().Hostname(), h3.Instance().Hostname()) + assert.Equal(sh.Instance().AvailabilityZone(), h3.Instance().AvailabilityZone()) assert.NoError(err) /* next time expect only one call */ - underyling_pool.EXPECT().ConnectionHost(clId, key, "h3").Times(1).Return(h3, nil) + underyling_pool.EXPECT().ConnectionHost(clId, key, config.Host{Address: "h3"}).Times(1).Return(h3, nil) sh, err = dbpool.ConnectionWithTSA(clId, key, config.TargetSessionAttrsRW) - assert.Equal(sh, h3) + assert.Equal(sh.Instance().Hostname(), h3.Instance().Hostname()) + assert.Equal(sh.Instance().AvailabilityZone(), h3.Instance().AvailabilityZone()) assert.NoError(err) } @@ -139,22 +145,25 @@ func TestDbPoolReadOnlyOrderDistribution(t *testing.T) { dbpool := pool.NewDBPoolFromMultiPool(map[string]*config.Shard{ key.Name: { - Hosts: []string{ + RawHosts: []string{ "h1", "h2", "h3", }, }, - }, &startup.StartupParams{}, underyling_pool, false, time.Hour) + }, &startup.StartupParams{}, underyling_pool, time.Hour) ins1 := mockinst.NewMockDBInstance(ctrl) ins1.EXPECT().Hostname().AnyTimes().Return("h1") + ins1.EXPECT().AvailabilityZone().AnyTimes().Return("") ins2 := mockinst.NewMockDBInstance(ctrl) ins2.EXPECT().Hostname().AnyTimes().Return("h2") + ins2.EXPECT().AvailabilityZone().AnyTimes().Return("") ins3 := mockinst.NewMockDBInstance(ctrl) ins3.EXPECT().Hostname().AnyTimes().Return("h3") + ins3.EXPECT().AvailabilityZone().AnyTimes().Return("") h1 := mockshard.NewMockShard(ctrl) h1.EXPECT().Instance().AnyTimes().Return(ins1) @@ -175,9 +184,9 @@ func TestDbPoolReadOnlyOrderDistribution(t *testing.T) { h1, h2, h3, } - underyling_pool.EXPECT().ConnectionHost(clId, key, "h1").AnyTimes().Return(h1, nil) - underyling_pool.EXPECT().ConnectionHost(clId, key, "h2").AnyTimes().Return(h2, nil) - underyling_pool.EXPECT().ConnectionHost(clId, key, "h3").Times(1).Return(h3, nil) + underyling_pool.EXPECT().ConnectionHost(clId, key, config.Host{Address: "h1"}).AnyTimes().Return(h1, nil) + underyling_pool.EXPECT().ConnectionHost(clId, key, config.Host{Address: "h2"}).AnyTimes().Return(h2, nil) + underyling_pool.EXPECT().ConnectionHost(clId, key, config.Host{Address: "h3"}).Times(1).Return(h3, nil) for ind, h := range hs { @@ -217,11 +226,12 @@ func TestDbPoolReadOnlyOrderDistribution(t *testing.T) { sh, err := dbpool.ConnectionWithTSA(clId, key, config.TargetSessionAttrsRW) - assert.Equal(sh, h3) + assert.Equal(sh.Instance().Hostname(), h3.Instance().Hostname()) + assert.Equal(sh.Instance().AvailabilityZone(), h3.Instance().AvailabilityZone()) assert.NoError(err) - underyling_pool.EXPECT().ConnectionHost(clId, key, "h3").MaxTimes(1).Return(h3, nil) + underyling_pool.EXPECT().ConnectionHost(clId, key, config.Host{Address: "h3"}).MaxTimes(1).Return(h3, nil) underyling_pool.EXPECT().Put(h3).Return(nil).MaxTimes(1) @@ -234,7 +244,7 @@ func TestDbPoolReadOnlyOrderDistribution(t *testing.T) { cnth1 := 0 cnth2 := 0 - dbpool.SetShuffleHosts(true) + dbpool.ShuffleHosts = true for i := 0; i < repeattimes; i++ { sh, err = dbpool.ConnectionWithTSA(clId, key, config.TargetSessionAttrsRO) @@ -258,3 +268,102 @@ func TestDbPoolReadOnlyOrderDistribution(t *testing.T) { assert.Less(diff, 90) assert.Equal(repeattimes, cnth1+cnth2) } + +func TestBuildHostOrder(t *testing.T) { + ctrl := gomock.NewController(t) + + underyling_pool := mockpool.NewMockMultiShardPool(ctrl) + + key := kr.ShardKey{ + Name: "sh1", + } + + dbpool := pool.NewDBPoolFromMultiPool(map[string]*config.Shard{ + key.Name: { + RawHosts: []string{ + "sas-123.db.yandex.net:6432:sas", + "sas-234.db.yandex.net:6432:sas", + "vla-123.db.yandex.net:6432:vla", + "vla-234.db.yandex.net:6432:vla", + "klg-123.db.yandex.net:6432:klg", + "klg-234.db.yandex.net:6432:klg", + }, + }, + }, &startup.StartupParams{}, underyling_pool, time.Hour) + + tests := []struct { + name string + shardKey kr.ShardKey + targetSessionAttrs tsa.TSA + shuffleHosts bool + preferAZ string + expectedHosts []string + }{ + { + name: "No shuffle, no preferred AZ", + shardKey: kr.ShardKey{Name: "sh1"}, + targetSessionAttrs: config.TargetSessionAttrsAny, + shuffleHosts: false, + preferAZ: "", + expectedHosts: []string{ + "sas-123.db.yandex.net:6432", + "sas-234.db.yandex.net:6432", + "vla-123.db.yandex.net:6432", + "vla-234.db.yandex.net:6432", + "klg-123.db.yandex.net:6432", + "klg-234.db.yandex.net:6432", + }, + }, + { + name: "Shuffle hosts", + shardKey: kr.ShardKey{Name: "sh1"}, + targetSessionAttrs: config.TargetSessionAttrsAny, + shuffleHosts: true, + preferAZ: "", + expectedHosts: []string{ + "sas-123.db.yandex.net:6432", + "sas-234.db.yandex.net:6432", + "vla-123.db.yandex.net:6432", + "vla-234.db.yandex.net:6432", + "klg-123.db.yandex.net:6432", + "klg-234.db.yandex.net:6432", + }, + }, + { + name: "Preferred AZ", + shardKey: kr.ShardKey{Name: "sh1"}, + targetSessionAttrs: config.TargetSessionAttrsAny, + shuffleHosts: false, + preferAZ: "klg", + expectedHosts: []string{ + "klg-234.db.yandex.net:6432", + "klg-123.db.yandex.net:6432", + "sas-123.db.yandex.net:6432", + "sas-234.db.yandex.net:6432", + "vla-123.db.yandex.net:6432", + "vla-234.db.yandex.net:6432", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dbpool.ShuffleHosts = tt.shuffleHosts + dbpool.PreferAZ = tt.preferAZ + + hostOrder, err := dbpool.BuildHostOrder(tt.shardKey, tt.targetSessionAttrs) + assert.NoError(t, err) + + var hostAddresses []string + for _, host := range hostOrder { + hostAddresses = append(hostAddresses, host.Address) + } + + if tt.shuffleHosts { + assert.ElementsMatch(t, tt.expectedHosts, hostAddresses) + } else { + assert.Equal(t, tt.expectedHosts, hostAddresses) + } + }) + } +} diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go index 2596fed82..79a3c6f86 100644 --- a/pkg/pool/pool.go +++ b/pkg/pool/pool.go @@ -6,7 +6,6 @@ import ( "github.com/pg-sharding/spqr/pkg/config" "github.com/pg-sharding/spqr/pkg/models/kr" "github.com/pg-sharding/spqr/pkg/shard" - "github.com/pg-sharding/spqr/pkg/tsa" ) const ( @@ -47,7 +46,7 @@ type MultiShardPool interface { shard.ShardIterator PoolIterator - ConnectionHost(clid uint, shardKey kr.ShardKey, host string) (shard.Shard, error) + ConnectionHost(clid uint, shardKey kr.ShardKey, host config.Host) (shard.Shard, error) SetRule(rule *config.BackendRule) } @@ -56,14 +55,4 @@ type PoolIterator interface { ForEachPool(cb func(p Pool) error) error } -type ConnectionAllocFn func(shardKey kr.ShardKey, host string, rule *config.BackendRule) (shard.Shard, error) - -type DBPool interface { - MultiShardPool - - ShardMapping() map[string]*config.Shard - - ConnectionWithTSA(clid uint, shardKey kr.ShardKey, tsa tsa.TSA) (shard.Shard, error) - - SetShuffleHosts(bool) -} +type ConnectionAllocFn func(shardKey kr.ShardKey, host config.Host, rule *config.BackendRule) (shard.Shard, error) diff --git a/pkg/pool/shardpool.go b/pkg/pool/shardpool.go index b0f051c9c..9e3594e47 100644 --- a/pkg/pool/shardpool.go +++ b/pkg/pool/shardpool.go @@ -29,11 +29,12 @@ type shardPool struct { /* dedicated */ host string + az string - ConnectionLimit int - ConnectionRetries int - ConnectionRetrySleepSlice int - ConnectionRetryRandomSleep int + connectionLimit int + connectionRetries int + connectionRetrySleepSlice int + connectionRetryRandomSleep int } var _ Pool = &shardPool{} @@ -49,21 +50,23 @@ var _ Pool = &shardPool{} // // Returns: // - Pool: The created instance of shardPool. -func NewShardPool(allocFn ConnectionAllocFn, host string, beRule *config.BackendRule) Pool { +func NewShardPool(allocFn ConnectionAllocFn, host config.Host, beRule *config.BackendRule) Pool { connLimit := config.ValueOrDefaultInt(beRule.ConnectionLimit, defaultInstanceConnectionLimit) connRetries := config.ValueOrDefaultInt(beRule.ConnectionRetries, defaultInstanceConnectionRetries) ret := &shardPool{ - mu: sync.Mutex{}, - pool: nil, - active: make(map[uint]shard.Shard), - alloc: allocFn, - beRule: beRule, - host: host, - ConnectionLimit: connLimit, - ConnectionRetries: connRetries, - ConnectionRetrySleepSlice: 50, - ConnectionRetryRandomSleep: 10, + mu: sync.Mutex{}, + pool: nil, + active: make(map[uint]shard.Shard), + alloc: allocFn, + beRule: beRule, + host: host.Address, + az: host.AZ, + + connectionLimit: connLimit, + connectionRetries: connRetries, + connectionRetrySleepSlice: 50, + connectionRetryRandomSleep: 10, } ret.queue = make(chan struct{}, connLimit) @@ -114,10 +117,10 @@ func (h *shardPool) View() Statistics { // TODO : unit tests func (h *shardPool) Connection(clid uint, shardKey kr.ShardKey) (shard.Shard, error) { if err := func() error { - for rep := 0; rep < h.ConnectionRetries; rep++ { + for rep := 0; rep < h.connectionRetries; rep++ { select { // TODO: configure waits using backend rule - case <-time.After(time.Duration(h.ConnectionRetrySleepSlice) * time.Millisecond * time.Duration(1+rand.Int31()%int32(h.ConnectionRetryRandomSleep))): + case <-time.After(time.Duration(h.connectionRetrySleepSlice) * time.Millisecond * time.Duration(1+rand.Int31()%int32(h.connectionRetryRandomSleep))): spqrlog.Zero.Info(). Uint("client", clid). Str("host", h.host). @@ -156,7 +159,7 @@ func (h *shardPool) Connection(clid uint, shardKey kr.ShardKey) (shard.Shard, er // do not hold lock on poolRW while allocate new connection var err error - sh, err = h.alloc(shardKey, h.host, h.beRule) + sh, err = h.alloc(shardKey, config.Host{Address: h.host, AZ: h.az}, h.beRule) if err != nil { // return acquired token h.queue <- struct{}{} @@ -356,11 +359,11 @@ func (c *cPool) ForEachPool(cb func(p Pool) error) error { // - error: The error that occurred during the connection process. // // TODO : unit tests -func (c *cPool) ConnectionHost(clid uint, shardKey kr.ShardKey, host string) (shard.Shard, error) { +func (c *cPool) ConnectionHost(clid uint, shardKey kr.ShardKey, host config.Host) (shard.Shard, error) { var pool Pool - if val, ok := c.pools.Load(host); !ok { + if val, ok := c.pools.Load(host.Address); !ok { pool = NewShardPool(c.alloc, host, c.beRule) - c.pools.Store(host, pool) + c.pools.Store(host.Address, pool) } else { pool = val.(Pool) } @@ -378,12 +381,11 @@ func (c *cPool) ConnectionHost(clid uint, shardKey kr.ShardKey, host string) (sh // - error: The error that occurred during the put operation. // // TODO : unit tests -func (c *cPool) Put(host shard.Shard) error { - if val, ok := c.pools.Load(host.Instance().Hostname()); ok { - return val.(Pool).Put(host) +func (c *cPool) Put(sh shard.Shard) error { + if val, ok := c.pools.Load(sh.Instance().Hostname()); ok { + return val.(Pool).Put(sh) } else { - /* very bad */ - panic(host) + panic(fmt.Sprintf("cPool.Put failed, hostname %s not found", sh.Instance().Hostname())) } } @@ -402,8 +404,7 @@ func (c *cPool) Discard(sh shard.Shard) error { if val, ok := c.pools.Load(sh.Instance().Hostname()); ok { return val.(Pool).Discard(sh) } else { - /* very bad */ - panic(sh) + panic(fmt.Sprintf("cPool.Discard failed, hostname %s not found", sh.Instance().Hostname())) } } diff --git a/pkg/pool/shardpool_test.go b/pkg/pool/shardpool_test.go index cebc3247d..09832dbc9 100644 --- a/pkg/pool/shardpool_test.go +++ b/pkg/pool/shardpool_test.go @@ -35,9 +35,9 @@ func TestShardPoolConnectionAcquirePut(t *testing.T) { shardconn.EXPECT().ID().AnyTimes().Return(uint(1234)) shardconn.EXPECT().TxStatus().AnyTimes().Return(txstatus.TXIDLE) - shp := pool.NewShardPool(func(shardKey kr.ShardKey, host string, rule *config.BackendRule) (shard.Shard, error) { + shp := pool.NewShardPool(func(shardKey kr.ShardKey, host config.Host, rule *config.BackendRule) (shard.Shard, error) { return shardconn, nil - }, "h1", &config.BackendRule{ + }, config.Host{Address: "h1"}, &config.BackendRule{ ConnectionLimit: 1, }) @@ -82,9 +82,9 @@ func TestShardPoolConnectionAcquireDiscard(t *testing.T) { shardconn.EXPECT().Close().Times(1) - shp := pool.NewShardPool(func(shardKey kr.ShardKey, host string, rule *config.BackendRule) (shard.Shard, error) { + shp := pool.NewShardPool(func(shardKey kr.ShardKey, host config.Host, rule *config.BackendRule) (shard.Shard, error) { return shardconn, nil - }, "h1", &config.BackendRule{ + }, config.Host{Address: "h1"}, &config.BackendRule{ ConnectionLimit: 1, }) @@ -121,9 +121,9 @@ func TestShardPoolAllocFnError(t *testing.T) { ins := mockinst.NewMockDBInstance(ctrl) ins.EXPECT().Hostname().AnyTimes().Return("h1") - shp := pool.NewShardPool(func(shardKey kr.ShardKey, host string, rule *config.BackendRule) (shard.Shard, error) { + shp := pool.NewShardPool(func(shardKey kr.ShardKey, host config.Host, rule *config.BackendRule) (shard.Shard, error) { return nil, errors.New("bad") - }, "h1", &config.BackendRule{ + }, config.Host{Address: "h1"}, &config.BackendRule{ ConnectionLimit: 1, }) @@ -177,7 +177,7 @@ func TestShardPoolConnectionAcquireLimit(t *testing.T) { var mu sync.Mutex - shp := pool.NewShardPool(func(shardKey kr.ShardKey, host string, rule *config.BackendRule) (shard.Shard, error) { + shp := pool.NewShardPool(func(shardKey kr.ShardKey, host config.Host, rule *config.BackendRule) (shard.Shard, error) { mu.Lock() defer mu.Unlock() @@ -191,7 +191,7 @@ func TestShardPoolConnectionAcquireLimit(t *testing.T) { assert.Fail("connection pool overflow") return nil, errors.New("bad") - }, "h1", &config.BackendRule{ + }, config.Host{Address: "h1"}, &config.BackendRule{ ConnectionLimit: connLimit, ConnectionRetries: 1, }) diff --git a/pkg/tsa/tsa.go b/pkg/tsa/tsa.go index e6c5abbf7..e106d6145 100644 --- a/pkg/tsa/tsa.go +++ b/pkg/tsa/tsa.go @@ -11,6 +11,7 @@ import ( "github.com/pg-sharding/spqr/pkg/txstatus" ) +// TSA is stands for target_session_attrs, type TSA string type TSAChecker interface { diff --git a/qdb/etcdqdb.go b/qdb/etcdqdb.go index 99b313437..902dd85b4 100644 --- a/qdb/etcdqdb.go +++ b/qdb/etcdqdb.go @@ -816,7 +816,7 @@ func (q *EtcdQDB) ListRouters(ctx context.Context) ([]*Router, error) { func (q *EtcdQDB) AddShard(ctx context.Context, shard *Shard) error { spqrlog.Zero.Debug(). Str("id", shard.ID). - Strs("hosts", shard.Hosts). + Strs("hosts", shard.RawHosts). Msg("etcdqdb: add shard") bytes, err := json.Marshal(shard) @@ -882,7 +882,7 @@ func (q *EtcdQDB) GetShard(ctx context.Context, id string) (*Shard, error) { for _, shard := range resp.Kvs { // The Port field is always for a while. - shardInfo.Hosts = append(shardInfo.Hosts, string(shard.Value)) + shardInfo.RawHosts = append(shardInfo.RawHosts, string(shard.Value)) } return shardInfo, nil diff --git a/qdb/memqdb_test.go b/qdb/memqdb_test.go index 5ecbf99e5..29a40c39e 100644 --- a/qdb/memqdb_test.go +++ b/qdb/memqdb_test.go @@ -15,8 +15,8 @@ var mockDistribution = &qdb.Distribution{ ID: "123", } var mockShard = &qdb.Shard{ - ID: "shard_id", - Hosts: []string{"host1", "host2"}, + ID: "shard_id", + RawHosts: []string{"host1", "host2"}, } var mockKeyRange = &qdb.KeyRange{ LowerBound: [][]byte{{1, 2}}, diff --git a/qdb/models.go b/qdb/models.go index 588b2b165..0c5133a92 100644 --- a/qdb/models.go +++ b/qdb/models.go @@ -57,14 +57,14 @@ func (r Router) Addr() string { } type Shard struct { - ID string `json:"id"` - Hosts []string `json:"hosts"` + ID string `json:"id"` + RawHosts []string `json:"hosts"` // format host:port:availability_zone } func NewShard(ID string, hosts []string) *Shard { return &Shard{ - ID: ID, - Hosts: hosts, + ID: ID, + RawHosts: hosts, } } diff --git a/router/qrouter/proxy_routing_test.go b/router/qrouter/proxy_routing_test.go index 90c131a53..5741b1d4d 100644 --- a/router/qrouter/proxy_routing_test.go +++ b/router/qrouter/proxy_routing_test.go @@ -48,12 +48,8 @@ func TestMultiShardRouting(t *testing.T) { lc := local.NewLocalCoordinator(db) pr, err := qrouter.NewProxyRouter(map[string]*config.Shard{ - "sh1": { - Hosts: nil, - }, - "sh2": { - Hosts: nil, - }, + "sh1": {}, + "sh2": {}, }, lc, &config.QRouter{}) assert.NoError(err) @@ -177,12 +173,8 @@ func TestComment(t *testing.T) { lc := local.NewLocalCoordinator(db) pr, err := qrouter.NewProxyRouter(map[string]*config.Shard{ - "sh1": { - Hosts: nil, - }, - "sh2": { - Hosts: nil, - }, + "sh1": {}, + "sh2": {}, }, lc, &config.QRouter{ DefaultRouteBehaviour: "BLOCK", }) @@ -281,12 +273,8 @@ func TestCTE(t *testing.T) { lc := local.NewLocalCoordinator(db) pr, err := qrouter.NewProxyRouter(map[string]*config.Shard{ - "sh1": { - Hosts: nil, - }, - "sh2": { - Hosts: nil, - }, + "sh1": {}, + "sh2": {}, }, lc, &config.QRouter{ DefaultRouteBehaviour: "BLOCK", }) @@ -537,12 +525,8 @@ func TestSingleShard(t *testing.T) { lc := local.NewLocalCoordinator(db) pr, err := qrouter.NewProxyRouter(map[string]*config.Shard{ - "sh1": { - Hosts: nil, - }, - "sh2": { - Hosts: nil, - }, + "sh1": {}, + "sh2": {}, }, lc, &config.QRouter{ DefaultRouteBehaviour: "BLOCK", }) @@ -942,12 +926,8 @@ func TestInsertOffsets(t *testing.T) { lc := local.NewLocalCoordinator(db) pr, err := qrouter.NewProxyRouter(map[string]*config.Shard{ - "sh1": { - Hosts: nil, - }, - "sh2": { - Hosts: nil, - }, + "sh1": {}, + "sh2": {}, }, lc, &config.QRouter{ DefaultRouteBehaviour: "BLOCK", }) @@ -1153,12 +1133,8 @@ func TestJoins(t *testing.T) { lc := local.NewLocalCoordinator(db) pr, err := qrouter.NewProxyRouter(map[string]*config.Shard{ - "sh1": { - Hosts: nil, - }, - "sh2": { - Hosts: nil, - }, + "sh1": {}, + "sh2": {}, }, lc, &config.QRouter{}) assert.NoError(err) @@ -1287,12 +1263,8 @@ func TestUnnest(t *testing.T) { lc := local.NewLocalCoordinator(db) pr, err := qrouter.NewProxyRouter(map[string]*config.Shard{ - "sh1": { - Hosts: nil, - }, - "sh2": { - Hosts: nil, - }, + "sh1": {}, + "sh2": {}, }, lc, &config.QRouter{ DefaultRouteBehaviour: "BLOCK", }) @@ -1411,12 +1383,8 @@ func TestCopySingleShard(t *testing.T) { lc := local.NewLocalCoordinator(db) pr, err := qrouter.NewProxyRouter(map[string]*config.Shard{ - "sh1": { - Hosts: nil, - }, - "sh2": { - Hosts: nil, - }, + "sh1": {}, + "sh2": {}, }, lc, &config.QRouter{ DefaultRouteBehaviour: "BLOCK", }) @@ -1479,12 +1447,8 @@ func TestSetStmt(t *testing.T) { lc := local.NewLocalCoordinator(db) pr, err := qrouter.NewProxyRouter(map[string]*config.Shard{ - "sh1": { - Hosts: nil, - }, - "sh2": { - Hosts: nil, - }, + "sh1": {}, + "sh2": {}, }, lc, &config.QRouter{ DefaultRouteBehaviour: "BLOCK", }) @@ -1560,12 +1524,8 @@ func TestMiscRouting(t *testing.T) { lc := local.NewLocalCoordinator(db) pr, err := qrouter.NewProxyRouter(map[string]*config.Shard{ - "sh1": { - Hosts: nil, - }, - "sh2": { - Hosts: nil, - }, + "sh1": {}, + "sh2": {}, }, lc, &config.QRouter{ DefaultRouteBehaviour: "BLOCK", }) @@ -1664,12 +1624,8 @@ func TestHashRouting(t *testing.T) { lc := local.NewLocalCoordinator(db) pr, err := qrouter.NewProxyRouter(map[string]*config.Shard{ - "sh1": { - Hosts: nil, - }, - "sh2": { - Hosts: nil, - }, + "sh1": {}, + "sh2": {}, }, lc, &config.QRouter{ DefaultRouteBehaviour: "BLOCK", }) diff --git a/router/route/route.go b/router/route/route.go index 89cece01c..8c450b4e0 100644 --- a/router/route/route.go +++ b/router/route/route.go @@ -44,7 +44,7 @@ type Route struct { frRule *config.FrontendRule clPool client.Pool - servPool pool.DBPool + servPool *pool.DBPool mu sync.Mutex // protects this @@ -58,10 +58,15 @@ func NewRoute(beRule *config.BackendRule, frRule *config.FrontendRule, mapping m sp.SearchPath = frRule.SearchPath } + var preferAZ string + if config.RouterConfig().PreferSameAvailabilityZone { + preferAZ = config.RouterConfig().AvailabilityZone + } + route := &Route{ beRule: beRule, frRule: frRule, - servPool: pool.NewDBPool(mapping, sp), + servPool: pool.NewDBPool(mapping, sp, preferAZ), clPool: client.NewClientPool(), params: shard.ParameterSet{}, } @@ -109,7 +114,7 @@ func (r *Route) Params() (shard.ParameterSet, error) { return r.params, nil } -func (r *Route) ServPool() pool.DBPool { +func (r *Route) ServPool() *pool.DBPool { return r.servPool } diff --git a/router/server/mirror.go b/router/server/mirror.go deleted file mode 100644 index 046851600..000000000 --- a/router/server/mirror.go +++ /dev/null @@ -1,42 +0,0 @@ -package server - -import ( - "github.com/jackc/pgx/v5/pgproto3" - "github.com/pg-sharding/spqr/pkg/pool" - "github.com/pg-sharding/spqr/pkg/shard" -) - -func NewMultiShardServer(pool pool.DBPool) (Server, error) { - ret := &MultiShardServer{ - pool: pool, - activeShards: []shard.Shard{}, - } - - return ret, nil -} - -type LoadMirroringServer struct { - Server - main Server - mirror Server -} - -var _ Server = &LoadMirroringServer{} - -func NewLoadMirroringServer(source Server, dest Server) *LoadMirroringServer { - return &LoadMirroringServer{ - main: source, - mirror: dest, - } -} - -func (LoadMirroringServer) Send(query pgproto3.FrontendMessage) error { - return nil -} -func (LoadMirroringServer) Receive() (pgproto3.BackendMessage, error) { - return nil, nil -} - -func (m *LoadMirroringServer) Datashards() []shard.Shard { - return []shard.Shard{} -} diff --git a/router/server/multishard.go b/router/server/multishard.go index 780051f6e..1e8a47a1e 100644 --- a/router/server/multishard.go +++ b/router/server/multishard.go @@ -43,13 +43,22 @@ type MultiShardServer struct { multistate MultishardState - pool pool.DBPool + pool *pool.DBPool status txstatus.TXStatus copyBuf []*pgproto3.CopyOutResponse } +func NewMultiShardServer(pool *pool.DBPool) (Server, error) { + ret := &MultiShardServer{ + pool: pool, + activeShards: []shard.Shard{}, + } + + return ret, nil +} + // HasPrepareStatement implements Server. func (m *MultiShardServer) HasPrepareStatement(hash uint64) (bool, *prepstatement.PreparedStatementDescriptor) { panic("unimplemented") diff --git a/router/server/shard.go b/router/server/shard.go index 3c37510dc..7f1995833 100644 --- a/router/server/shard.go +++ b/router/server/shard.go @@ -19,7 +19,7 @@ import ( var ErrShardUnavailable = fmt.Errorf("shard is unavailable, try again later") type ShardServer struct { - pool pool.DBPool + pool *pool.DBPool shard shard.Shard // protects shard mu sync.RWMutex @@ -34,7 +34,7 @@ func (srv *ShardServer) RequestData() { srv.shard.RequestData() } -func NewShardServer(spool pool.DBPool) *ShardServer { +func NewShardServer(spool *pool.DBPool) *ShardServer { return &ShardServer{ pool: spool, mu: sync.RWMutex{},