From 535ebd91b0c1bb23e5f1301a304b3a0b45ca932b Mon Sep 17 00:00:00 2001 From: Mikhail Alpinskiy Date: Fri, 10 Jan 2025 18:44:58 +0300 Subject: [PATCH] Check CRC when forwarding RPC packets Dump last 16 packets on error --- .github/workflows/ci-go.yml | 1 + .vscode/launch.json | 17 ++- go.mod | 1 + go.sum | 2 + internal/aggregator/ingress_proxy.go | 66 +++++------ internal/vkgo/rpc/pingpong.go | 12 +- internal/vkgo/rpc/statshouse.go | 163 ++++++++++++++++++++------- internal/vkgo/rpc/statshouse_test.go | 163 +++++++++++++++++++++++++-- 8 files changed, 335 insertions(+), 90 deletions(-) diff --git a/.github/workflows/ci-go.yml b/.github/workflows/ci-go.yml index 364a00e50..d08cc1e69 100644 --- a/.github/workflows/ci-go.yml +++ b/.github/workflows/ci-go.yml @@ -17,6 +17,7 @@ jobs: with: go-version: 1.22.x cache: true + - run: apt-get install libpcap-dev - uses: golangci/golangci-lint-action@v6 with: version: v1.60 diff --git a/.vscode/launch.json b/.vscode/launch.json index 8c2ca8081..35ecbac2e 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -138,7 +138,7 @@ ], }, { - "name": "proxy_rpc_test", + "name": "proxy rpc test", "type": "go", "request": "launch", "mode": "auto", @@ -162,7 +162,7 @@ }, }, { - "name": "TestForwardPacket", + "name": "test forward packet", "type": "go", "request": "launch", "mode": "test", @@ -170,6 +170,19 @@ "args": [ "-test.run=TestForwardPacket", ], + }, + { + "name": "test play PCAP", + "type": "go", + "request": "launch", + "mode": "test", + "program": "${workspaceFolder}/internal/vkgo/rpc/statshouse_test.go", + "args": [ + "-test.run=TestPlayPcap", + ], + "env": { + "STATSHOUSE_TEST_PLAY_PCAP_FILE_PATH":"", + }, } ] } \ No newline at end of file diff --git a/go.mod b/go.mod index 9eb3f0bf7..89ef996d8 100644 --- a/go.mod +++ b/go.mod @@ -112,6 +112,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/flatbuffers v23.5.26+incompatible // indirect + github.com/google/gopacket v1.1.19 github.com/grafana/regexp v0.0.0-20220304095617-2e8d9baf4ac2 // indirect github.com/hashicorp/consul/api v1.12.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect diff --git a/go.sum b/go.sum index 89d83c692..b3d4150c0 100644 --- a/go.sum +++ b/go.sum @@ -287,6 +287,8 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= diff --git a/internal/aggregator/ingress_proxy.go b/internal/aggregator/ingress_proxy.go index e0d81e0b7..feffd2ff1 100644 --- a/internal/aggregator/ingress_proxy.go +++ b/internal/aggregator/ingress_proxy.go @@ -395,31 +395,30 @@ func (p *proxyConn) run() { defer p.group.Done() defer p.clientConn.Close() // handshake client - _, _, err := p.clientConn.HandshakeServer(p.serverKeys, p.serverOpts.TrustedSubnetGroups, false, p.startTime, rpc.DefaultPacketTimeout) - if err != nil { - p.logClientError("handshake", err) - return - } - // initialize "__rpc_request_size" tags + _, _, err := p.clientConn.HandshakeServer(p.serverKeys, p.serverOpts.TrustedSubnetGroups, true, p.startTime, rpc.DefaultPacketTimeout) cryptoKeyID := p.clientConn.KeyID() p.clientCryptoKeyID = int32(binary.BigEndian.Uint32(cryptoKeyID[:4])) p.clientProtocolVersion = int32(p.clientConn.ProtocolVersion()) + if err != nil { + p.logClientError("handshake", err, rpc.PacketHeaderCircularBuffer{}) + return + } // read first request to get shardReplica - var req proxyRequest + var firstReq proxyRequest for { - req, err = p.readRequest() + firstReq, err = p.readRequest() if err != nil { return } if p.ctx.Err() != nil { return // server shutdown } - if req.tip == rpcInvokeReqHeaderTLTag { + if firstReq.tip == rpcInvokeReqHeaderTLTag { break } - p.rareLog("Client skip #%d looking for invoke request, addr %v\n", req.tip, p.clientConn.RemoteAddr()) + log.Printf("Client skip #%d looking for invoke request, addr %v\n", firstReq.tip, p.clientConn.RemoteAddr()) } - shardReplica := req.shardReplica(p) + shardReplica := firstReq.shardReplica(p) upstreamAddr := p.agent.GetConfigResult.Addresses[shardReplica] p.rareLog("Connect shard replica %d, addr %v < %v\n", shardReplica, p.clientConn.LocalAddr(), p.clientConn.RemoteAddr()) defer p.rareLog("Disconnect shard replica %d, addr %v < %v\n", shardReplica, p.clientConn.LocalAddr(), p.clientConn.RemoteAddr()) @@ -427,25 +426,25 @@ func (p *proxyConn) run() { upstreamConn, err := net.DialTimeout("tcp", upstreamAddr, rpc.DefaultPacketTimeout) if err != nil { log.Printf("error connect, upstream addr %s: %v\n", upstreamAddr, err) - _ = req.WriteReponseAndFlush(p.clientConn, err) + _ = firstReq.WriteReponseAndFlush(p.clientConn, err) return } defer upstreamConn.Close() p.upstreamConn = rpc.NewPacketConn(upstreamConn, rpc.DefaultClientConnReadBufSize, rpc.DefaultClientConnWriteBufSize) err = p.upstreamConn.HandshakeClient(p.clientOpts.CryptoKey, p.clientOpts.TrustedSubnetGroups, false, p.uniqueStartTime.Dec(), 0, rpc.DefaultPacketTimeout, rpc.LatestProtocolVersion) if err != nil { - p.logUpstreamError("handshake", err) - _ = req.WriteReponseAndFlush(p.clientConn, err) + p.logUpstreamError("handshake", err, rpc.PacketHeaderCircularBuffer{}) + _ = firstReq.WriteReponseAndFlush(p.clientConn, err) return } // process first request - res := req.process(p) - if res.Error() != nil { + firstReqRes := firstReq.process(p) + if firstReqRes.Error() != nil { return } // serve var ctx = p.ctx - var gracefulShutdown bool + gracefulShutdown := firstReqRes.ClientWantsFin for { // two iterations at most, the latter is graceful shutdown var respLoop sync.WaitGroup var respLoopRes rpc.ForwardPacketsResult @@ -456,12 +455,15 @@ func (p *proxyConn) run() { }() reqLoopRes := p.requestLoop(ctx) respLoop.Wait() - if gracefulShutdown || res.ClientWantsFin || reqLoopRes.ClientWantsFin || respLoopRes.ServerWantsFin || reqLoopRes.Error() != nil || respLoopRes.Error() != nil { + if !gracefulShutdown { + gracefulShutdown = reqLoopRes.ClientWantsFin || respLoopRes.ServerWantsFin + } + if gracefulShutdown || reqLoopRes.Error() != nil || respLoopRes.Error() != nil { return // either graceful shutdown already attempted or error occurred } gracefulShutdown = true if err = p.clientConn.WritePacket(rpcServerWantsFinTLTag, nil, rpc.DefaultPacketTimeout); err != nil { - p.logClientError("write fin", err) + p.logClientError("write fin", err, rpc.PacketHeaderCircularBuffer{}) return } // no timeout for connection graceful shutdown (has server level shutdown timeout) @@ -494,10 +496,10 @@ func (p *proxyConn) responseLoop(ctx context.Context) rpc.ForwardPacketsResult { res := rpc.ForwardPackets(ctx, p.clientConn, p.upstreamConn) if err := res.Error(); err != nil { if res.ReadErr != nil { - p.logUpstreamError("read", res.ReadErr) + p.logUpstreamError("read", res.ReadErr, res.PacketHeaderCircularBuffer) } if res.WriteErr != nil { - p.logClientError("write", res.WriteErr) + p.logClientError("write", res.WriteErr, res.PacketHeaderCircularBuffer) } p.clientConn.ShutdownWrite() } @@ -507,7 +509,7 @@ func (p *proxyConn) responseLoop(ctx context.Context) rpc.ForwardPacketsResult { func (p *proxyConn) readRequest() (req proxyRequest, err error) { if req.tip, req.Request, err = p.clientConn.ReadPacket(p.reqBuf[:0], rpc.DefaultPacketTimeout); err != nil { - p.logClientError("read", err) + p.logClientError("read", err, rpc.PacketHeaderCircularBuffer{}) return proxyRequest{}, err } req.size = len(req.Request) @@ -519,7 +521,7 @@ func (p *proxyConn) readRequest() (req proxyRequest, err error) { p.reqBuf = req.Request // buffer reuse } if err = req.ParseInvokeReq(&p.serverOpts); err != nil { - p.logClientError("parse", err) + p.logClientError("parse", err, rpc.PacketHeaderCircularBuffer{}) return proxyRequest{}, err } requestTag := req.RequestTag() @@ -537,11 +539,11 @@ func (p *proxyConn) readRequest() (req proxyRequest, err error) { // pass return req, nil default: - p.logClientError("not supported request", err) + p.logClientError("not supported request", err, rpc.PacketHeaderCircularBuffer{}) return proxyRequest{}, rpc.ErrNoHandler } default: - p.logClientError("not supported packet", err) + p.logClientError("not supported packet", err, rpc.PacketHeaderCircularBuffer{}) return proxyRequest{}, rpc.ErrNoHandler } } @@ -563,7 +565,7 @@ func (p *proxyConn) reportRequestSize(req *proxyRequest) { p.agent.AddValueCounter(&key, float64(req.size), 1, format.BuiltinMetricMetaRPCRequests) } -func (p *proxyConn) logClientError(tag string, err error) { +func (p *proxyConn) logClientError(tag string, err error, lastPackets rpc.PacketHeaderCircularBuffer) { if err == nil || err == io.EOF { return } @@ -571,10 +573,10 @@ func (p *proxyConn) logClientError(tag string, err error) { if p.clientConn != nil { addr = p.clientConn.RemoteAddr() } - log.Printf("error %s, client addr %s, key 0x%X: %v\n", tag, addr, p.clientCryptoKeyID, err) + log.Printf("error %s, client addr %s, version %d, key 0x%X: %v, %s\n", tag, addr, p.clientProtocolVersion, p.clientCryptoKeyID, err, lastPackets.String()) } -func (p *proxyConn) logUpstreamError(tag string, err error) { +func (p *proxyConn) logUpstreamError(tag string, err error, lastPackets rpc.PacketHeaderCircularBuffer) { if err == nil || err == io.EOF { return } @@ -582,7 +584,7 @@ func (p *proxyConn) logUpstreamError(tag string, err error) { if p.upstreamConn != nil { addr = p.upstreamConn.RemoteAddr() } - log.Printf("error %s, upstream addr %s: %v\n", tag, addr, err) + log.Printf("error %s, upstream addr %s: %v, %s\n", tag, addr, err, lastPackets.String()) } func (req *proxyRequest) process(p *proxyConn) (res rpc.ForwardPacketsResult) { @@ -595,13 +597,13 @@ func (req *proxyRequest) process(p *proxyConn) (res rpc.ForwardPacketsResult) { if _, err = args.ReadBoxed(req.Request); err == nil { if args.Cluster != p.cluster { err = fmt.Errorf("statshouse misconfiguration! cluster requested %q does not match actual cluster connected %q", args.Cluster, p.cluster) - p.logClientError("GetConfig2", err) + p.logClientError("GetConfig2", err, rpc.PacketHeaderCircularBuffer{}) } else { req.Response, _ = args.WriteResult(req.Response[:0], p.config) } } if err = req.WriteReponseAndFlush(p.clientConn, err); err != nil { - p.logClientError("write", err) + p.logClientError("write", err, rpc.PacketHeaderCircularBuffer{}) // not an error ("requestLoop" exits on request read-write errors only) } default: @@ -621,7 +623,7 @@ func (req *proxyRequest) process(p *proxyConn) (res rpc.ForwardPacketsResult) { func (req *proxyRequest) forwardAndFlush(p *proxyConn) error { if err := req.ForwardAndFlush(p.upstreamConn, req.tip, rpc.DefaultPacketTimeout); err != nil { - p.logUpstreamError("write", err) + p.logUpstreamError("write", err, rpc.PacketHeaderCircularBuffer{}) return err } if cap(p.reqBuf) < cap(req.Request) { diff --git a/internal/vkgo/rpc/pingpong.go b/internal/vkgo/rpc/pingpong.go index 6cddc0f66..46e012c93 100644 --- a/internal/vkgo/rpc/pingpong.go +++ b/internal/vkgo/rpc/pingpong.go @@ -94,12 +94,12 @@ func (pc *PacketConn) onPong(body []byte) error { } pc.pingMu.Lock() defer pc.pingMu.Unlock() - if !pc.pingSent { - return fmt.Errorf("received unexpected pong %d without sending ping", pingID) - } - if pingID != pc.currentPingID { - return fmt.Errorf("received unexpected pong %d for ping %d", pingID, pc.currentPingID) - } + // if !pc.pingSent { + // return fmt.Errorf("received unexpected pong %d without sending ping", pingID) + // } + // if pingID != pc.currentPingID { + // return fmt.Errorf("received unexpected pong %d for ping %d", pingID, pc.currentPingID) + // } pc.pingSent = false return nil } diff --git a/internal/vkgo/rpc/statshouse.go b/internal/vkgo/rpc/statshouse.go index dad8d7947..55cc85370 100644 --- a/internal/vkgo/rpc/statshouse.go +++ b/internal/vkgo/rpc/statshouse.go @@ -4,6 +4,9 @@ import ( "context" "encoding/binary" "fmt" + "hash/crc32" + "io" + "strings" "time" "github.com/vkcom/statshouse/internal/vkgo/rpc/internal/gen/tl" @@ -17,11 +20,64 @@ type ReadWriteError struct { } type ForwardPacketsResult struct { + PacketHeaderCircularBuffer ReadWriteError ServerWantsFin bool ClientWantsFin bool } +type ForwardPacketResult struct { + packetHeader + ReadWriteError + ServerWantsFin bool + ClientWantsFin bool +} + +type forwardPacketOptions struct { + testEnv bool +} + +type PacketHeaderCircularBuffer struct { + s []packetHeader + x, n int +} + +func (p packetHeader) String() string { + return fmt.Sprintf(`{"len":%d,"seq":%d,"tip":0x%08X}`, p.length, p.seqNum, p.tip) +} + +func (b *PacketHeaderCircularBuffer) add(p packetHeader) { + if len(b.s) == 0 { + b.s = make([]packetHeader, 16) + } + b.s[b.x] = p + b.x = (b.x + 1) % len(b.s) + b.n++ +} + +func (b *PacketHeaderCircularBuffer) String() string { + if b.n == 0 { + return "[]" + } + var sb strings.Builder + sb.WriteString("[") + if b.n < len(b.s) { + sb.WriteString(b.s[0].String()) + for i := 1; i < b.n; i++ { + sb.WriteString(",") + sb.WriteString(b.s[i].String()) + } + } else { + sb.WriteString(b.s[b.x].String()) + for i := (b.x + 1) % len(b.s); i != b.x; i = (i + 1) % len(b.s) { + sb.WriteString(",") + sb.WriteString(b.s[i].String()) + } + } + sb.WriteString("]") + return sb.String() +} + func (r ReadWriteError) Error() error { if r.ReadErr != nil { return r.ReadErr @@ -33,92 +89,118 @@ func (pc *PacketConn) KeyID() [4]byte { return pc.keyID } -func ForwardPackets(ctx context.Context, dst, src *PacketConn) (res ForwardPacketsResult) { - for i := 0; ctx.Err() == nil && res.Error() == nil; i++ { - res = ForwardPacket(ctx, dst, src) +func ForwardPackets(ctx context.Context, dst, src *PacketConn) ForwardPacketsResult { + var res ForwardPacketResult + var buf PacketHeaderCircularBuffer + for { + res = ForwardPacket(dst, src, forwardPacketOptions{}) + buf.add(res.packetHeader) + if res.Error() != nil { + break + } + if ctx.Err() != nil { + break + } + } + return ForwardPacketsResult{ + PacketHeaderCircularBuffer: buf, + ReadWriteError: res.ReadWriteError, + ServerWantsFin: res.ServerWantsFin, + ClientWantsFin: res.ClientWantsFin, } - return res } -func ForwardPacket(ctx context.Context, dst, src *PacketConn) (res ForwardPacketsResult) { - var header packetHeader +func ForwardPacket(dst, src *PacketConn, opt forwardPacketOptions) (res ForwardPacketResult) { src.readMu.Lock() - _, isBuiltin, _, err := src.readPacketHeaderUnlocked(&header, DefaultPacketTimeout) + _, isBuiltin, _, err := src.readPacketHeaderUnlocked(&res.packetHeader, DefaultPacketTimeout) src.readMu.Unlock() if err != nil { res.ReadErr = err return res } + if opt.testEnv { + switch res.packetHeader.tip { + case packetTypeRPCNonce, packetTypeRPCHandshake: + src.table = crc32.IEEETable + default: + src.table = castagnoliTable + } + } if isBuiltin { - res.WriteErr = src.WritePacketBuiltin(time.Duration(0)) + res.WriteErr = src.WritePacketBuiltin(DefaultPacketTimeout) } else { - switch header.tip { + switch res.packetHeader.tip { case tl.RpcClientWantsFin{}.TLTag(): res.ClientWantsFin = true case tl.RpcServerWantsFin{}.TLTag(): res.ServerWantsFin = true } - res.ReadWriteError = forwardPacket(dst, src, header) + res.ReadWriteError = forwardPacket(dst, src, &res.packetHeader) } return res } var forwardPacketTrailer = [][]byte{{}, {0}, {0, 0}, {0, 0, 0}} -func forwardPacket(dst, src *PacketConn, header packetHeader) (res ReadWriteError) { +func forwardPacket(dst, src *PacketConn, header *packetHeader) (res ReadWriteError) { dst.writeMu.Lock() defer dst.writeMu.Unlock() // write header - srcBodySize := header.length - packetOverhead - dstBodySize := srcBodySize - // legacy RPC protocol used to align packet body length to 4 bytes boundary, including + bodySize := header.length - packetOverhead + // legacy RPC protocol used to align packet length to 4 bytes boundary, including // padding into checksum; packet length is guaranteed to be aligned by 4 bytes then var legacyWriteAlignTo4 uint32 if dst.protocolVersion == 0 { - legacyWriteAlignTo4 = -srcBodySize & 3 - dstBodySize += legacyWriteAlignTo4 + legacyWriteAlignTo4 = -bodySize & 3 + bodySize += legacyWriteAlignTo4 } - if err := dst.writePacketHeaderUnlocked(header.tip, int(dstBodySize), time.Duration(0)); err != nil { + if err := dst.writePacketHeaderUnlocked(header.tip, int(bodySize), DefaultPacketTimeout); err != nil { res.WriteErr = err return res } // write body - if res = copyBodySkipCRCAndCryptoPadding(dst, src, srcBodySize); res.Error() != nil { + if res = copyBodyCheckedSkipCryptoPadding(dst, src, header); res.Error() != nil { return res } if 0 < legacyWriteAlignTo4 && legacyWriteAlignTo4 < 4 { - // legacy RPC protocol, write body padding and CRC - trailer := forwardPacketTrailer[legacyWriteAlignTo4] - if _, err := dst.w.Write(trailer); err != nil { + if err := dst.writePacketBodyUnlocked(forwardPacketTrailer[legacyWriteAlignTo4]); err != nil { res.WriteErr = err return res } - dst.updateWriteCRC(trailer) - dst.headerWriteBuf = binary.LittleEndian.AppendUint32(dst.headerWriteBuf, dst.writeCRC) - } else { - // write CRC and padding - dst.writePacketTrailerUnlocked() } + // write CRC and padding + dst.writePacketTrailerUnlocked() res.WriteErr = dst.FlushUnlocked() return res } -func copyBodySkipCRCAndCryptoPadding(dst, src *PacketConn, bodySize uint32) (res ReadWriteError) { +func copyBodyCheckedSkipCryptoPadding(dst, src *PacketConn, header *packetHeader) (res ReadWriteError) { src.readMu.Lock() defer src.readMu.Unlock() // copy body - res = packetConnCopy(dst, src, int(bodySize)) + var crc uint32 + crc = crc32.Update(0, src.table, src.headerReadBuf[:12]) + crc, res = packetConnCopy(dst, src, int(header.length-packetOverhead), crc) if res.Error() != nil { return res } - // skip CRC - if err := src.r.discard(4); err != nil { + // read CRC + if _, err := io.ReadFull(src.r, src.headerReadBuf[:4]); err != nil { res.ReadErr = err return res } + readCRC := binary.LittleEndian.Uint32(src.headerReadBuf[:]) + // check CRC + if crc != readCRC { + res.ReadErr = &tagError{ + tag: "crc_mismatch", + err: fmt.Errorf("CRC mismatch: read 0x%x, expected 0x%x", readCRC, crc), + } + return res + } // skip crypto padding if src.w.isEncrypted() { - res.ReadErr = src.r.discard(int(-bodySize & 3)) + res.ReadErr = src.r.discard(int(-header.length & 3)) } return res } @@ -191,13 +273,13 @@ func (hctx *HandlerContext) writeReponseUnlocked(conn *PacketConn) error { return nil } -func packetConnCopy(dst, src *PacketConn, n int) ReadWriteError { - return cryptoCopy(dst.w, src.r, n, dst.updateWriteCRC) +func packetConnCopy(dst, src *PacketConn, n int, readCRC uint32) (uint32, ReadWriteError) { + return cryptoCopy(dst.w, src.r, n, readCRC, src.table, dst.updateWriteCRC) } -func cryptoCopy(dst *cryptoWriter, src *cryptoReader, n int, cb func([]byte)) (res ReadWriteError) { +func cryptoCopy(dst *cryptoWriter, src *cryptoReader, n int, readCRC uint32, table *crc32.Table, cb func([]byte)) (_ uint32, res ReadWriteError) { if n == 0 { - return ReadWriteError{} + return readCRC, ReadWriteError{} } for { if m := src.end - src.begin; m > 0 { @@ -207,7 +289,10 @@ func cryptoCopy(dst *cryptoWriter, src *cryptoReader, n int, cb func([]byte)) (r s := src.buf[src.begin : src.begin+m] m, res.WriteErr = dst.Write(s) if res.WriteErr != nil { - return res + return readCRC, res + } + if table != nil { + readCRC = crc32.Update(readCRC, table, s) } if cb != nil { cb(s) @@ -215,11 +300,11 @@ func cryptoCopy(dst *cryptoWriter, src *cryptoReader, n int, cb func([]byte)) (r src.begin += m n -= m if n == 0 { - return ReadWriteError{} + return readCRC, ReadWriteError{} } } if res.ReadErr != nil { - return res + return readCRC, res } buf := src.buf[:cap(src.buf)] bufSize := copy(buf, src.buf[src.end:]) @@ -237,7 +322,7 @@ func cryptoCopy(dst *cryptoWriter, src *cryptoReader, n int, cb func([]byte)) (r } if read <= 0 { // infinite loop guard res.ReadErr = errZeroRead - return res + return readCRC, res } } } diff --git a/internal/vkgo/rpc/statshouse_test.go b/internal/vkgo/rpc/statshouse_test.go index 3858689f5..283b9a4ec 100644 --- a/internal/vkgo/rpc/statshouse_test.go +++ b/internal/vkgo/rpc/statshouse_test.go @@ -2,11 +2,18 @@ package rpc import ( "bytes" - "context" + "encoding/binary" + "fmt" + "io" + "log" "net" + "os" "testing" "time" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/google/gopacket/pcap" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" "pgregory.net/rapid" @@ -53,7 +60,7 @@ func (c *cryptoPipelineMachine) ReadDiscard(t *rapid.T) { func (c *cryptoPipelineMachine) Copy(t *rapid.T) { n := rapid.IntRange(0, c.rb.Len()).Draw(t, "n") - rwe := cryptoCopy(c.w, c.r, n, nil) + _, rwe := cryptoCopy(c.w, c.r, n, 0, nil, nil) if rwe.Error() != nil { c.fatalf("copy failed: %v, %v", rwe.WriteErr, rwe.ReadErr) } @@ -86,7 +93,7 @@ func TestCryptoPipeline(t *testing.T) { fatalf: t.Fatalf, } t.Repeat(rapid.StateMachineActions(c)) - _ = cryptoCopy(c.w, c.r, c.rb.Len()+cap(c.r.buf), nil) + _, _ = cryptoCopy(c.w, c.r, c.rb.Len()+cap(c.r.buf), 0, nil, nil) _ = c.w.Flush() if !bytes.Equal(c.expected, c.actual.Bytes()) { c.fatalf("expected %q, actual %q", c.expected, c.actual.Bytes()) @@ -173,9 +180,9 @@ func (m *forwardPacketMachine) run(t *rapid.T) { minBodyLen = 4 } for i := 0; i < 512; i++ { - var group errgroup.Group - group.Go(func() error { - res := ForwardPacket(context.Background(), m.proxyClient, m.proxyServer) + var forward errgroup.Group + forward.Go(func() error { + res := ForwardPacket(m.proxyClient, m.proxyServer, forwardPacketOptions{testEnv: false}) return res.Error() }) sent := message{ @@ -185,14 +192,18 @@ func (m *forwardPacketMachine) run(t *rapid.T) { if legacyProtocol { sent.body = sent.body[:len(sent.body)-len(sent.body)%4] } - err := m.client.WritePacket(sent.tip, sent.body, 0) + err := m.client.WritePacket(sent.tip, sent.body, DefaultPacketTimeout) require.NoError(t, err) - m.client.Flush() + err = m.client.Flush() require.NoError(t, err) + var receive errgroup.Group var received message - received.tip, received.body, err = m.server.ReadPacket(nil, 0) - require.NoError(t, err) - require.NoError(t, group.Wait()) + receive.Go(func() (err error) { + received.tip, received.body, err = m.server.ReadPacket(nil, DefaultPacketTimeout) + return err + }) + require.NoError(t, forward.Wait()) + require.NoError(t, receive.Wait()) if m.protocolServer == 0 { writeAlignTo4 := int(-uint(len(sent.body)) & 3) sent.body = append(sent.body, forwardPacketTrailer[writeAlignTo4]...) @@ -209,3 +220,133 @@ func TestForwardPacket(t *testing.T) { shutdown() }) } + +type pcapEndpoint struct { + host string + port layers.TCPPort +} + +func (e pcapEndpoint) Network() string { + return "ip" +} + +func (e pcapEndpoint) String() string { + return fmt.Sprintf("%s:%d", e.host, e.port) +} + +type testConn struct { + localAddr net.Addr + remoteAddr net.Addr + buffer []byte + offset int +} + +func (c *testConn) Read(b []byte) (n int, err error) { + if c.offset == len(c.buffer) { + return 0, io.EOF + } + n = copy(b, c.buffer[c.offset:]) + c.offset += n + return n, nil +} + +func (c *testConn) Write(b []byte) (int, error) { + return len(b), nil // nop +} + +func (c *testConn) Close() error { + c.offset = len(c.buffer) + return nil +} + +func (c *testConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *testConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *testConn) SetDeadline(_ time.Time) error { + return nil +} + +func (c *testConn) SetReadDeadline(_ time.Time) error { + return nil +} + +func (c *testConn) SetWriteDeadline(_ time.Time) error { + return nil +} + +// NB! remove "received unexpected pong" assertion fot test to pass +func TestPlayPcap(t *testing.T) { + path := os.Getenv("STATSHOUSE_TEST_PLAY_PCAP_FILE_PATH") + if path == "" { + return + } + log.Println("PCAP play", path) + for k, v := range readPCAP(t, path, "") { + playPcap(t, k, v) + } +} + +func playPcap(t *testing.T, k [2]pcapEndpoint, v []byte) { + srcConn := &testConn{ + buffer: v, + localAddr: k[0], + remoteAddr: k[1], + } + src := &PacketConn{ + conn: srcConn, + timeoutAccuracy: DefaultConnTimeoutAccuracy, + r: newCryptoReader(srcConn, DefaultServerRequestBufSize), + w: newCryptoWriter(srcConn, DefaultServerResponseBufSize), + readSeqNum: int64(binary.LittleEndian.Uint32(v[4:])), + } + dstConn := &testConn{ + localAddr: k[0], + remoteAddr: k[1], + } + dst := &PacketConn{ + conn: dstConn, + timeoutAccuracy: DefaultConnTimeoutAccuracy, + r: newCryptoReader(dstConn, DefaultServerRequestBufSize), + w: newCryptoWriter(dstConn, DefaultServerResponseBufSize), + table: castagnoliTable, + } + var buf PacketHeaderCircularBuffer + for { + res := ForwardPacket(dst, src, forwardPacketOptions{testEnv: true}) + buf.add(res.packetHeader) + if res.Error() != nil { + require.ErrorIsf(t, res.ReadErr, io.EOF, "%v %s", k, buf.String()) + require.NoError(t, res.WriteErr) + break + } + } +} + +func readPCAP(t *testing.T, path string, dstHost string) map[[2]pcapEndpoint][]byte { + handle, err := pcap.OpenOffline(path) + require.NoError(t, err) + packetSource := gopacket.NewPacketSource(handle, handle.LinkType()) + m := map[[2]pcapEndpoint][]byte{} + for p := range packetSource.Packets() { + var src, dst pcapEndpoint + ip := p.Layer(layers.LayerTypeIPv4).(*layers.IPv4) + src.host = ip.SrcIP.String() + dst.host = ip.DstIP.String() + if dstHost != "" && dst.host != dstHost { + continue + } + tcp := p.Layer(layers.LayerTypeTCP).(*layers.TCP) + src.port = tcp.SrcPort + dst.port = tcp.DstPort + if appLayer := p.ApplicationLayer(); appLayer != nil { + k := [2]pcapEndpoint{src, dst} + m[k] = append(m[k], appLayer.Payload()...) + } + } + return m +}