Skip to content

Commit

Permalink
uacp: add support for read timeouts
Browse files Browse the repository at this point in the history
This prevents io.ReadFull to hang.
  • Loading branch information
jgould authored and magiconair committed Mar 23, 2023
1 parent 119433f commit 48b28c6
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 7 deletions.
12 changes: 11 additions & 1 deletion config.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func NewDialer(cfg *Config) *uacp.Dialer {
// ApplyConfig applies the config options to the default configuration.
// todo(fs): Can we find a better name?
//
// Note: Starting with v0.5 this function will will return an error.
// Note: Starting with v0.5 this function will return an error.
func ApplyConfig(opts ...Option) *Config {
cfg := &Config{
sechan: DefaultClientConfig(),
Expand Down Expand Up @@ -501,6 +501,16 @@ func DialTimeout(d time.Duration) Option {
}
}

// ReadTimeout sets the timeout for every read operation.
func ReadTimeout(d time.Duration) Option {
return func(cfg *Config) {
initDialer(cfg)
cfg.dialer.ReadTimeout = d
}
}
}
}

// MaxMessageSize sets the maximum message size for the UACP handshake.
func MaxMessageSize(n uint32) Option {
return func(cfg *Config) {
Expand Down
36 changes: 30 additions & 6 deletions uacp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ type Dialer struct {
// ClientACK defines the connection parameters requested by the client.
// Defaults to DefaultClientACK.
ClientACK *Acknowledge

// ReadTimeout sets a read timeout for reading a full response from the
// underlying network connection. ReadTimeout is ignored if it is <= 0.
ReadTimeout time.Duration
}

func (d *Dialer) Dial(ctx context.Context, endpoint string) (*Conn, error) {
Expand All @@ -88,6 +92,7 @@ func (d *Dialer) Dial(ctx context.Context, endpoint string) (*Conn, error) {
c.Close()
return nil, err
}
conn.readTimeout = d.ReadTimeout

debug.Printf("uacp %d: start HEL/ACK handshake", conn.id)
if err := conn.Handshake(ctx, endpoint); err != nil {
Expand Down Expand Up @@ -174,7 +179,8 @@ type Conn struct {
id uint32
ack *Acknowledge

closeOnce sync.Once
closeOnce sync.Once
readTimeout time.Duration
}

func NewConn(c *net.TCPConn, ack *Acknowledge) (*Conn, error) {
Expand Down Expand Up @@ -351,15 +357,25 @@ const hdrlen = 8
// The size of b must be at least ReceiveBufSize. Otherwise,
// the function returns an error.
func (c *Conn) Receive() ([]byte, error) {
// TODO(kung-foo): allow user-specified buffer
// TODO(kung-foo): sync.Pool
// todo(kung-foo): allow user-specified buffer
// todo(kung-foo): sync.Pool
b := make([]byte, c.ack.ReceiveBufSize)

if _, err := io.ReadFull(c, b[:hdrlen]); err != nil {
if c.readTimeout > 0 {
if err := c.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil {
return nil, errors.Errorf("uacp: failed to set read timeout: %w", err)
}
}

n, err := c.Read(b[:hdrlen])
if err != nil {
// todo(fs): do not wrap this error since it hides io.EOF
// todo(fs): use golang.org/x/xerrors
return nil, err
}
if n != hdrlen {
return nil, errors.Errorf("uacp: short read on header. got %d bytes, want %d ", n, hdrlen)
}

var h Header
if _, err := h.Decode(b[:hdrlen]); err != nil {
Expand All @@ -370,18 +386,26 @@ func (c *Conn) Receive() ([]byte, error) {
return nil, errors.Errorf("uacp: message too large: %d > %d bytes", h.MessageSize, c.ack.ReceiveBufSize)
}

if _, err := io.ReadFull(c, b[hdrlen:h.MessageSize]); err != nil {
n, err = c.Read(b[hdrlen:h.MessageSize])
if err != nil {
// todo(fs): do not wrap this error since it hides io.EOF
// todo(fs): use golang.org/x/xerrors
return nil, err
}

// clear the deadline
c.SetReadDeadline(time.Time{})

if uint32(n) != h.MessageSize-hdrlen {
return nil, fmt.Errorf("uacp %d: short read on message. got %d bytes, want %d", c.id, n, h.MessageSize-hdrlen)
}

debug.Printf("uacp %d: recv %s%c with %d bytes", c.id, h.MessageType, h.ChunkType, h.MessageSize)

if h.MessageType == "ERR" {
errf := new(Error)
if _, err := errf.Decode(b[hdrlen:h.MessageSize]); err != nil {
return nil, errors.Errorf("uacp: failed to decode ERRF message: %s", err)
return nil, errors.Errorf("uacp: failed to decode ERRF message: %w", err)
}
return nil, errf
}
Expand Down

0 comments on commit 48b28c6

Please sign in to comment.