Skip to content

Commit

Permalink
uacp: add support for write timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
magiconair committed Mar 23, 2023
1 parent 48b28c6 commit beb72fb
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
6 changes: 6 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,12 @@ func ReadTimeout(d time.Duration) Option {
cfg.dialer.ReadTimeout = d
}
}

// WriteTimeout sets the timeout for every write operation.
func WriteTimeout(d time.Duration) Option {
return func(cfg *Config) {
initDialer(cfg)
cfg.dialer.WriteTimeout = d
}
}

Expand Down
19 changes: 17 additions & 2 deletions uacp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ type Dialer struct {
// ReadTimeout sets a read timeout for reading a full response from the
// underlying network connection. ReadTimeout is ignored if it is <= 0.
ReadTimeout time.Duration

// WriteTimeout sets a write timeout for sending a request on the
// underlying network connection. WriteTimeout is ignored if it is <= 0.
WriteTimeout time.Duration
}

func (d *Dialer) Dial(ctx context.Context, endpoint string) (*Conn, error) {
Expand All @@ -93,6 +97,7 @@ func (d *Dialer) Dial(ctx context.Context, endpoint string) (*Conn, error) {
return nil, err
}
conn.readTimeout = d.ReadTimeout
conn.writeTimeout = d.WriteTimeout

debug.Printf("uacp %d: start HEL/ACK handshake", conn.id)
if err := conn.Handshake(ctx, endpoint); err != nil {
Expand Down Expand Up @@ -179,8 +184,9 @@ type Conn struct {
id uint32
ack *Acknowledge

closeOnce sync.Once
closeOnce sync.Once
readTimeout time.Duration
writeTimeout time.Duration
}

func NewConn(c *net.TCPConn, ack *Acknowledge) (*Conn, error) {
Expand Down Expand Up @@ -419,7 +425,7 @@ func (c *Conn) Send(typ string, msg interface{}) error {

body, err := ua.Encode(msg)
if err != nil {
return errors.Errorf("encode msg failed: %s", err)
return errors.Errorf("encode msg failed: %w", err)
}

h := Header{
Expand All @@ -437,12 +443,21 @@ func (c *Conn) Send(typ string, msg interface{}) error {
return errors.Errorf("encode hdr failed: %s", err)
}

if c.writeTimeout > 0 {
if err := c.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
return errors.Errorf("failed to set write timeout: %w", err)
}
}

b := append(hdr, body...)
if _, err := c.Write(b); err != nil {
return errors.Errorf("write failed: %s", err)
}
debug.Printf("uacp %d: sent %s with %d bytes", c.id, typ, len(b))

// clear the deadline
c.SetWriteDeadline(time.Time{})

return nil
}

Expand Down

0 comments on commit beb72fb

Please sign in to comment.