Skip to content

Commit

Permalink
Fix flaky tests and fix #69 (#67)
Browse files Browse the repository at this point in the history
Fix #69 
Fix more flaky tests
  • Loading branch information
joao-r-reis authored Oct 20, 2022
1 parent 62a8407 commit ee0181b
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 56 deletions.
1 change: 1 addition & 0 deletions CHANGELOG/CHANGELOG-2.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ When cutting a new release, update the `unreleased` heading to the tag being gen
### Bug Fixes

* [#48](https://github.com/datastax/zdm-proxy/issues/48) Fix scheduler shutdown race condition
* [#69](https://github.com/datastax/zdm-proxy/issues/69) Client connection can be closed before proxy returns protocol error

## v2.0.0 - 2022-10-17

Expand Down
9 changes: 4 additions & 5 deletions integration-tests/asyncreads_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -299,13 +299,12 @@ func TestAsyncReadsRequestTypes(t *testing.T) {
require.Nil(t, err)
defer proxy.Shutdown()

client := client.NewCqlClient("127.0.0.1:14002", nil)
cqlClientConn, err := client.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 0)
require.Nil(t, err)
defer cqlClientConn.Close()

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := client.NewCqlClient("127.0.0.1:14002", nil)
cqlClientConn, err := client.ConnectAndInit(context.Background(), primitive.ProtocolVersion4, 0)
require.Nil(t, err)
defer cqlClientConn.Close()
err = testSetup.Origin.DeleteLogs()
require.Nil(t, err)
err = testSetup.Target.DeleteLogs()
Expand Down
4 changes: 2 additions & 2 deletions integration-tests/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ func TestRequestedProtocolVersionUnsupportedByProxy(t *testing.T) {

oldLevel := log.GetLevel()
oldZeroLogLevel := zerolog.GlobalLevel()
log.SetLevel(log.WarnLevel)
log.SetLevel(log.TraceLevel)
defer log.SetLevel(oldLevel)
zerolog.SetGlobalLevel(zerolog.WarnLevel)
zerolog.SetGlobalLevel(zerolog.TraceLevel)
defer zerolog.SetGlobalLevel(oldZeroLogLevel)

cfg := setup.NewTestConfig("127.0.1.1", "127.0.1.2")
Expand Down
52 changes: 36 additions & 16 deletions integration-tests/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,17 @@ func testMetrics(t *testing.T, metricsHandler *httpzdmproxy.HandlerWithFallback)
}
}

func requireEventuallyContainsLine(t *testing.T, lines []string, line string) {
utils.RequireWithRetries(t,
func() (err error, fatal bool) {
if containsLine(lines, line) {
return nil, false
}
return fmt.Errorf("%v does not contain %v", lines, line), false
},
25, 200)
}

func containsLine(lines []string, line string) bool {
for i := 0; i < len(lines); i++ {
if lines[i] == line {
Expand Down Expand Up @@ -305,7 +316,7 @@ func checkMetrics(

if checkNodeMetrics {
if asyncEnabled {
require.Contains(t, lines, fmt.Sprintf("%v %v", getPrometheusNameWithNodeLabel(prefix, metrics.InFlightRequestsAsync, asyncHost), 0))
requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v %v", getPrometheusNameWithNodeLabel(prefix, metrics.InFlightRequestsAsync, asyncHost), 0))
} else {
require.NotContains(t, lines, fmt.Sprintf("%v", getPrometheusName(prefix, metrics.InFlightRequestsAsync)))
}
Expand All @@ -314,16 +325,16 @@ func checkMetrics(
require.Contains(t, lines, fmt.Sprintf("%v{node=\"%v\"} %v", getPrometheusName(prefix, metrics.OpenTargetConnections), targetHost, openTargetConns))

if asyncEnabled {
require.Contains(t, lines, fmt.Sprintf("%v{node=\"%v\"} %v", getPrometheusName(prefix, metrics.OpenAsyncConnections), asyncHost, openAsyncConns))
require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncReadTimeouts, asyncHost)))
require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncWriteTimeouts, asyncHost)))
require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncOtherErrors, asyncHost)))
require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncClientTimeouts, asyncHost)))
require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncUnavailableErrors, asyncHost)))
require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncReadFailures, asyncHost)))
require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncWriteFailures, asyncHost)))
require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncOverloadedErrors, asyncHost)))
require.Contains(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncUnpreparedErrors, asyncHost)))
requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v{node=\"%v\"} %v", getPrometheusName(prefix, metrics.OpenAsyncConnections), asyncHost, openAsyncConns))
requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncReadTimeouts, asyncHost)))
requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncWriteTimeouts, asyncHost)))
requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncOtherErrors, asyncHost)))
requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncClientTimeouts, asyncHost)))
requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncUnavailableErrors, asyncHost)))
requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncReadFailures, asyncHost)))
requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncWriteFailures, asyncHost)))
requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncOverloadedErrors, asyncHost)))
requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v 0", getPrometheusNameWithNodeLabel(prefix, metrics.AsyncUnpreparedErrors, asyncHost)))
} else {
require.NotContains(t, lines, fmt.Sprintf("%v", getPrometheusName(prefix, metrics.OpenAsyncConnections)))
require.NotContains(t, lines, fmt.Sprintf("%v", getPrometheusName(prefix, metrics.AsyncReadTimeouts)))
Expand Down Expand Up @@ -379,20 +390,29 @@ func checkMetrics(

if successAsync == 0 || !asyncEnabled {
if asyncEnabled {
require.Contains(t, lines, fmt.Sprintf("%v{node=\"%v\"} 0", getPrometheusNameWithSuffix(prefix, metrics.AsyncRequestDuration, "sum"), asyncHost))
requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v{node=\"%v\"} 0", getPrometheusNameWithSuffix(prefix, metrics.AsyncRequestDuration, "sum"), asyncHost))
} else {
require.NotContains(t, lines, fmt.Sprintf("%v", getPrometheusNameWithSuffix(prefix, metrics.AsyncRequestDuration, "sum")))
}
} else {
value, err := findMetricValue(lines, fmt.Sprintf("%v{node=\"%v\"} ", getPrometheusNameWithSuffix(prefix, metrics.AsyncRequestDuration, "sum"), asyncHost))
require.Nil(t, err)
require.Greater(t, value, 0.0)
utils.RequireWithRetries(t, func() (err error, fatal bool) {
prefix := fmt.Sprintf("%v{node=\"%v\"} ", getPrometheusNameWithSuffix(prefix, metrics.AsyncRequestDuration, "sum"), asyncHost)
value, err := findMetricValue(lines, prefix)
if err != nil {
return err, false
}
if value <= 0.0 {
return fmt.Errorf("%v expected greater than 0.0 but was %v", prefix, value), false
}

return nil, false
}, 25, 200)
}

require.Contains(t, lines, fmt.Sprintf("%v{node=\"%v\"} %v", getPrometheusNameWithSuffix(prefix, metrics.TargetRequestDuration, "count"), targetHost, successTarget+successBoth))
require.Contains(t, lines, fmt.Sprintf("%v{node=\"%v\"} %v", getPrometheusNameWithSuffix(prefix, metrics.OriginRequestDuration, "count"), originHost, successOrigin+successBoth))
if asyncEnabled {
require.Contains(t, lines, fmt.Sprintf("%v{node=\"%v\"} %v", getPrometheusNameWithSuffix(prefix, metrics.AsyncRequestDuration, "count"), asyncHost, successAsync))
requireEventuallyContainsLine(t, lines, fmt.Sprintf("%v{node=\"%v\"} %v", getPrometheusNameWithSuffix(prefix, metrics.AsyncRequestDuration, "count"), asyncHost, successAsync))
} else {
require.NotContains(t, lines, fmt.Sprintf("%v", getPrometheusNameWithSuffix(prefix, metrics.OriginRequestDuration, "count")))
}
Expand Down
12 changes: 10 additions & 2 deletions integration-tests/prepared_statements_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ func TestPreparedIdReplacement(t *testing.T) {
}

expectedTargetPrepares := 1
expectedMaxTargetPrepares := 1
expectedTargetExecutes := 0
expectedTargetBatches := 0
expectedOriginPrepares := 1
Expand All @@ -427,20 +428,27 @@ func TestPreparedIdReplacement(t *testing.T) {
expectedTargetExecutes += 1
}
if dualReadsEnabled {
// depending on goroutine scheduling, async cluster connector might receive an UNPREPARED and send a PREPARE on its own or not
// so with async reads we will assert greater or equal instead of equal
expectedTargetPrepares += 1
expectedMaxTargetPrepares += 2
}
if test.batchQuery != "" {
expectedTargetBatches += 1
expectedTargetPrepares += 1
expectedMaxTargetPrepares += 1
expectedOriginBatches += 1
expectedOriginPrepares += 1
}

utils.RequireWithRetries(t, func() (err error, fatal bool) {
targetLock.Lock()
defer targetLock.Unlock()
if expectedTargetPrepares != len(targetPrepareMessages) {
return fmt.Errorf("expectedTargetPrepares %v != %v", expectedTargetPrepares, len(targetPrepareMessages)), false
if len(targetPrepareMessages) < expectedTargetPrepares {
return fmt.Errorf("expectedTargetPrepares %v < %v", len(targetPrepareMessages), expectedTargetPrepares), false
}
if len(targetPrepareMessages) > expectedMaxTargetPrepares {
return fmt.Errorf("expectedMaxTargetPrepares %v > %v", len(targetPrepareMessages), expectedMaxTargetPrepares), false
}
if expectedTargetExecutes != len(targetExecuteMessages) {
return fmt.Errorf("expectedTargetExecutes %v != %v", expectedTargetExecutes, len(targetExecuteMessages)), false
Expand Down
47 changes: 30 additions & 17 deletions integration-tests/shutdown_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ func TestShutdownInFlightRequests(t *testing.T) {
testSetup.Origin.Prime(
simulacron.WhenQuery("SELECT * FROM test2", simulacron.NewWhenQueryOptions()).
ThenSuccess().WithDelay(3 * time.Second))
testSetup.Origin.Prime(
simulacron.WhenQuery("SELECT * FROM test3", simulacron.NewWhenQueryOptions()).
ThenSuccess().WithDelay(4 * time.Second))

queryMsg1 := &message.Query{
Query: "SELECT * FROM test1",
Expand All @@ -78,21 +81,24 @@ func TestShutdownInFlightRequests(t *testing.T) {
Options: nil,
}

queryMsg3 := &message.Query{
Query: "SELECT * FROM test3",
Options: nil,
}

beginTimestamp := time.Now()

reqFrame := frame.NewFrame(primitive.ProtocolVersion4, 2, queryMsg1)
inflightRequest, err := cqlConn.Send(reqFrame)

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
require.Nil(t, err)

reqFrame2 := frame.NewFrame(primitive.ProtocolVersion4, 3, queryMsg2)
inflightRequest2, err := cqlConn.Send(reqFrame2)
require.Nil(t, err)

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
reqFrame3 := frame.NewFrame(primitive.ProtocolVersion4, 4, queryMsg3)
inflightRequest3, err := cqlConn.Send(reqFrame3)
require.Nil(t, err)

time.Sleep(1 * time.Second)

Expand All @@ -119,15 +125,12 @@ func TestShutdownInFlightRequests(t *testing.T) {
default:
}

reqFrame3 := frame.NewFrame(primitive.ProtocolVersion4, 4, queryMsg1)
inflightRequest3, err := cqlConn.Send(reqFrame3)

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
reqFrame4 := frame.NewFrame(primitive.ProtocolVersion4, 5, queryMsg1)
inflightRequest4, err := cqlConn.Send(reqFrame4)
require.Nil(t, err)

select {
case rsp := <-inflightRequest3.Incoming():
case rsp := <-inflightRequest4.Incoming():
require.Equal(t, primitive.OpCodeError, rsp.Header.OpCode)
_, ok := rsp.Body.Message.(*message.Overloaded)
require.True(t, ok)
Expand All @@ -136,14 +139,24 @@ func TestShutdownInFlightRequests(t *testing.T) {
}

select {
case rsp := <-inflightRequest2.Incoming():
case rsp, ok := <-inflightRequest2.Incoming():
require.True(t, ok)
require.Equal(t, primitive.OpCodeResult, rsp.Header.OpCode)
case <-time.After(15 * time.Second):
t.Fatalf("test timed out after 15 seconds")
}

// 2 seconds instead of 3 just in case there is a time precision issue
require.GreaterOrEqual(t, time.Now().Sub(beginTimestamp).Nanoseconds(), (2 * time.Second).Nanoseconds())
select {
case rsp, ok := <-inflightRequest3.Incoming():
if ok { // ignore if last request's channel is closed before we read from it
require.Equal(t, primitive.OpCodeResult, rsp.Header.OpCode)
}
case <-time.After(15 * time.Second):
t.Fatalf("test timed out after 15 seconds")
}

// 3 seconds instead of 4 just in case there is a time precision issue
require.GreaterOrEqual(t, time.Now().Sub(beginTimestamp).Nanoseconds(), (3 * time.Second).Nanoseconds())

select {
case <-shutdownComplete:
Expand Down
37 changes: 33 additions & 4 deletions integration-tests/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/datastax/zdm-proxy/integration-tests/env"
"github.com/datastax/zdm-proxy/integration-tests/setup"
"github.com/datastax/zdm-proxy/proxy/pkg/config"
zerologger "github.com/rs/zerolog/log"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"io/ioutil"
Expand Down Expand Up @@ -410,7 +411,7 @@ func TestTls_OneWayOrigin_MutualTarget(t *testing.T) {
serverName: "",
errExpected: true,
errWarningsExpected: []string{"tls: client didn't provide a certificate"},
errMsgExpected: "",
errMsgExpected: "remote error: tls: bad certificate",
},
{
name: "Proxy: Mutual TLS and SNI on Client, Mutual TLS on Listener, One-way TLS on Origin, mutual TLS on Target",
Expand Down Expand Up @@ -1056,16 +1057,36 @@ func testProxyClientTls(t *testing.T, ccmSetup *setup.CcmTestSetup,

logMessages := buffer.String()

for _, errWarnExpected := range proxyTlsConfig.errWarningsExpected {
require.True(t, strings.Contains(logMessages, errWarnExpected), "%v not found", errWarnExpected)
warningAssertionFailed := false
warningExpected := false
for _, expectedWarningMsg := range proxyTlsConfig.errWarningsExpected {
warningExpected = true
if !strings.Contains(logMessages, expectedWarningMsg) {
t.Logf("%v not found in %v", expectedWarningMsg, logMessages)
warningAssertionFailed = true
}
}

if proxyTlsConfig.errExpected {
require.NotNil(t, err, "Did not get expected error %s", proxyTlsConfig.errMsgExpected)
errorAssertionFailed := false
errorExpected := false
if proxyTlsConfig.errMsgExpected != "" {
require.True(t, strings.Contains(err.Error(), proxyTlsConfig.errMsgExpected), err.Error())
errorExpected = true
if !strings.Contains(err.Error(), proxyTlsConfig.errMsgExpected) {
errorAssertionFailed = true
t.Logf("%v not found in %v", err.Error(), proxyTlsConfig.errMsgExpected)
}
}
if errorExpected && warningExpected {
require.False(t, errorAssertionFailed && warningAssertionFailed) // only 1 check needs to pass in this scenario
} else if errorExpected {
require.False(t, errorAssertionFailed)
} else if warningExpected {
require.False(t, warningAssertionFailed)
}
} else {
require.False(t, warningAssertionFailed)
require.Nil(t, err, "testClient setup failed: %v", err)
// create schema on clusters through the proxy
sendRequest(cqlConn, "CREATE KEYSPACE IF NOT EXISTS testks "+
Expand Down Expand Up @@ -1340,6 +1361,14 @@ func getClientSideVerifyConnectionCallback(rootCAs *x509.CertPool) func(cs tls.C
}
}
_, err := cs.PeerCertificates[0].Verify(opts)
if err != nil {
// use zerolog to avoid interacting with proxy's log messages that are used for assertions in the test
zerologger.Warn().
Interface("cs", cs).
Interface("verifyopts", opts).
Interface("peercertificates", cs.PeerCertificates).
Msgf("client side verify callback error: %v", err)
}
return err
}
}
19 changes: 11 additions & 8 deletions proxy/pkg/zdmproxy/clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type ClientConnector struct {
clientHandlerContext context.Context
clientHandlerCancelFunc context.CancelFunc

// not used atm but should be used when a protocol error occurs after #68 has been addressed
clientHandlerShutdownRequestCancelFn context.CancelFunc

writeCoalescer *writeCoalescer
Expand Down Expand Up @@ -165,6 +166,7 @@ func (cc *ClientConnector) listenForRequests() {
bufferedReader := bufio.NewReaderSize(cc.connection, cc.conf.RequestWriteBufferSizeBytes)
connectionAddr := cc.connection.RemoteAddr().String()
protocolErrOccurred := false
var alreadySentProtocolErr *frame.RawFrame
for cc.clientHandlerContext.Err() == nil {
f, err := readRawFrame(bufferedReader, connectionAddr, cc.clientHandlerContext)

Expand All @@ -174,14 +176,15 @@ func (cc *ClientConnector) listenForRequests() {
err, cc.clientHandlerContext, cc.clientHandlerCancelFunc, ClientConnectorLogPrefix, "reading", connectionAddr)
break
} else if protocolErrResponseFrame != nil {
f = protocolErrResponseFrame
if !protocolErrOccurred {
protocolErrOccurred = true
cc.sendResponseToClient(protocolErrResponseFrame)
cc.clientHandlerShutdownRequestCancelFn()
setDrainModeNowFunc()
continue
}
alreadySentProtocolErr = protocolErrResponseFrame
protocolErrOccurred = true
cc.sendResponseToClient(protocolErrResponseFrame)
continue
} else if alreadySentProtocolErr != nil {
clonedProtocolErr := alreadySentProtocolErr.Clone()
clonedProtocolErr.Header.StreamId = f.Header.StreamId
cc.sendResponseToClient(clonedProtocolErr)
continue
}

wg.Add(1)
Expand Down
5 changes: 3 additions & 2 deletions proxy/pkg/zdmproxy/clienthandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ type ClientHandler struct {
parameterModifier *ParameterModifier
timeUuidGenerator TimeUuidGenerator

// not used atm but should be used when a protocol error occurs after #68 has been addressed
clientHandlerShutdownRequestCancelFn context.CancelFunc
clientHandlerShutdownRequestContext context.Context

clientHandlerShutdownRequestContext context.Context
}

func NewClientHandler(
Expand Down Expand Up @@ -643,7 +645,6 @@ func (ch *ClientHandler) tryProcessProtocolError(response *Response, protocolErr
errMsg, response.connectorType)
}
ch.clientConnector.sendResponseToClient(response.responseFrame)
ch.clientHandlerShutdownRequestCancelFn()
}
return true
}
Expand Down

0 comments on commit ee0181b

Please sign in to comment.