Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conncheck: use unix.Poll instead of syscall.Read #1456

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions conncheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,14 @@
package mysql

import (
"errors"
"io"
"fmt"
"net"
"syscall"
)

var errUnexpectedRead = errors.New("unexpected read from socket")
"golang.org/x/sys/unix"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work on Windows?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it doesn't, as current implementation too.

)

func connCheck(conn net.Conn) error {
var sysErr error

sysConn, ok := conn.(syscall.Conn)
if !ok {
return nil
Expand All @@ -32,24 +29,22 @@ func connCheck(conn net.Conn) error {
return err
}

err = rawConn.Read(func(fd uintptr) bool {
var buf [1]byte
n, err := syscall.Read(int(fd), buf[:])
switch {
case n == 0 && err == nil:
sysErr = io.EOF
case n > 0:
sysErr = errUnexpectedRead
case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK:
sysErr = nil
default:
sysErr = err
var pollErr error
err = rawConn.Control(func(fd uintptr) {
fds := []unix.PollFd{
{Fd: int32(fd), Events: unix.POLLIN | unix.POLLERR},
}
n, err := unix.Poll(fds, 0)
if err != nil {
pollErr = fmt.Errorf("poll: %w", err)
}
if n > 0 {
// fmt.Errorf("poll: %v", fds[0].Revents)
pollErr = errUnexpectedEvent
}
return true
})
if err != nil {
return err
}

return sysErr
return pollErr
}
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/go-sql-driver/mysql

go 1.18

require golang.org/x/sys v0.10.0 // indirect
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
53 changes: 40 additions & 13 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"database/sql/driver"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"math"
Expand Down Expand Up @@ -44,12 +45,24 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {

// check packet sync [8 bit]
if data[3] != mc.sequence {
var syncErr error
if data[3] > mc.sequence {
return nil, ErrPktSyncMul
syncErr = ErrPktSyncMul
} else {
syncErr = ErrPktSync
}
return nil, ErrPktSync

if prevData != nil {
return nil, syncErr
} else {
// log and ignore seqno mismatch error.
// MySQL sometimes sends wrong sequence no.
mc.cfg.Logger.Print(syncErr)
mc.sequence = data[3] + 1
}
} else {
mc.sequence++
}
mc.sequence++

// packets with length 0 terminate a previous packet which is a
// multiple of (2^24)-1 bytes long
Expand Down Expand Up @@ -89,6 +102,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
}
}

// used in conncheck.go
var errUnexpectedEvent = errors.New("recieved unexpected event")

// Write packet buffer 'data'
func (mc *mysqlConn) writePacket(data []byte) error {
pktLen := len(data) - 4
Expand All @@ -111,18 +127,29 @@ func (mc *mysqlConn) writePacket(data []byte) error {
}
var err error
if mc.cfg.CheckConnLiveness {
if mc.cfg.ReadTimeout != 0 {
err = conn.SetReadDeadline(time.Now().Add(mc.cfg.ReadTimeout))
}
if err == nil {
err = connCheck(conn)
err = connCheck(conn)
if err != nil {
if err == errUnexpectedEvent {
_ = conn.SetReadDeadline(time.Now().Add(time.Second))
var data []byte
data, err = mc.readPacket()

if err == nil {
if data[0] == iERR {
err = mc.handleErrorPacket(data)
} else {
err = fmt.Errorf("unexpected packet: % x", data[:128])
}
} else {
err = fmt.Errorf("readPacket(): %w", err)
}
}

mc.cfg.Logger.Print("checkConn() failed: ", err)
mc.Close()
return driver.ErrBadConn
}
}
if err != nil {
mc.cfg.Logger.Print("closing bad idle connection: ", err)
mc.Close()
return driver.ErrBadConn
}
}

for {
Expand Down
39 changes: 33 additions & 6 deletions packets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ package mysql
import (
"bytes"
"errors"
"fmt"
"net"
"testing"
"time"
Expand Down Expand Up @@ -132,31 +133,57 @@ func TestReadPacketSingleByte(t *testing.T) {
}
}

type mockLogger struct {
bytes.Buffer
}

func (ml *mockLogger) Print(v ...any) {
ml.WriteString(fmt.Sprint(v...) + "\n")
}

func TestReadPacketWrongSequenceID(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
cfg: NewConfig(),
}
logger := &mockLogger{}
mc.cfg.Logger = Logger(logger)

// too low sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
conn.maxReads = 1
mc.sequence = 1
_, err := mc.readPacket()
if err != ErrPktSync {
t.Errorf("expected ErrPktSync, got %v", err)
data, err := mc.readPacket()
if err != nil {
t.Errorf("expected nil, got %v", err)
}
if len(data) != 1 || data[0] != 0xff {
t.Errorf("expected [0xff], got % x", data)
}
logMsg := logger.String()
if logMsg != ErrPktSync.Error()+"\n" {
t.Errorf("expected ErrPktSync.Error(), got %q", logMsg)
}

// reset
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)
logger.Reset()

// too high sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
_, err = mc.readPacket()
if err != ErrPktSyncMul {
t.Errorf("expected ErrPktSyncMul, got %v", err)
data, err = mc.readPacket()
if err != nil {
t.Errorf("expected nil, got %v", err)
}
if len(data) != 1 || data[0] != 0xff {
t.Errorf("expected [0xff], got % x", data)
}
logMsg = logger.String()
if logMsg != ErrPktSyncMul.Error()+"\n" {
t.Errorf("expected ErrPktSync.Error(), got %q", logMsg)
}
}

Expand Down