Skip to content

Commit

Permalink
Fix the race condition during vttablet startup (#15731)
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <[email protected]>
  • Loading branch information
dbussink authored Apr 17, 2024
1 parent 4519c8f commit 178e6e8
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 18 deletions.
29 changes: 22 additions & 7 deletions go/test/endtoend/utils/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"os"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -129,7 +130,9 @@ func TestSetSuperReadOnlyMySQL(t *testing.T) {
func TestGetMysqlPort(t *testing.T) {
require.NotNil(t, mysqld)

port, err := mysqld.GetMysqlPort()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
port, err := mysqld.GetMysqlPort(ctx)

// Expected port should be one less than the port returned by GetAndReservePort
// As we are calling this second time to get port
Expand Down Expand Up @@ -161,7 +164,9 @@ func TestReplicationStatus(t *testing.T) {
conn, err := mysql.Connect(ctx, &mysqlParams)
require.NoError(t, err)

port, err := mysqld.GetMysqlPort()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
port, err := mysqld.GetMysqlPort(ctx)
require.NoError(t, err)
host := "localhost"

Expand Down Expand Up @@ -234,7 +239,9 @@ func TestSetReplicationPosition(t *testing.T) {
func TestSetAndResetReplication(t *testing.T) {
require.NotNil(t, mysqld)

port, err := mysqld.GetMysqlPort()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
port, err := mysqld.GetMysqlPort(ctx)
require.NoError(t, err)
host := "localhost"

Expand Down Expand Up @@ -387,7 +394,9 @@ func TestWaitForReplicationStart(t *testing.T) {
err := mysqlctl.WaitForReplicationStart(mysqld, 1)
assert.ErrorContains(t, err, "no replication status")

port, err := mysqld.GetMysqlPort()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
port, err := mysqld.GetMysqlPort(ctx)
require.NoError(t, err)
host := "localhost"

Expand All @@ -407,7 +416,9 @@ func TestStartReplication(t *testing.T) {
err := mysqld.StartReplication(map[string]string{})
assert.ErrorContains(t, err, "The server is not configured as replica")

port, err := mysqld.GetMysqlPort()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
port, err := mysqld.GetMysqlPort(ctx)
require.NoError(t, err)
host := "localhost"

Expand All @@ -425,7 +436,9 @@ func TestStartReplication(t *testing.T) {
func TestStopReplication(t *testing.T) {
require.NotNil(t, mysqld)

port, err := mysqld.GetMysqlPort()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
port, err := mysqld.GetMysqlPort(ctx)
require.NoError(t, err)
host := "localhost"

Expand All @@ -449,7 +462,9 @@ func TestStopReplication(t *testing.T) {
func TestStopSQLThread(t *testing.T) {
require.NotNil(t, mysqld)

port, err := mysqld.GetMysqlPort()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
port, err := mysqld.GetMysqlPort(ctx)
require.NoError(t, err)
host := "localhost"

Expand Down
2 changes: 1 addition & 1 deletion go/vt/mysqlctl/fakemysqldaemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ func (fmd *FakeMysqlDaemon) WaitForDBAGrants(ctx context.Context, waitTime time.
}

// GetMysqlPort is part of the MysqlDaemon interface.
func (fmd *FakeMysqlDaemon) GetMysqlPort() (int32, error) {
func (fmd *FakeMysqlDaemon) GetMysqlPort(ctx context.Context) (int32, error) {
if fmd.MysqlPort.Load() == -1 {
return 0, fmt.Errorf("FakeMysqlDaemon.GetMysqlPort returns an error")
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/mysqlctl/mysql_daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type MysqlDaemon interface {
WaitForDBAGrants(ctx context.Context, waitTime time.Duration) (err error)

// GetMysqlPort returns the current port mysql is listening on.
GetMysqlPort() (int32, error)
GetMysqlPort(ctx context.Context) (int32, error)

// GetServerID returns the servers ID.
GetServerID(ctx context.Context) (uint32, error)
Expand Down
6 changes: 5 additions & 1 deletion go/vt/mysqlctl/mysqld.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,11 +521,15 @@ func (mysqld *Mysqld) WaitForDBAGrants(ctx context.Context, waitTime time.Durati
if waitTime == 0 {
return nil
}
params, err := mysqld.dbcfgs.DbaConnector().MysqlParams()
if err != nil {
return err
}
timer := time.NewTimer(waitTime)
ctx, cancel := context.WithTimeout(ctx, waitTime)
defer cancel()
for {
conn, connErr := dbconnpool.NewDBConnection(ctx, mysqld.dbcfgs.DbaConnector())
conn, connErr := mysql.Connect(ctx, params)
if connErr == nil {
res, fetchErr := conn.ExecuteFetch("SHOW GRANTS", 1000, false)
conn.Close()
Expand Down
18 changes: 16 additions & 2 deletions go/vt/mysqlctl/replication.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"strings"
"time"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/mysql/replication"
"vitess.io/vitess/go/netutil"
"vitess.io/vitess/go/vt/hook"
Expand Down Expand Up @@ -174,8 +175,21 @@ func (mysqld *Mysqld) RestartReplication(hookExtraEnv map[string]string) error {
}

// GetMysqlPort returns mysql port
func (mysqld *Mysqld) GetMysqlPort() (int32, error) {
qr, err := mysqld.FetchSuperQuery(context.TODO(), "SHOW VARIABLES LIKE 'port'")
func (mysqld *Mysqld) GetMysqlPort(ctx context.Context) (int32, error) {
// We can not use the connection pool here. This check runs very early
// during MySQL startup when we still might be loading things like grants.
// This means we need to use an isolated connection to avoid poisoning the
// DBA connection pool for further queries.
params, err := mysqld.dbcfgs.DbaConnector().MysqlParams()
if err != nil {
return 0, err
}
conn, err := mysql.Connect(ctx, params)
if err != nil {
return 0, err
}
defer conn.Close()
qr, err := conn.ExecuteFetch("SHOW VARIABLES LIKE 'port'", 1, false)
if err != nil {
return 0, err
}
Expand Down
7 changes: 5 additions & 2 deletions go/vt/mysqlctl/replication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"net"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -133,12 +134,14 @@ func TestGetMysqlPort(t *testing.T) {
testMysqld := NewMysqld(dbc)
defer testMysqld.Close()

res, err := testMysqld.GetMysqlPort()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
res, err := testMysqld.GetMysqlPort(ctx)
assert.Equal(t, int32(12), res)
assert.NoError(t, err)

db.AddQuery("SHOW VARIABLES LIKE 'port'", &sqltypes.Result{})
res, err = testMysqld.GetMysqlPort()
res, err = testMysqld.GetMysqlPort(ctx)
assert.ErrorContains(t, err, "no port variable in mysql")
assert.Equal(t, int32(0), res)
}
Expand Down
16 changes: 12 additions & 4 deletions go/vt/vttablet/tabletmanager/tm_init.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ func (tm *TabletManager) Start(tablet *topodatapb.Tablet, config *tabletenv.Tabl
if err := tm.checkPrimaryShip(ctx, si); err != nil {
return err
}
if err := tm.checkMysql(); err != nil {
if err := tm.checkMysql(ctx); err != nil {
return err
}
if err := tm.initTablet(ctx); err != nil {
Expand Down Expand Up @@ -702,7 +702,7 @@ func (tm *TabletManager) checkPrimaryShip(ctx context.Context, si *topo.ShardInf
return nil
}

func (tm *TabletManager) checkMysql() error {
func (tm *TabletManager) checkMysql(ctx context.Context) error {
appConfig, err := tm.DBConfigs.AppWithDB().MysqlParams()
if err != nil {
return err
Expand All @@ -717,7 +717,7 @@ func (tm *TabletManager) checkMysql() error {
tm.tmState.UpdateTablet(func(tablet *topodatapb.Tablet) {
tablet.MysqlHostname = tablet.Hostname
})
mysqlPort, err := tm.MysqlDaemon.GetMysqlPort()
mysqlPort, err := tm.MysqlDaemon.GetMysqlPort(ctx)
if err != nil {
log.Warningf("Cannot get current mysql port, will keep retrying every %v: %v", mysqlPortRetryInterval, err)
go tm.findMysqlPort(mysqlPortRetryInterval)
Expand All @@ -730,10 +730,18 @@ func (tm *TabletManager) checkMysql() error {
return nil
}

const portCheckTimeout = 5 * time.Second

func (tm *TabletManager) getMysqlPort() (int32, error) {
ctx, cancel := context.WithTimeout(context.Background(), portCheckTimeout)
defer cancel()
return tm.MysqlDaemon.GetMysqlPort(ctx)
}

func (tm *TabletManager) findMysqlPort(retryInterval time.Duration) {
for {
time.Sleep(retryInterval)
mport, err := tm.MysqlDaemon.GetMysqlPort()
mport, err := tm.getMysqlPort()
if err != nil || mport == 0 {
continue
}
Expand Down

0 comments on commit 178e6e8

Please sign in to comment.