From ee0181b69e4683611f68126d519ecc80f09d9fd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Reis?= Date: Thu, 20 Oct 2022 16:06:37 +0100 Subject: [PATCH] Fix flaky tests and fix #69 (#67) Fix #69 Fix more flaky tests --- CHANGELOG/CHANGELOG-2.0.md | 1 + integration-tests/asyncreads_test.go | 9 ++-- integration-tests/connect_test.go | 4 +- integration-tests/metrics_test.go | 52 +++++++++++++------ integration-tests/prepared_statements_test.go | 12 ++++- integration-tests/shutdown_test.go | 47 +++++++++++------ integration-tests/tls_test.go | 37 +++++++++++-- proxy/pkg/zdmproxy/clientconn.go | 19 ++++--- proxy/pkg/zdmproxy/clienthandler.go | 5 +- 9 files changed, 130 insertions(+), 56 deletions(-) diff --git a/CHANGELOG/CHANGELOG-2.0.md b/CHANGELOG/CHANGELOG-2.0.md index edc0b842..f2695d35 100644 --- a/CHANGELOG/CHANGELOG-2.0.md +++ b/CHANGELOG/CHANGELOG-2.0.md @@ -9,6 +9,7 @@ When cutting a new release, update the `unreleased` heading to the tag being gen ### Bug Fixes * [#48](https://github.com/datastax/zdm-proxy/issues/48) Fix scheduler shutdown race condition +* [#69](https://github.com/datastax/zdm-proxy/issues/69) Client connection can be closed before proxy returns protocol error ## v2.0.0 - 2022-10-17 diff --git a/integration-tests/asyncreads_test.go b/integration-tests/asyncreads_test.go index d62ed3fc..3510d0b2 100644 --- a/integration-tests/asyncreads_test.go +++ b/integration-tests/asyncreads_test.go @@ -299,13 +299,12 @@ func TestAsyncReadsRequestTypes(t *testing.T) { require.Nil(t, err) defer proxy.Shutdown() - client := client.NewCqlClient("127.0.0.1:14002", nil) - cqlClientConn, err := client.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 0) - require.Nil(t, err) - defer cqlClientConn.Close() - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + client := client.NewCqlClient("127.0.0.1:14002", nil) + cqlClientConn, err := client.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 0) + require.Nil(t, err) + defer cqlClientConn.Close() err = testSetup.Origin.DeleteLogs() require.Nil(t, err) err = testSetup.Target.DeleteLogs() diff --git a/integration-tests/connect_test.go b/integration-tests/connect_test.go index 716a5b6e..b7df5f71 100644 --- a/integration-tests/connect_test.go +++ b/integration-tests/connect_test.go @@ -106,9 +106,9 @@ func TestRequestedProtocolVersionUnsupportedByProxy(t *testing.T) { oldLevel := log.GetLevel() oldZeroLogLevel := zerolog.GlobalLevel() - log.SetLevel(log.WarnLevel) + log.SetLevel(log.TraceLevel) defer log.SetLevel(oldLevel) - zerolog.SetGlobalLevel(zerolog.WarnLevel) + zerolog.SetGlobalLevel(zerolog.TraceLevel) defer zerolog.SetGlobalLevel(oldZeroLogLevel) cfg := setup.NewTestConfig("127.0.1.1", "127.0.1.2") diff --git a/integration-tests/metrics_test.go b/integration-tests/metrics_test.go index 6d9ba632..30f99ff8 100644 --- a/integration-tests/metrics_test.go +++ b/integration-tests/metrics_test.go @@ -180,6 +180,17 @@ func testMetrics(t *testing.T, metricsHandler *httpzdmproxy.HandlerWithFallback) } } +func requireEventuallyContainsLine(t *testing.T, lines []string, line string) { + utils.RequireWithRetries(t, + func() (err error, fatal bool) { + if containsLine(lines, line) { + return nil, false + } + return fmt.Errorf("%v does not contain %v", lines, line), false + }, + 25, 200) +} + func containsLine(lines []string, line string) bool { for i := 0; i < len(lines); i++ { if lines[i] == line { @@ -305,7 +316,7 @@ func checkMetrics( if checkNodeMetrics { if asyncEnabled { - require.Contains(t, lines, fmt.Sprintf("%v %v", getPrometheusNameWithNodeLabel(prefix, metrics.InFlightRequestsAsync, asyncHost), 0)) + requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v %v", getPrometheusNameWithNodeLabel(prefix, metrics.InFlightRequestsAsync, asyncHost), 0)) } else { require.NotContains(t, lines, fmt.Sprintf("%v", getPrometheusName(prefix, metrics.InFlightRequestsAsync))) } @@ -314,16 +325,16 @@ func checkMetrics( require.Contains(t, lines, fmt.Sprintf("%v{node=\"%v\"} %v", getPrometheusName(prefix, metrics.OpenTargetConnections), targetHost, openTargetConns)) if asyncEnabled { - require.Contains(t, lines, fmt.Sprintf("%v{node=\"%v\"} %v", getPrometheusName(prefix, metrics.OpenAsyncConnections), asyncHost, openAsyncConns)) - require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncReadTimeouts, asyncHost))) - require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncWriteTimeouts, asyncHost))) - require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncOtherErrors, asyncHost))) - require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncClientTimeouts, asyncHost))) - require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncUnavailableErrors, asyncHost))) - require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncReadFailures, asyncHost))) - require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncWriteFailures, asyncHost))) - require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncOverloadedErrors, asyncHost))) - require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncUnpreparedErrors, asyncHost))) + requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v{node=\"%v\"} %v", getPrometheusName(prefix, metrics.OpenAsyncConnections), asyncHost, openAsyncConns)) + requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncReadTimeouts, asyncHost))) + requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncWriteTimeouts, asyncHost))) + requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncOtherErrors, asyncHost))) + requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncClientTimeouts, asyncHost))) + requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncUnavailableErrors, asyncHost))) + requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncReadFailures, asyncHost))) + requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncWriteFailures, asyncHost))) + requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncOverloadedErrors, asyncHost))) + requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncUnpreparedErrors, asyncHost))) } else { require.NotContains(t, lines, fmt.Sprintf("%v", getPrometheusName(prefix, metrics.OpenAsyncConnections))) require.NotContains(t, lines, fmt.Sprintf("%v", getPrometheusName(prefix, metrics.AsyncReadTimeouts))) @@ -379,20 +390,29 @@ func checkMetrics( if successAsync == 0 || !asyncEnabled { if asyncEnabled { - require.Contains(t, lines, fmt.Sprintf("%v{node=\"%v\"} 0", getPrometheusNameWithSuffix(prefix, metrics.AsyncRequestDuration, "sum"), asyncHost)) + requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v{node=\"%v\"} 0", getPrometheusNameWithSuffix(prefix, metrics.AsyncRequestDuration, "sum"), asyncHost)) } else { require.NotContains(t, lines, fmt.Sprintf("%v", getPrometheusNameWithSuffix(prefix, metrics.AsyncRequestDuration, "sum"))) } } else { - value, err := findMetricValue(lines, fmt.Sprintf("%v{node=\"%v\"} ", getPrometheusNameWithSuffix(prefix, metrics.AsyncRequestDuration, "sum"), asyncHost)) - require.Nil(t, err) - require.Greater(t, value, 0.0) + utils.RequireWithRetries(t, func() (err error, fatal bool) { + prefix := fmt.Sprintf("%v{node=\"%v\"} ", getPrometheusNameWithSuffix(prefix, metrics.AsyncRequestDuration, "sum"), asyncHost) + value, err := findMetricValue(lines, prefix) + if err != nil { + return err, false + } + if value <= 0.0 { + return fmt.Errorf("%v expected greater than 0.0 but was %v", prefix, value), false + } + + return nil, false + }, 25, 200) } require.Contains(t, lines, fmt.Sprintf("%v{node=\"%v\"} %v", getPrometheusNameWithSuffix(prefix, metrics.TargetRequestDuration, "count"), targetHost, successTarget+successBoth)) require.Contains(t, lines, fmt.Sprintf("%v{node=\"%v\"} %v", getPrometheusNameWithSuffix(prefix, metrics.OriginRequestDuration, "count"), originHost, successOrigin+successBoth)) if asyncEnabled { - require.Contains(t, lines, fmt.Sprintf("%v{node=\"%v\"} %v", getPrometheusNameWithSuffix(prefix, metrics.AsyncRequestDuration, "count"), asyncHost, successAsync)) + requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v{node=\"%v\"} %v", getPrometheusNameWithSuffix(prefix, metrics.AsyncRequestDuration, "count"), asyncHost, successAsync)) } else { require.NotContains(t, lines, fmt.Sprintf("%v", getPrometheusNameWithSuffix(prefix, metrics.OriginRequestDuration, "count"))) } diff --git a/integration-tests/prepared_statements_test.go b/integration-tests/prepared_statements_test.go index a8523473..3775cc12 100644 --- a/integration-tests/prepared_statements_test.go +++ b/integration-tests/prepared_statements_test.go @@ -418,6 +418,7 @@ func TestPreparedIdReplacement(t *testing.T) { } expectedTargetPrepares := 1 + expectedMaxTargetPrepares := 1 expectedTargetExecutes := 0 expectedTargetBatches := 0 expectedOriginPrepares := 1 @@ -427,11 +428,15 @@ func TestPreparedIdReplacement(t *testing.T) { expectedTargetExecutes += 1 } if dualReadsEnabled { + // depending on goroutine scheduling, async cluster connector might receive an UNPREPARED and send a PREPARE on its own or not + // so with async reads we will assert greater or equal instead of equal expectedTargetPrepares += 1 + expectedMaxTargetPrepares += 2 } if test.batchQuery != "" { expectedTargetBatches += 1 expectedTargetPrepares += 1 + expectedMaxTargetPrepares += 1 expectedOriginBatches += 1 expectedOriginPrepares += 1 } @@ -439,8 +444,11 @@ func TestPreparedIdReplacement(t *testing.T) { utils.RequireWithRetries(t, func() (err error, fatal bool) { targetLock.Lock() defer targetLock.Unlock() - if expectedTargetPrepares != len(targetPrepareMessages) { - return fmt.Errorf("expectedTargetPrepares %v != %v", expectedTargetPrepares, len(targetPrepareMessages)), false + if len(targetPrepareMessages) < expectedTargetPrepares { + return fmt.Errorf("expectedTargetPrepares %v < %v", len(targetPrepareMessages), expectedTargetPrepares), false + } + if len(targetPrepareMessages) > expectedMaxTargetPrepares { + return fmt.Errorf("expectedMaxTargetPrepares %v > %v", len(targetPrepareMessages), expectedMaxTargetPrepares), false } if expectedTargetExecutes != len(targetExecuteMessages) { return fmt.Errorf("expectedTargetExecutes %v != %v", expectedTargetExecutes, len(targetExecuteMessages)), false diff --git a/integration-tests/shutdown_test.go b/integration-tests/shutdown_test.go index 6b2db410..b6b44cda 100644 --- a/integration-tests/shutdown_test.go +++ b/integration-tests/shutdown_test.go @@ -67,6 +67,9 @@ func TestShutdownInFlightRequests(t *testing.T) { testSetup.Origin.Prime( simulacron.WhenQuery("SELECT * FROM test2", simulacron.NewWhenQueryOptions()). ThenSuccess().WithDelay(3 * time.Second)) + testSetup.Origin.Prime( + simulacron.WhenQuery("SELECT * FROM test3", simulacron.NewWhenQueryOptions()). + ThenSuccess().WithDelay(4 * time.Second)) queryMsg1 := &message.Query{ Query: "SELECT * FROM test1", @@ -78,21 +81,24 @@ func TestShutdownInFlightRequests(t *testing.T) { Options: nil, } + queryMsg3 := &message.Query{ + Query: "SELECT * FROM test3", + Options: nil, + } + beginTimestamp := time.Now() reqFrame := frame.NewFrame(primitive.ProtocolVersion4, 2, queryMsg1) inflightRequest, err := cqlConn.Send(reqFrame) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + require.Nil(t, err) reqFrame2 := frame.NewFrame(primitive.ProtocolVersion4, 3, queryMsg2) inflightRequest2, err := cqlConn.Send(reqFrame2) + require.Nil(t, err) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + reqFrame3 := frame.NewFrame(primitive.ProtocolVersion4, 4, queryMsg3) + inflightRequest3, err := cqlConn.Send(reqFrame3) + require.Nil(t, err) time.Sleep(1 * time.Second) @@ -119,15 +125,12 @@ func TestShutdownInFlightRequests(t *testing.T) { default: } - reqFrame3 := frame.NewFrame(primitive.ProtocolVersion4, 4, queryMsg1) - inflightRequest3, err := cqlConn.Send(reqFrame3) - - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + reqFrame4 := frame.NewFrame(primitive.ProtocolVersion4, 5, queryMsg1) + inflightRequest4, err := cqlConn.Send(reqFrame4) + require.Nil(t, err) select { - case rsp := <-inflightRequest3.Incoming(): + case rsp := <-inflightRequest4.Incoming(): require.Equal(t, primitive.OpCodeError, rsp.Header.OpCode) _, ok := rsp.Body.Message.(*message.Overloaded) require.True(t, ok) @@ -136,14 +139,24 @@ func TestShutdownInFlightRequests(t *testing.T) { } select { - case rsp := <-inflightRequest2.Incoming(): + case rsp, ok := <-inflightRequest2.Incoming(): + require.True(t, ok) require.Equal(t, primitive.OpCodeResult, rsp.Header.OpCode) case <-time.After(15 * time.Second): t.Fatalf("test timed out after 15 seconds") } - // 2 seconds instead of 3 just in case there is a time precision issue - require.GreaterOrEqual(t, time.Now().Sub(beginTimestamp).Nanoseconds(), (2 * time.Second).Nanoseconds()) + select { + case rsp, ok := <-inflightRequest3.Incoming(): + if ok { // ignore if last request's channel is closed before we read from it + require.Equal(t, primitive.OpCodeResult, rsp.Header.OpCode) + } + case <-time.After(15 * time.Second): + t.Fatalf("test timed out after 15 seconds") + } + + // 3 seconds instead of 4 just in case there is a time precision issue + require.GreaterOrEqual(t, time.Now().Sub(beginTimestamp).Nanoseconds(), (3 * time.Second).Nanoseconds()) select { case <-shutdownComplete: diff --git a/integration-tests/tls_test.go b/integration-tests/tls_test.go index ecacdaa8..6e656489 100644 --- a/integration-tests/tls_test.go +++ b/integration-tests/tls_test.go @@ -12,6 +12,7 @@ import ( "github.com/datastax/zdm-proxy/integration-tests/env" "github.com/datastax/zdm-proxy/integration-tests/setup" "github.com/datastax/zdm-proxy/proxy/pkg/config" + zerologger "github.com/rs/zerolog/log" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "io/ioutil" @@ -410,7 +411,7 @@ func TestTls_OneWayOrigin_MutualTarget(t *testing.T) { serverName: "", errExpected: true, errWarningsExpected: []string{"tls: client didn't provide a certificate"}, - errMsgExpected: "", + errMsgExpected: "remote error: tls: bad certificate", }, { name: "Proxy: Mutual TLS and SNI on Client, Mutual TLS on Listener, One-way TLS on Origin, mutual TLS on Target", @@ -1056,16 +1057,36 @@ func testProxyClientTls(t *testing.T, ccmSetup *setup.CcmTestSetup, logMessages := buffer.String() - for _, errWarnExpected := range proxyTlsConfig.errWarningsExpected { - require.True(t, strings.Contains(logMessages, errWarnExpected), "%v not found", errWarnExpected) + warningAssertionFailed := false + warningExpected := false + for _, expectedWarningMsg := range proxyTlsConfig.errWarningsExpected { + warningExpected = true + if !strings.Contains(logMessages, expectedWarningMsg) { + t.Logf("%v not found in %v", expectedWarningMsg, logMessages) + warningAssertionFailed = true + } } if proxyTlsConfig.errExpected { require.NotNil(t, err, "Did not get expected error %s", proxyTlsConfig.errMsgExpected) + errorAssertionFailed := false + errorExpected := false if proxyTlsConfig.errMsgExpected != "" { - require.True(t, strings.Contains(err.Error(), proxyTlsConfig.errMsgExpected), err.Error()) + errorExpected = true + if !strings.Contains(err.Error(), proxyTlsConfig.errMsgExpected) { + errorAssertionFailed = true + t.Logf("%v not found in %v", err.Error(), proxyTlsConfig.errMsgExpected) + } + } + if errorExpected && warningExpected { + require.False(t, errorAssertionFailed && warningAssertionFailed) // only 1 check needs to pass in this scenario + } else if errorExpected { + require.False(t, errorAssertionFailed) + } else if warningExpected { + require.False(t, warningAssertionFailed) } } else { + require.False(t, warningAssertionFailed) require.Nil(t, err, "testClient setup failed: %v", err) // create schema on clusters through the proxy sendRequest(cqlConn, "CREATE KEYSPACE IF NOT EXISTS testks "+ @@ -1340,6 +1361,14 @@ func getClientSideVerifyConnectionCallback(rootCAs *x509.CertPool) func(cs tls.C } } _, err := cs.PeerCertificates[0].Verify(opts) + if err != nil { + // use zerolog to avoid interacting with proxy's log messages that are used for assertions in the test + zerologger.Warn(). + Interface("cs", cs). + Interface("verifyopts", opts). + Interface("peercertificates", cs.PeerCertificates). + Msgf("client side verify callback error: %v", err) + } return err } } diff --git a/proxy/pkg/zdmproxy/clientconn.go b/proxy/pkg/zdmproxy/clientconn.go index b5039643..37797ce8 100644 --- a/proxy/pkg/zdmproxy/clientconn.go +++ b/proxy/pkg/zdmproxy/clientconn.go @@ -37,6 +37,7 @@ type ClientConnector struct { clientHandlerContext context.Context clientHandlerCancelFunc context.CancelFunc + // not used atm but should be used when a protocol error occurs after #68 has been addressed clientHandlerShutdownRequestCancelFn context.CancelFunc writeCoalescer *writeCoalescer @@ -165,6 +166,7 @@ func (cc *ClientConnector) listenForRequests() { bufferedReader := bufio.NewReaderSize(cc.connection, cc.conf.RequestWriteBufferSizeBytes) connectionAddr := cc.connection.RemoteAddr().String() protocolErrOccurred := false + var alreadySentProtocolErr *frame.RawFrame for cc.clientHandlerContext.Err() == nil { f, err := readRawFrame(bufferedReader, connectionAddr, cc.clientHandlerContext) @@ -174,14 +176,15 @@ func (cc *ClientConnector) listenForRequests() { err, cc.clientHandlerContext, cc.clientHandlerCancelFunc, ClientConnectorLogPrefix, "reading", connectionAddr) break } else if protocolErrResponseFrame != nil { - f = protocolErrResponseFrame - if !protocolErrOccurred { - protocolErrOccurred = true - cc.sendResponseToClient(protocolErrResponseFrame) - cc.clientHandlerShutdownRequestCancelFn() - setDrainModeNowFunc() - continue - } + alreadySentProtocolErr = protocolErrResponseFrame + protocolErrOccurred = true + cc.sendResponseToClient(protocolErrResponseFrame) + continue + } else if alreadySentProtocolErr != nil { + clonedProtocolErr := alreadySentProtocolErr.Clone() + clonedProtocolErr.Header.StreamId = f.Header.StreamId + cc.sendResponseToClient(clonedProtocolErr) + continue } wg.Add(1) diff --git a/proxy/pkg/zdmproxy/clienthandler.go b/proxy/pkg/zdmproxy/clienthandler.go index 85e9a17f..162bb235 100644 --- a/proxy/pkg/zdmproxy/clienthandler.go +++ b/proxy/pkg/zdmproxy/clienthandler.go @@ -114,8 +114,10 @@ type ClientHandler struct { parameterModifier *ParameterModifier timeUuidGenerator TimeUuidGenerator + // not used atm but should be used when a protocol error occurs after #68 has been addressed clientHandlerShutdownRequestCancelFn context.CancelFunc - clientHandlerShutdownRequestContext context.Context + + clientHandlerShutdownRequestContext context.Context } func NewClientHandler( @@ -643,7 +645,6 @@ func (ch *ClientHandler) tryProcessProtocolError(response *Response, protocolErr errMsg, response.connectorType) } ch.clientConnector.sendResponseToClient(response.responseFrame) - ch.clientHandlerShutdownRequestCancelFn() } return true }