diff --git a/pkg/agent/protocol/tcp/delay.go b/pkg/agent/protocol/tcp/delay.go new file mode 100644 index 00000000..214ef772 --- /dev/null +++ b/pkg/agent/protocol/tcp/delay.go @@ -0,0 +1,72 @@ +package tcp + +import ( + "fmt" + "io" + "math/rand" + "time" +) + +type DelayHandler struct { + delay time.Duration + variation float64 + + lastDelayed time.Time +} + +func DelayHandlerBuilder(delay time.Duration, variation float64) HandlerBuilder { + return func(ConnMeta) Handler { + return &DelayHandler{ + delay: delay, + variation: variation, + } + } +} + +func (d *DelayHandler) HandleUpward(client io.Reader, server io.Writer) error { + // Buffer of 2048 bytes for reading a TCP segment. 2048 is the smallest power of two able to hold the most common + // TCP MSS, 1500 - (20 + 20) = 1460 bytes. + buf := make([]byte, 2048) + + for { + // Wait until there is one byte available. + _, err := client.Read(buf[:1]) + if err != nil { + return fmt.Errorf("reading from downstream: %w", err) + } + + // After the first byte reception, introduce delay if we need to. + d.waitIfNeeded() + + // Read the rest of the TCP frame *up to* len(buf). + // Stream-based implementations of io.Reader return early if there is no more data + n, err := client.Read(buf[1:]) + if err != nil { + return fmt.Errorf("reading from downstream: %w", err) + } + + // Write the amount read plus the first byte. + _, err = server.Write(buf[:n+1]) + if err != nil { + return fmt.Errorf("writing to upstream: %w", err) + } + } +} + +func (d *DelayHandler) HandleDownward(server io.Reader, client io.Writer) error { + _, err := io.Copy(client, server) + // Copy dos not return EOF. + return fmt.Errorf("relaying data downstream: %w", err) +} + +func (d *DelayHandler) waitIfNeeded() { + if time.Since(d.lastDelayed) < time.Second { + return + } + + d.lastDelayed = time.Now() + plusMinus := rand.Float64()*d.variation*2 - d.variation + delay := time.Duration(float64(d.delay) * (1 + plusMinus)) + + time.Sleep(delay) +} diff --git a/pkg/agent/protocol/tcp/delay_test.go b/pkg/agent/protocol/tcp/delay_test.go new file mode 100644 index 00000000..92835467 --- /dev/null +++ b/pkg/agent/protocol/tcp/delay_test.go @@ -0,0 +1,104 @@ +package tcp_test + +import ( + "bufio" + "fmt" + "net" + "testing" + "time" + + "github.com/grafana/xk6-disruptor/pkg/agent/protocol/tcp" + "github.com/grafana/xk6-disruptor/pkg/agent/protocol/tcp/testutil" +) + +// Test_Proxy_Forwards tests the tcp.Proxy using tcp.ForwardHandler, ensuring messages are forwarded to and from the +// proxy. +func Test_Delay(t *testing.T) { + t.Parallel() + + const localv4 = "127.0.0.1:0" + upstreamL, err := net.Listen("tcp", localv4) + if err != nil { + t.Fatalf("creating upstream listener: %v", err) + } + + serverCh := make(chan string) + serverErr := make(chan error) + go func() { + serverErr <- testutil.EchoServer(upstreamL, serverCh) + }() + + proxyL, err := net.Listen("tcp", localv4) + if err != nil { + t.Fatalf("creating proxy listener: %v", err) + } + + proxy := tcp.NewProxy(proxyL, upstreamL.Addr(), tcp.DelayHandlerBuilder(500*time.Millisecond, 0)) + go func() { + err := proxy.Start() + if err != nil { + // t.Fatal cannot be used inside a goroutine. + t.Errorf("couldn't start poxy: %v", err) + } + }() + + proxyConn, err := net.Dial("tcp", proxyL.Addr().String()) + if err != nil { + t.Fatalf("dialing proxy address: %v", err) + } + + bufReader := bufio.NewReader(proxyConn) + + // Write a first line. + _, err = fmt.Fprintln(proxyConn, "a line") + if err != nil { + t.Fatalf("writing to proxy conn: %v", err) + } + + // + select { + case <-time.After(100 * time.Millisecond): + case <-serverCh: + t.Fatalf("upstream data way before the delay passed") + } + + // Check the server received the line + select { + case <-time.After(time.Second): + t.Fatalf("upstream did not receive the line before the deadline") + case serverLine := <-serverCh: + if serverLine != "a line\n" { + t.Fatalf("upstream received unexpected data %q", serverLine) + } + } + + // Check we received the echoed data + clientLine, err := bufReader.ReadString('\n') + if err != nil { + t.Fatalf("reading upstream response from proxyconn: %v", err) + } + if clientLine != "a line\n" { + t.Fatalf("downstream received unexpected data %q", clientLine) + } + + // Close the connection to the proxy. + _ = proxyConn.Close() + + select { + case <-time.After(time.Second): + t.Fatalf("upstream connection was not closed") + case line, ok := <-serverCh: + if ok { + t.Fatalf("upstream receive unexpected data: %q", line) + } + } + + select { + case <-time.After(time.Second): + t.Fatalf("server did not terminate") + case err = <-serverErr: + if err != nil { + t.Fatalf("server returned an error: %v", err) + } + } +} diff --git a/pkg/agent/protocol/tcp/handler.go b/pkg/agent/protocol/tcp/handler.go new file mode 100644 index 00000000..73c9e7d9 --- /dev/null +++ b/pkg/agent/protocol/tcp/handler.go @@ -0,0 +1,83 @@ +package tcp + +import ( + "errors" + "hash/crc64" + "io" + "net" + "time" +) + +// HandlerBuilder is a function that returns a Handler. Proxy will use this function to create a new handler for each +// TCP connection. ConnMeta provides information about the TCP connection this handler will handle. +// Handler implementations that require additional parameters can have a builder that returns a HandlerBuilder, such as +type HandlerBuilder func(ConnMeta) Handler + +// Handler is an object capable of acting when TCP messages are either sent or received. +type Handler interface { + // HandleUpward forwards data from the client to the server. Proxy will call each method exactly once for the + // single connection a Handler instance handles. Implementations should consume from client and write to server + // until an error occurs. + // When either HandleUpward or HandleDownward return an error, connections to both server and clients are closed. + // If ErrTerminate is returned, the connection is still closed but no error message is logged. + HandleUpward(client io.Reader, server io.Writer) error + // HandleDownward provides is the equivalent of HandleUpward for data sent from the server to the client. + HandleDownward(server io.Reader, client io.Writer) error +} + +// ErrTerminate may be returned by Handler implementations that wish to willingly terminate a connection. Connection +// will be closed, but no error log will be generated. +var ErrTerminate = errors.New("connection terminated by proxy handler") + +// ConnMeta holds metadata about a TCP connection. +type ConnMeta struct { + Opened time.Time + ClientAddress net.Addr + ServerAddress net.Addr +} + +// Hash returns a semi-unique number to every connection. +// The implementation of Hash is not guaranteed to be stable between updates of this package. +func (c ConnMeta) Hash() uint64 { + // We use CRC64 as this hash does not need to be cryptographically secure, and it's easy to get an uint64 from it. + hash := crc64.New(crc64.MakeTable(crc64.ISO)) + _, _ = hash.Write([]byte(c.Opened.String())) + _, _ = hash.Write([]byte(c.ClientAddress.String())) + _, _ = hash.Write([]byte(c.ServerAddress.String())) + + return hash.Sum64() +} + +// ForwardHandler is a handler that forwards data between client and server without taking any actions. +type ForwardHandler struct{} + +// ForwardHandlerBuilder returns a new instance of a ForwardHandler. +func ForwardHandlerBuilder(_ ConnMeta) Handler { + return ForwardHandler{} +} + +func (ForwardHandler) HandleUpward(client io.Reader, server io.Writer) error { + _, err := io.Copy(server, client) + return err +} + +func (ForwardHandler) HandleDownward(server io.Reader, client io.Writer) error { + _, err := io.Copy(client, server) + return err +} + +// RejectHandler is a handler that closes connections immediately after being opened. +type RejectHandler struct{} + +// RejectHandlerBuilder returns a new instance of a ForwardHandler. +func RejectHandlerBuilder(_ ConnMeta) Handler { + return RejectHandler{} +} + +func (RejectHandler) HandleUpward(client io.Reader, server io.Writer) error { + return ErrTerminate +} + +func (RejectHandler) HandleDownward(server io.Reader, client io.Writer) error { + return ErrTerminate +} diff --git a/pkg/agent/protocol/tcp/meta_tcpip_test.go b/pkg/agent/protocol/tcp/meta_tcpip_test.go new file mode 100644 index 00000000..2c3c17fc --- /dev/null +++ b/pkg/agent/protocol/tcp/meta_tcpip_test.go @@ -0,0 +1,69 @@ +package tcp + +import ( + "errors" + "io" + "net" + "testing" + "time" +) + +// Test_TCPConnShortReads is a meta-test that proves that the TCP/IP stack behaves as we expect it to, where Read() +// returns immediately for a TCP segment. +func Test_TCPConnShortReads(t *testing.T) { + server, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listening server: %v", err) + } + + client, err := net.Dial("tcp", server.Addr().String()) + if err != nil { + t.Fatalf("dialing client: %v", err) + } + + go func() { + twoBytes := make([]byte, 2) + _, cErr := client.Write(twoBytes) + if cErr != nil { + t.Errorf("writing twoBytes: %v", err) + } + + time.Sleep(200 * time.Millisecond) + + _, cErr = client.Write(twoBytes) + if cErr != nil { + t.Errorf("writing twoBytes: %v", err) + } + + _ = client.Close() + }() + + conn, err := server.Accept() + if err != nil { + t.Fatalf("accepting server conn: %v", err) + } + + buf := make([]byte, 4) + n, err := conn.Read(buf) + if err != nil { + t.Fatalf("reading from server conn: %v", err) + } + + if n != 2 { + t.Fatalf("expected to read 2 bytes") + } + + n, err = conn.Read(buf) + if err != nil { + t.Fatalf("reading from server conn: %v", err) + } + + if n != 2 { + t.Fatalf("expected to read 2 bytes") + } + + _, err = conn.Read(buf) + if !errors.Is(err, io.EOF) { + t.Fatalf("expected EOF: %v", err) + } +} diff --git a/pkg/agent/protocol/tcp/proxy.go b/pkg/agent/protocol/tcp/proxy.go new file mode 100644 index 00000000..db004f39 --- /dev/null +++ b/pkg/agent/protocol/tcp/proxy.go @@ -0,0 +1,92 @@ +package tcp + +import ( + "errors" + "fmt" + "log" + "net" + "time" +) + +// Proxy implements a TCP transparent proxy between a client and a server. +type Proxy struct { + l net.Listener + upstream net.Addr + handlerBuilder HandlerBuilder +} + +func NewProxy(l net.Listener, upstream net.Addr, handlerBuilder HandlerBuilder) *Proxy { + return &Proxy{ + l: l, + upstream: upstream, + handlerBuilder: handlerBuilder, + } +} + +func (p *Proxy) Start() error { + for { + conn, err := p.l.Accept() + if err != nil { + return err + } + + go func() { + err := p.handleConn(conn) + // TODO: Better error handling + log.Printf("handling connection: %v", err) + }() + } +} + +func (p *Proxy) Stop() error { + // TODO: Harvest open connections and close them. + return nil +} + +func (p *Proxy) handleConn(downstreamConn net.Conn) error { + defer func() { + _ = downstreamConn.Close() + }() + + upstreamConn, err := net.Dial("tcp", p.upstream.String()) + if err != nil { + return fmt.Errorf("opening upstream connection: %w", err) + } + + defer func() { + _ = upstreamConn.Close() + }() + + metadata := ConnMeta{ + Opened: time.Now(), + ClientAddress: downstreamConn.RemoteAddr(), + ServerAddress: upstreamConn.RemoteAddr(), + } + + handler := p.handlerBuilder(metadata) + + errChan := make(chan error, 2) + go func() { + errChan <- func() error { + err := handler.HandleUpward(downstreamConn, upstreamConn) + if err != nil && !errors.Is(err, ErrTerminate) { + return err + } + + return nil + }() + }() + go func() { + errChan <- func() error { + err := handler.HandleDownward(upstreamConn, downstreamConn) + if err != nil && !errors.Is(err, ErrTerminate) { + return err + } + + return nil + }() + }() + + err = <-errChan + return fmt.Errorf("forwarding data: %w", err) +} diff --git a/pkg/agent/protocol/tcp/proxy_test.go b/pkg/agent/protocol/tcp/proxy_test.go new file mode 100644 index 00000000..a30030c3 --- /dev/null +++ b/pkg/agent/protocol/tcp/proxy_test.go @@ -0,0 +1,189 @@ +package tcp_test + +import ( + "bufio" + "fmt" + "net" + "testing" + "time" + + "github.com/grafana/xk6-disruptor/pkg/agent/protocol/tcp" + "github.com/grafana/xk6-disruptor/pkg/agent/protocol/tcp/testutil" +) + +// Test_Proxy_Forwards tests the tcp.Proxy using tcp.ForwardHandler, ensuring messages are forwarded to and from the +// proxy. +func Test_Proxy_Forwards(t *testing.T) { + t.Parallel() + + const localv4 = "127.0.0.1:0" + upstreamL, err := net.Listen("tcp", localv4) + if err != nil { + t.Fatalf("creating upstream listener: %v", err) + } + + serverCh := make(chan string) + serverErr := make(chan error) + go func() { + serverErr <- testutil.EchoServer(upstreamL, serverCh) + }() + + proxyL, err := net.Listen("tcp", localv4) + if err != nil { + t.Fatalf("creating proxy listener: %v", err) + } + + proxy := tcp.NewProxy(proxyL, upstreamL.Addr(), tcp.ForwardHandlerBuilder) + go func() { + err := proxy.Start() + if err != nil { + // t.Fatal cannot be used inside a goroutine. + t.Errorf("couldn't start poxy: %v", err) + } + }() + + proxyConn, err := net.Dial("tcp", proxyL.Addr().String()) + if err != nil { + t.Fatalf("dialing proxy address: %v", err) + } + + bufReader := bufio.NewReader(proxyConn) + + // Write a first line. + _, err = fmt.Fprintln(proxyConn, "a line") + if err != nil { + t.Fatalf("writing to proxy conn: %v", err) + } + + // Check the server received the line + select { + case <-time.After(time.Second): + t.Fatalf("upstream did not receive the line before the deadline") + case serverLine := <-serverCh: + if serverLine != "a line\n" { + t.Fatalf("upstream received unexpected data %q", serverLine) + } + } + + // Check we received the echoed data + clientLine, err := bufReader.ReadString('\n') + if err != nil { + t.Fatalf("reading upstream response from proxyconn: %v", err) + } + if clientLine != "a line\n" { + t.Fatalf("downstream received unexpected data %q", clientLine) + } + + // Write a second line. + _, err = fmt.Fprintln(proxyConn, "another line") + if err != nil { + t.Fatalf("writing to proxy conn: %v", err) + } + + // Check the server received the line + select { + case <-time.After(time.Second): + t.Fatalf("upstream did not receive the line before the deadline") + case serverLine := <-serverCh: + if serverLine != "another line\n" { + t.Fatalf("upstream received unexpected data %q", serverLine) + } + } + + // Check we received the echoed data + clientLine, err = bufReader.ReadString('\n') + if err != nil { + t.Fatalf("reading upstream response from proxyconn: %v", err) + } + if clientLine != "another line\n" { + t.Fatalf("downstream received unexpected data %q", clientLine) + } + + // Close the connection to the proxy. + _ = proxyConn.Close() + + select { + case <-time.After(time.Second): + t.Fatalf("upstream connection was not closed") + case line, ok := <-serverCh: + if ok { + t.Fatalf("upstream receive unexpected data: %q", line) + } + } + + select { + case <-time.After(time.Second): + t.Fatalf("server did not terminate") + case err = <-serverErr: + if err != nil { + t.Fatalf("server returned an error: %v", err) + } + } +} + +// Test_Proxy_Forwards tests the tcp.Proxy using tcp.RejectHandler, ensuring both client and server connections are +// closed properly and cleanly when handlers return errors. +func Test_Proxy_Rejects(t *testing.T) { + t.Parallel() + + const localv4 = "127.0.0.1:0" + upstreamL, err := net.Listen("tcp", localv4) + if err != nil { + t.Fatalf("creating upstream listener: %v", err) + } + + serverCh := make(chan string) + serverErr := make(chan error) + go func() { + serverErr <- testutil.EchoServer(upstreamL, serverCh) + }() + + proxyL, err := net.Listen("tcp", localv4) + if err != nil { + t.Fatalf("creating proxy listener: %v", err) + } + + proxy := tcp.NewProxy(proxyL, upstreamL.Addr(), tcp.RejectHandlerBuilder) + go func() { + err := proxy.Start() + if err != nil { + // t.Fatal cannot be used inside a goroutine. + t.Errorf("couldn't start poxy: %v", err) + } + }() + + proxyConn, err := net.Dial("tcp", proxyL.Addr().String()) + if err != nil { + t.Fatalf("dialing proxy address: %v", err) + } + + // Attempt to write a first line. + _, err = fmt.Fprintln(proxyConn, "a line") + if err != nil { + t.Fatalf("error writing data: %v", err) + } + + singleByte := make([]byte, 1) + _, err = proxyConn.Read(singleByte) + if err == nil { + t.Fatalf("expected connection to be closed by rejectHandler: %v", err) + } + + select { + case <-time.After(time.Second): + t.Fatalf("upstream connection was not closed") + case line, ok := <-serverCh: + if ok { + t.Fatalf("upstream receive unexpected data: %q", line) + } + } + + select { + case <-time.After(time.Second): + t.Fatalf("server did not terminate") + case err = <-serverErr: + if err != nil { + t.Fatalf("server returned an error: %v", err) + } + } +} diff --git a/pkg/agent/protocol/tcp/testutil/echoserver.go b/pkg/agent/protocol/tcp/testutil/echoserver.go new file mode 100644 index 00000000..462a396a --- /dev/null +++ b/pkg/agent/protocol/tcp/testutil/echoserver.go @@ -0,0 +1,44 @@ +package testutil + +import ( + "bufio" + "errors" + "fmt" + "io" + "net" + "time" +) + +// EchoServer is a helper function for testing that accepts a single connection from the given listener, and pushes +// each received line to lineCh. When the connection is closed, it also closes lineCh. +func EchoServer(l net.Listener, lineCh chan string) error { + defer close(lineCh) + + conn, err := l.Accept() + if err != nil { + return fmt.Errorf("accepting conn: %w", err) + } + + reader := bufio.NewReader(conn) + for { + line, err := reader.ReadString('\n') + if errors.Is(err, io.EOF) { + return nil + } + if err != nil { + return fmt.Errorf("reading from conn: %w", err) + } + + _, err = conn.Write([]byte(line)) + if err != nil { + return fmt.Errorf("echoing back to conn: %w", err) + } + + select { + case lineCh <- line: + continue + case <-time.After(time.Second): + return fmt.Errorf("reader did not consume line %q", line) + } + } +}