Skip to content

Commit

Permalink
Added RPC tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alpinskiy committed Dec 25, 2024
1 parent 757f598 commit 9f8116d
Show file tree
Hide file tree
Showing 2 changed files with 291 additions and 20 deletions.
92 changes: 72 additions & 20 deletions internal/vkgo/rpc/statshouse.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,30 @@ func (pc *PacketConn) KeyID() [4]byte {

func ForwardPackets(ctx context.Context, dst, src *PacketConn) (res ForwardPacketsResult) {
for i := 0; ctx.Err() == nil && res.Error() == nil; i++ {
var header packetHeader
src.readMu.Lock()
_, isBuiltin, _, err := src.readPacketHeaderUnlocked(&header, DefaultPacketTimeout)
src.readMu.Unlock()
if err != nil {
res.ReadErr = err
return res
}
if isBuiltin {
res.WriteErr = src.WritePacketBuiltin(time.Duration(0))
} else {
switch header.tip {
case tl.RpcClientWantsFin{}.TLTag():
res.ClientWantsFin = true
case tl.RpcServerWantsFin{}.TLTag():
res.ServerWantsFin = true
}
res.ReadWriteError = forwardPacket(dst, src, header)
res = ForwardPacket(ctx, dst, src)
}
return res
}

func ForwardPacket(ctx context.Context, dst, src *PacketConn) (res ForwardPacketsResult) {
var header packetHeader
src.readMu.Lock()
_, isBuiltin, _, err := src.readPacketHeaderUnlocked(&header, DefaultPacketTimeout)
src.readMu.Unlock()
if err != nil {
res.ReadErr = err
return res
}
if isBuiltin {
res.WriteErr = src.WritePacketBuiltin(time.Duration(0))
} else {
switch header.tip {
case tl.RpcClientWantsFin{}.TLTag():
res.ClientWantsFin = true
case tl.RpcServerWantsFin{}.TLTag():
res.ServerWantsFin = true
}
res.ReadWriteError = forwardPacket(dst, src, header)
}
return res
}
Expand Down Expand Up @@ -81,7 +86,7 @@ func copyBody(dst, src *PacketConn, headerLen uint32) (res ReadWriteError) {
defer src.readMu.Unlock()
// copy body
bodySize := int(headerLen) - packetOverhead
res = cryptoCopy(dst, src, bodySize)
res = packetConnCopy(dst, src, bodySize)
if res.Error() != nil {
return res
}
Expand Down Expand Up @@ -165,7 +170,7 @@ func (hctx *HandlerContext) writeReponseUnlocked(conn *PacketConn) error {
return nil
}

func cryptoCopy(dst, src *PacketConn, n int) (res ReadWriteError) {
func packetConnCopy(dst, src *PacketConn, n int) (res ReadWriteError) {
if n == 0 {
return ReadWriteError{}
}
Expand Down Expand Up @@ -210,6 +215,53 @@ func cryptoCopy(dst, src *PacketConn, n int) (res ReadWriteError) {
}
}

func cryptoCopy(dst *cryptoWriter, src *cryptoReader, n int, cb func([]byte)) (res ReadWriteError) {
if n == 0 {
return ReadWriteError{}
}
for {
if m := src.end - src.begin; m > 0 {
if m > n {
m = n
}
s := src.buf[src.begin : src.begin+m]
m, res.WriteErr = dst.w.Write(s)
if res.WriteErr != nil {
return res
}
if cb != nil {
cb(s)
}
src.begin += m
n -= m
if n == 0 {
return ReadWriteError{}
}
}
if res.ReadErr != nil {
return res
}
buf := src.buf[:cap(src.buf)]
bufSize := copy(buf, src.buf[src.end:])
var read int
read, res.ReadErr = src.r.Read(buf[bufSize:])
bufSize += read
src.buf = buf[:bufSize]
src.begin = 0
if src.enc != nil {
decrypt := roundDownPow2(bufSize, src.blockSize)
src.enc.CryptBlocks(buf[:decrypt], buf[:decrypt])
src.end = decrypt
} else {
src.end = bufSize
}
if read <= 0 { // infinite loop guard
res.ReadErr = errZeroRead
return res
}
}
}

func (src *cryptoReader) discard(n int) error {
if n == 0 {
return nil
Expand Down
219 changes: 219 additions & 0 deletions internal/vkgo/rpc/statshouse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
package rpc

import (
"bytes"
"context"
"net"
"testing"
"time"

"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"pgregory.net/rapid"
)

type cryptoPipelineMachine struct {
r *cryptoReader
w *cryptoWriter
rb *bytes.Buffer
actual *bytes.Buffer
expected []byte
offset int
fatalf func(format string, args ...any)
}

func (c *cryptoPipelineMachine) Write(t *rapid.T) {
s := rapid.SliceOf(rapid.Byte()).Draw(t, "slice")
c.expected = append(c.expected, s...)
for len(s) != 0 {
n, err := c.rb.Write(s)
if err != nil {
c.fatalf("write failed: %v", err)
}
s = s[n:]
}
}

func (c *cryptoPipelineMachine) Discard(t *rapid.T) {
n := rapid.IntRange(0, c.rb.Len()).Draw(t, "n")
if err := c.r.discard(n); err != nil {
c.fatalf("discard failed: %v", err)
}
c.expected = append(c.expected[:c.offset], c.expected[c.offset+n:]...)
}

func (c *cryptoPipelineMachine) ReadDiscard(t *rapid.T) {
n := rapid.IntRange(0, c.rb.Len()).Draw(t, "n")
m, err := c.r.Read(make([]byte, n))
if err != nil {
c.fatalf("Read failed: %v", err)
}
c.expected = append(c.expected[:c.offset], c.expected[c.offset+m:]...)
}

func (c *cryptoPipelineMachine) Copy(t *rapid.T) {
n := rapid.IntRange(0, c.rb.Len()).Draw(t, "n")
rwe := cryptoCopy(c.w, c.r, n, nil)
if rwe.Error() != nil {
c.fatalf("copy failed: %v, %v", rwe.WriteErr, rwe.ReadErr)
}
if err := c.w.Flush(); err != nil {
c.fatalf("flush failed: %v", err)
}
c.offset += n
}

func (c *cryptoPipelineMachine) Check(_ *rapid.T) {
if len(c.expected) < c.actual.Len() {
c.fatalf("expected %v bytes, actual %v bytes", len(c.expected), c.actual.Len())
}
expected := c.expected[:c.actual.Len()]
actual := c.actual.Bytes()
if !bytes.Equal(expected, actual) {
c.fatalf("expected %q, actual %q", expected, actual)
}
}

func TestCryptoPipeline(t *testing.T) {
rapid.Check(t, func(t *rapid.T) {
rb := &bytes.Buffer{}
actual := &bytes.Buffer{}
c := &cryptoPipelineMachine{
rb: rb,
actual: actual,
r: newCryptoReader(rb, rapid.IntRange(0, 1024).Draw(t, "read_buffer_size")),
w: newCryptoWriter(actual, rapid.IntRange(0, 1024).Draw(t, "write_buffer_size")),
fatalf: t.Fatalf,
}
t.Repeat(rapid.StateMachineActions(c))
_ = cryptoCopy(c.w, c.r, c.rb.Len()+cap(c.r.buf), nil)
_ = c.w.Flush()
if !bytes.Equal(c.expected, c.actual.Bytes()) {
c.fatalf("expected %q, actual %q", c.expected, c.actual.Bytes())
}
})
}

type forwardPacketMachine struct {
client *PacketConn
server *PacketConn
proxyClient *PacketConn
proxyServer *PacketConn
protocolVersion uint32
}

func newForwardPacketsMachine(t *rapid.T) (res forwardPacketMachine) {
var (
network = "tcp"
proxyAddr = "127.0.0.1:7000"
upstreamAddr = "127.0.0.1:7001"
startTime = uint32(time.Now().Unix())
forceEncryption = rapid.Bool().Draw(t, "forceEncryption")
protocolVersion = func() uint32 {
if rapid.Bool().Draw(t, "protocolVersion") {
return LatestProtocolVersion
} else {
return DefaultProtocolVersion
}
}()
cryptoKey = rapid.StringN(1, 1024, 1024).Draw(t, "cryptoKey")
)
var (
listen = func(addr string) net.Listener {
res, err := Listen(network, addr, false)
require.NoError(t, err)
return res
}
accept = func(ln net.Listener) (*PacketConn, error) {
defer ln.Close()
conn, err := ln.Accept()
if err != nil {
return nil, err
}
res := NewPacketConn(conn, DefaultServerRequestBufSize, DefaultServerResponseBufSize)
_, _, err = res.HandshakeServer([]string{cryptoKey}, nil, forceEncryption, startTime, 0)
if err != nil {
res.Close()
return nil, err
}
return res, nil
}
connect = func(addr string) (*PacketConn, error) {
conn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
res := NewPacketConn(conn, DefaultServerRequestBufSize, DefaultServerResponseBufSize)
err = res.HandshakeClient(cryptoKey, nil, forceEncryption, startTime, 0, 0, protocolVersion)
if err != nil {
res.Close()
return nil, err
}
return res, nil
}
)
var group errgroup.Group
upstream := listen(upstreamAddr)
group.Go(func() (err error) {
res.server, err = accept(upstream)
return err
})
proxy := listen(proxyAddr)
group.Go(func() (err error) {
if res.proxyServer, err = accept(proxy); err == nil {
res.proxyClient, err = connect(upstreamAddr)
}
return err
})
var err error
res.client, err = connect(proxyAddr)
require.NoError(t, err)
require.NoError(t, group.Wait())
res.protocolVersion = protocolVersion
return res
}

func (m *forwardPacketMachine) close() {
m.client.Close()
m.proxyClient.Close()
m.server.Close()
m.proxyServer.Close()
}

func TestForwardPacket(t *testing.T) {
type message struct {
tip uint32
body []byte
}
rapid.Check(t, func(t *rapid.T) {
m := newForwardPacketsMachine(t)
defer m.close()
minBodyLen := 1
if m.protocolVersion == DefaultProtocolVersion {
minBodyLen = 4
}
for i := 0; i < 512; i++ {
var group errgroup.Group
group.Go(func() error {
res := ForwardPacket(context.Background(), m.proxyClient, m.proxyServer)
return res.Error()
})
sent := message{
tip: rapid.Uint32().Draw(t, "tip"),
body: rapid.SliceOfN(rapid.Byte(), minBodyLen, 1024).Draw(t, "body"),
}
if m.protocolVersion == DefaultProtocolVersion {
sent.body = sent.body[:len(sent.body)-len(sent.body)%4]
}
err := m.client.WritePacket(sent.tip, sent.body, 0)
require.NoError(t, err)
m.client.Flush()
require.NoError(t, err)
require.NoError(t, group.Wait())
var received message
received.tip, received.body, err = m.server.ReadPacket(nil, 0)
require.NoError(t, err)
require.Equal(t, sent, received)
}
})
}

0 comments on commit 9f8116d

Please sign in to comment.