From b6312f0faba3b3b15ca4cd672ea069019a7f6ec3 Mon Sep 17 00:00:00 2001 From: Pavel Gabriel Date: Fri, 21 Jul 2023 21:29:12 +0200 Subject: [PATCH] make writes of the data from Write sequential with other writes (#55) * make writes of the data from Write sequential with other writes * update the comment for Write method --- connection.go | 74 ++++++++++++++++++++++++------------ connection_test.go | 93 +++++++++++++++++++++++++++------------------- 2 files changed, 105 insertions(+), 62 deletions(-) diff --git a/connection.go b/connection.go index 40ee71f..f7fc25f 100644 --- a/connection.go +++ b/connection.go @@ -2,7 +2,6 @@ package connection import ( "bufio" - "bytes" "crypto/tls" "errors" "fmt" @@ -40,6 +39,12 @@ const ( StatusUnknown Status = "" ) +// directWrite is used to write data directly to the connection +type directWrite struct { + data []byte + errCh chan error +} + // Connection represents an ISO 8583 Connection. Connection may be used // by multiple goroutines simultaneously. type Connection struct { @@ -48,6 +53,7 @@ type Connection struct { conn io.ReadWriteCloser requestsCh chan request readResponseCh chan *iso8583.Message + directWriteCh chan directWrite done chan struct{} // spec that will be used to unpack received messages @@ -93,6 +99,7 @@ func New(addr string, spec *iso8583.MessageSpec, mlReader MessageLengthReader, m Opts: opts, requestsCh: make(chan request), readResponseCh: make(chan *iso8583.Message), + directWriteCh: make(chan directWrite), done: make(chan struct{}), respMap: make(map[string]response), spec: spec, @@ -169,13 +176,25 @@ func (c *Connection) Connect() error { return nil } -// Write writes data directly to the connection. Writes are atomic for -// net.TCPConn and tls.Conn and can be called simultaneously from multiple -// goroutines. But you should write whole message (including its header, etc.) -// at once, don't split one message into multiple Write calls. -// It's the caller's responsibility to handle the error returned from Write. +// Write writes data directly to the connection. It is crucial to note that the +// Write operation is atomic in nature, meaning it completes in a single +// uninterrupted step. +// When writing data, the entire message—including its header and any other +// components—should be written in one go. Splitting a single message into +// multiple Write calls is dangerous, as it could lead to unexpected behavior +// or errors. func (c *Connection) Write(p []byte) (int, error) { - return c.conn.Write(p) + dw := directWrite{ + data: p, + errCh: make(chan error, 1), + } + + select { + case c.directWriteCh <- dw: + return len(p), <-dw.errCh + case <-c.done: + return 0, ErrConnectionClosed + } } // run starts read and write loops in goroutines @@ -380,34 +399,30 @@ func (c *Connection) Send(message *iso8583.Message) (*iso8583.Message, error) { func (c *Connection) writeMessage(w io.Writer, message *iso8583.Message) error { if c.Opts.MessageWriter != nil { - return c.Opts.MessageWriter.WriteMessage(w, message) + err := c.Opts.MessageWriter.WriteMessage(c.conn, message) + if err != nil { + return fmt.Errorf("writing message: %w", err) + } + + return nil } - // default message writer + // if no custom message writer is set, use default one + packed, err := message.Pack() if err != nil { return fmt.Errorf("packing message: %w", err) } - // create buffer for header and packed message so we can write it to - // the connection as a single write - buf := &bytes.Buffer{} - // create header - _, err = c.writeMessageLength(buf, len(packed)) - if err != nil { - return fmt.Errorf("writing message header to buffer: %w", err) - } - - _, err = buf.Write(packed) + _, err = c.writeMessageLength(c.conn, len(packed)) if err != nil { - return fmt.Errorf("writing packed message to buffer: %w", err) + return fmt.Errorf("writing message length: %w", err) } - // write buffer to the connection as a single write (atomic) - _, err = buf.WriteTo(w) + _, err = c.conn.Write(packed) if err != nil { - return fmt.Errorf("writing buffer to the connection: %w", err) + return fmt.Errorf("writing packed message: %w", err) } return nil @@ -533,6 +548,19 @@ func (c *Connection) writeLoop() { if req.replyCh == nil { req.errCh <- nil } + + case dw := <-c.directWriteCh: + _, err = c.conn.Write(dw.data) + if err != nil { + c.handleError(fmt.Errorf("writing data: %w", err)) + dw.errCh <- err + + // we can't continue to write other messages or data when we failed to write + // one of them + break + } + dw.errCh <- nil + case <-idleTimeTimer.C: // if no message was sent during idle time, we have to send ping message if c.Opts.PingHandler != nil { diff --git a/connection_test.go b/connection_test.go index 990b6ac..ef91a30 100644 --- a/connection_test.go +++ b/connection_test.go @@ -178,58 +178,73 @@ func TestClient_Write(t *testing.T) { require.NoError(t, err) defer server.Close() - var called atomic.Int32 + t.Run("write into the connection", func(t *testing.T) { + var called atomic.Int32 + inboundMessageHandler := func(c *connection.Connection, message *iso8583.Message) { + called.Add(1) - inboundMessageHandler := func(c *connection.Connection, message *iso8583.Message) { - called.Add(1) + mti, err := message.GetMTI() + require.NoError(t, err) + require.Equal(t, "0810", mti) + } - mti, err := message.GetMTI() + // we should be able to write any bytes to the server + c, err := connection.New(server.Addr, testSpec, readMessageLength, writeMessageLength, connection.InboundMessageHandler(inboundMessageHandler)) require.NoError(t, err) - require.Equal(t, "0810", mti) - } + defer c.Close() - // we should be able to write any bytes to the server - c, err := connection.New(server.Addr, testSpec, readMessageLength, writeMessageLength, connection.InboundMessageHandler(inboundMessageHandler)) - require.NoError(t, err) - defer c.Close() + err = c.Connect() + require.NoError(t, err) - err = c.Connect() - require.NoError(t, err) + // let's create data to write to the server, we will prepare header and + // packed message + + // network management message + message := iso8583.NewMessage(testSpec) + err = message.Marshal(baseFields{ + MTI: field.NewStringValue("0800"), + TestCaseCode: field.NewStringValue(TestCaseReply), + STAN: field.NewStringValue(getSTAN()), + }) + require.NoError(t, err) - // let's create data to write to the server, we will prepare header and - // packed message + packed, err := message.Pack() + require.NoError(t, err) - // network management message - message := iso8583.NewMessage(testSpec) - err = message.Marshal(baseFields{ - MTI: field.NewStringValue("0800"), - TestCaseCode: field.NewStringValue(TestCaseReply), - STAN: field.NewStringValue(getSTAN()), - }) - require.NoError(t, err) + // prepare header + header := &bytes.Buffer{} + _, err = writeMessageLength(header, len(packed)) + require.NoError(t, err) - packed, err := message.Pack() - require.NoError(t, err) + // combine header and message + data := append(header.Bytes(), packed...) - // prepare header - header := &bytes.Buffer{} - _, err = writeMessageLength(header, len(packed)) - require.NoError(t, err) + // write the data directly to the connection + n, err := c.Write(data) + + require.NoError(t, err) + require.Equal(t, len(data), n) + + // we should expect to get reply, but as we are not using Send method, + // the reply will be handled by InboundMessageHandler + require.Eventually(t, func() bool { + return called.Load() == 1 + }, 100*time.Millisecond, 20*time.Millisecond, "inboundMessageHandler should be called") + }) - // combine header and message - data := append(header.Bytes(), packed...) + t.Run("write into the closed connection ", func(t *testing.T) { + c, err := connection.New(server.Addr, testSpec, readMessageLength, writeMessageLength) + require.NoError(t, err) + defer c.Close() - // write the data directly to the connection - n, err := c.Write(data) + err = c.Connect() + require.NoError(t, err) - require.NoError(t, err) - require.Equal(t, len(data), n) + c.Close() - // we should expect to get reply, but as we are not using Send method, - // the reply will be handled by InboundMessageHandler - require.Eventually(t, func() bool { - return called.Load() == 1 - }, 100*time.Millisecond, 20*time.Millisecond, "inboundMessageHandler should be called") + _, err = c.Write([]byte("hello")) + require.Error(t, err) + }) } func TestClient_Send(t *testing.T) {