diff --git a/conn.go b/conn.go index 7138665..ffe1ae0 100644 --- a/conn.go +++ b/conn.go @@ -147,6 +147,46 @@ func (c *Conn) Execute(m Message) ([]Message, error) { return res, nil } +// ExecuteFunc sends a single Message to netlink using Send and passes one or more +// replies obtained by Receive to the provided callback func after checking each +// reply for validity using Validate. +// +// ExecuteFunc acquires a lock for the duration of the function call which blocks +// concurrent calls to Send, SendMessages, and Receive, in order to ensure +// consistency between netlink request/reply messages - do not call methods on this +// Conn from the provided callback func or your application may deadlock. +// +// See the documentation of Send, Receive, and Validate for details about +// each function. +func (c *Conn) ExecuteFunc(m Message, cb func(Message)) error { + // Acquire the write lock and invoke the internal implementations of Send + // and Receive which require the lock already be held. + c.mu.Lock() + defer c.mu.Unlock() + + req, err := c.lockedSend(m) + if err != nil { + return err + } + + var validateErr error + err = c.lockedReceiveEach(func(m Message) { + if err := Validate(req, []Message{m}); err != nil { + validateErr = err + return + } + cb(m) + }) + if err != nil { + return err + } + if validateErr != nil { + return validateErr + } + + return nil +} + // SendMessages sends multiple Messages to netlink. The handling of // a Header's Length, Sequence and PID fields is the same as when // calling Send. @@ -231,11 +271,27 @@ func (c *Conn) Receive() ([]Message, error) { return c.lockedReceive() } +// ReceiveEach receives one or more messages from netlink. Multi-part messages are +// handled transparently and a callback invoked for ech message, with the +// final empty "multi-part done" message removed. +// +// If any of the messages indicate a netlink error, that error will be returned. +func (c *Conn) ReceiveEach(cb func(Message)) error { + // Wait for any concurrent calls to Execute to finish before proceeding. + c.mu.RLock() + defer c.mu.RUnlock() + + return c.lockedReceiveEach(cb) +} + // lockedReceive implements Receive, but must be called with c.mu acquired for reading. // We rely on the kernel to deal with concurrent reads and writes to the netlink // socket itself. func (c *Conn) lockedReceive() ([]Message, error) { - msgs, err := c.receive() + var msgs []Message + err := c.receiveEach(func(m Message) { + msgs = append(msgs, m) + }) if err != nil { c.debug(func(d *debugger) { d.debugf(1, "recv: err: %v", err) @@ -250,23 +306,34 @@ func (c *Conn) lockedReceive() ([]Message, error) { } }) - // When using nltest, it's possible for zero messages to be returned by receive. - if len(msgs) == 0 { - return msgs, nil - } + return msgs, nil +} + +// lockedReceive implements Receive, but must be called with c.mu acquired for reading. +// We rely on the kernel to deal with concurrent reads and writes to the netlink +// socket itself. +func (c *Conn) lockedReceiveEach(cb func(Message)) error { + err := c.receiveEach(func(m Message) { + cb(m) + + c.debug(func(d *debugger) { + d.debugf(1, "recv: %+v", m) + }) + }) + if err != nil { + c.debug(func(d *debugger) { + d.debugf(1, "recv: err: %v", err) + }) - // Trim the final message with multi-part done indicator if - // present. - if m := msgs[len(msgs)-1]; m.Header.Flags&Multi != 0 && m.Header.Type == Done { - return msgs[:len(msgs)-1], nil + return err } - return msgs, nil + return nil } // receive is the internal implementation of Conn.Receive, which can be called // recursively to handle multi-part messages. -func (c *Conn) receive() ([]Message, error) { +func (c *Conn) receiveEach(cb func(Message)) error { // NB: All non-nil errors returned from this function *must* be of type // OpError in order to maintain the appropriate contract with callers of // this package. @@ -274,40 +341,38 @@ func (c *Conn) receive() ([]Message, error) { // This contract also applies to functions called within this function, // such as checkMessage. - var res []Message - for { + for more := true; more; { msgs, err := c.sock.Receive() if err != nil { - return nil, newOpError("receive", err) + return newOpError("receive", err) } - // If this message is multi-part, we will need to continue looping to - // drain all the messages from the socket. - var multi bool - + more = false + multipartDone := false for _, m := range msgs { if err := checkMessage(m); err != nil { - return nil, err + return err } // Does this message indicate a multi-part message? - if m.Header.Flags&Multi == 0 { - // No, check the next messages. - continue + if m.Header.Flags&Multi != 0 { + multipartDone = m.Header.Type == Done + more = !multipartDone } - - // Does this message indicate the last message in a series of - // multi-part messages from a single read? - multi = m.Header.Type != Done } - res = append(res, msgs...) + // Trim the final message with multi-part done indicator if + // present. + if multipartDone { + msgs = msgs[:len(msgs)-1] + } - if !multi { - // No more messages coming. - return res, nil + for _, m := range msgs { + cb(m) } } + + return nil } // A groupJoinLeaver is a Socket that supports joining and leaving diff --git a/conn_linux_integration_test.go b/conn_linux_integration_test.go index 7442f92..0cee24d 100644 --- a/conn_linux_integration_test.go +++ b/conn_linux_integration_test.go @@ -90,6 +90,76 @@ func TestIntegrationConn(t *testing.T) { } } +func TestIntegrationConnFunc(t *testing.T) { + t.Parallel() + + c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) + if err != nil { + t.Fatalf("failed to dial netlink: %v", err) + } + + // Ask to send us an acknowledgement, which will contain an + // error code (or success) and a copy of the payload we sent in + req := netlink.Message{ + Header: netlink.Header{ + Flags: netlink.Request | netlink.Acknowledge, + }, + } + + // Perform a request using ExecuteFunc, receive replies, and validate the replies + var msgs []netlink.Message + err = c.ExecuteFunc(req, func(m netlink.Message) { + msgs = append(msgs, m) + }) + if err != nil { + t.Fatalf("failed to execute request: %v", err) + } + if want, got := 1, len(msgs); want != got { + t.Fatalf("unexpected message count from netlink:\n- want: %v\n- got: %v", + want, got) + } + + if err := c.Close(); err != nil { + t.Fatalf("error closing netlink connection: %v", err) + } + + m := msgs[0] + + if want, got := 0, int(nlenc.Uint32(m.Data[0:4])); want != got { + t.Fatalf("unexpected error code:\n- want: %v\n- got: %v", want, got) + } + + if want, got := 36, int(m.Header.Length); want != got { + t.Fatalf("unexpected header length:\n- want: %v\n- got: %v", want, got) + } + if want, got := netlink.Error, m.Header.Type; want != got { + t.Fatalf("unexpected header type:\n- want: %v\n- got: %v", want, got) + } + // Recent kernel versions (> 4.14) return a 256 here instead of a 0 + if want, wantAlt, got := 0, 256, int(m.Header.Flags); want != got && wantAlt != got { + t.Fatalf("unexpected header flags:\n- want: %v or %v\n- got: %v", want, wantAlt, got) + } + + // Sequence number is not checked because we assign one at random when + // a Conn is created. PID is not checked because running tests in parallel + // results in only the first socket getting assigned the process's PID as + // its netlink PID. + + // Skip error code and unmarshal the copy of request sent back by + // skipping the success code at bytes 0-4 + var reply netlink.Message + if err := (&reply).UnmarshalBinary(m.Data[4:]); err != nil { + t.Fatalf("failed to unmarshal reply: %v", err) + } + + if want, got := req.Header.Flags, reply.Header.Flags; want != got { + t.Fatalf("unexpected copy header flags:\n- want: %v\n- got: %v", want, got) + } + if want, got := len(req.Data), len(reply.Data); want != got { + t.Fatalf("unexpected copy header data length:\n- want: %v\n- got: %v", want, got) + } +} + func TestIntegrationConnConcurrentManyConns(t *testing.T) { t.Parallel() skipShort(t)