Skip to content

Commit

Permalink
make writes of the data from Write sequential with other writes (#55)
Browse files Browse the repository at this point in the history
* make writes of the data from Write sequential with other writes

* update the comment for Write method
  • Loading branch information
alovak authored Jul 21, 2023
1 parent 28e94e2 commit b6312f0
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 62 deletions.
74 changes: 51 additions & 23 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package connection

import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"fmt"
Expand Down Expand Up @@ -40,6 +39,12 @@ const (
StatusUnknown Status = ""
)

// directWrite is used to write data directly to the connection
type directWrite struct {
data []byte
errCh chan error
}

// Connection represents an ISO 8583 Connection. Connection may be used
// by multiple goroutines simultaneously.
type Connection struct {
Expand All @@ -48,6 +53,7 @@ type Connection struct {
conn io.ReadWriteCloser
requestsCh chan request
readResponseCh chan *iso8583.Message
directWriteCh chan directWrite
done chan struct{}

// spec that will be used to unpack received messages
Expand Down Expand Up @@ -93,6 +99,7 @@ func New(addr string, spec *iso8583.MessageSpec, mlReader MessageLengthReader, m
Opts: opts,
requestsCh: make(chan request),
readResponseCh: make(chan *iso8583.Message),
directWriteCh: make(chan directWrite),
done: make(chan struct{}),
respMap: make(map[string]response),
spec: spec,
Expand Down Expand Up @@ -169,13 +176,25 @@ func (c *Connection) Connect() error {
return nil
}

// Write writes data directly to the connection. Writes are atomic for
// net.TCPConn and tls.Conn and can be called simultaneously from multiple
// goroutines. But you should write whole message (including its header, etc.)
// at once, don't split one message into multiple Write calls.
// It's the caller's responsibility to handle the error returned from Write.
// Write writes data directly to the connection. It is crucial to note that the
// Write operation is atomic in nature, meaning it completes in a single
// uninterrupted step.
// When writing data, the entire message—including its header and any other
// components—should be written in one go. Splitting a single message into
// multiple Write calls is dangerous, as it could lead to unexpected behavior
// or errors.
func (c *Connection) Write(p []byte) (int, error) {
return c.conn.Write(p)
dw := directWrite{
data: p,
errCh: make(chan error, 1),
}

select {
case c.directWriteCh <- dw:
return len(p), <-dw.errCh
case <-c.done:
return 0, ErrConnectionClosed
}
}

// run starts read and write loops in goroutines
Expand Down Expand Up @@ -380,34 +399,30 @@ func (c *Connection) Send(message *iso8583.Message) (*iso8583.Message, error) {

func (c *Connection) writeMessage(w io.Writer, message *iso8583.Message) error {
if c.Opts.MessageWriter != nil {
return c.Opts.MessageWriter.WriteMessage(w, message)
err := c.Opts.MessageWriter.WriteMessage(c.conn, message)
if err != nil {
return fmt.Errorf("writing message: %w", err)
}

return nil
}

// default message writer
// if no custom message writer is set, use default one

packed, err := message.Pack()
if err != nil {
return fmt.Errorf("packing message: %w", err)
}

// create buffer for header and packed message so we can write it to
// the connection as a single write
buf := &bytes.Buffer{}

// create header
_, err = c.writeMessageLength(buf, len(packed))
if err != nil {
return fmt.Errorf("writing message header to buffer: %w", err)
}

_, err = buf.Write(packed)
_, err = c.writeMessageLength(c.conn, len(packed))
if err != nil {
return fmt.Errorf("writing packed message to buffer: %w", err)
return fmt.Errorf("writing message length: %w", err)
}

// write buffer to the connection as a single write (atomic)
_, err = buf.WriteTo(w)
_, err = c.conn.Write(packed)
if err != nil {
return fmt.Errorf("writing buffer to the connection: %w", err)
return fmt.Errorf("writing packed message: %w", err)
}

return nil
Expand Down Expand Up @@ -533,6 +548,19 @@ func (c *Connection) writeLoop() {
if req.replyCh == nil {
req.errCh <- nil
}

case dw := <-c.directWriteCh:
_, err = c.conn.Write(dw.data)
if err != nil {
c.handleError(fmt.Errorf("writing data: %w", err))
dw.errCh <- err

// we can't continue to write other messages or data when we failed to write
// one of them
break
}
dw.errCh <- nil

case <-idleTimeTimer.C:
// if no message was sent during idle time, we have to send ping message
if c.Opts.PingHandler != nil {
Expand Down
93 changes: 54 additions & 39 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,58 +178,73 @@ func TestClient_Write(t *testing.T) {
require.NoError(t, err)
defer server.Close()

var called atomic.Int32
t.Run("write into the connection", func(t *testing.T) {
var called atomic.Int32
inboundMessageHandler := func(c *connection.Connection, message *iso8583.Message) {
called.Add(1)

inboundMessageHandler := func(c *connection.Connection, message *iso8583.Message) {
called.Add(1)
mti, err := message.GetMTI()
require.NoError(t, err)
require.Equal(t, "0810", mti)
}

mti, err := message.GetMTI()
// we should be able to write any bytes to the server
c, err := connection.New(server.Addr, testSpec, readMessageLength, writeMessageLength, connection.InboundMessageHandler(inboundMessageHandler))
require.NoError(t, err)
require.Equal(t, "0810", mti)
}
defer c.Close()

// we should be able to write any bytes to the server
c, err := connection.New(server.Addr, testSpec, readMessageLength, writeMessageLength, connection.InboundMessageHandler(inboundMessageHandler))
require.NoError(t, err)
defer c.Close()
err = c.Connect()
require.NoError(t, err)

err = c.Connect()
require.NoError(t, err)
// let's create data to write to the server, we will prepare header and
// packed message

// network management message
message := iso8583.NewMessage(testSpec)
err = message.Marshal(baseFields{
MTI: field.NewStringValue("0800"),
TestCaseCode: field.NewStringValue(TestCaseReply),
STAN: field.NewStringValue(getSTAN()),
})
require.NoError(t, err)

// let's create data to write to the server, we will prepare header and
// packed message
packed, err := message.Pack()
require.NoError(t, err)

// network management message
message := iso8583.NewMessage(testSpec)
err = message.Marshal(baseFields{
MTI: field.NewStringValue("0800"),
TestCaseCode: field.NewStringValue(TestCaseReply),
STAN: field.NewStringValue(getSTAN()),
})
require.NoError(t, err)
// prepare header
header := &bytes.Buffer{}
_, err = writeMessageLength(header, len(packed))
require.NoError(t, err)

packed, err := message.Pack()
require.NoError(t, err)
// combine header and message
data := append(header.Bytes(), packed...)

// prepare header
header := &bytes.Buffer{}
_, err = writeMessageLength(header, len(packed))
require.NoError(t, err)
// write the data directly to the connection
n, err := c.Write(data)

require.NoError(t, err)
require.Equal(t, len(data), n)

// we should expect to get reply, but as we are not using Send method,
// the reply will be handled by InboundMessageHandler
require.Eventually(t, func() bool {
return called.Load() == 1
}, 100*time.Millisecond, 20*time.Millisecond, "inboundMessageHandler should be called")
})

// combine header and message
data := append(header.Bytes(), packed...)
t.Run("write into the closed connection ", func(t *testing.T) {
c, err := connection.New(server.Addr, testSpec, readMessageLength, writeMessageLength)
require.NoError(t, err)
defer c.Close()

// write the data directly to the connection
n, err := c.Write(data)
err = c.Connect()
require.NoError(t, err)

require.NoError(t, err)
require.Equal(t, len(data), n)
c.Close()

// we should expect to get reply, but as we are not using Send method,
// the reply will be handled by InboundMessageHandler
require.Eventually(t, func() bool {
return called.Load() == 1
}, 100*time.Millisecond, 20*time.Millisecond, "inboundMessageHandler should be called")
_, err = c.Write([]byte("hello"))
require.Error(t, err)
})
}

func TestClient_Send(t *testing.T) {
Expand Down

0 comments on commit b6312f0

Please sign in to comment.