diff --git a/broker_test.go b/broker_test.go index 71cf90f..b6be231 100644 --- a/broker_test.go +++ b/broker_test.go @@ -10,6 +10,15 @@ import ( "github.com/smartystreets/goconvey/convey" ) +func newMockBroker() *Broker { + return &Broker{ + address: "localhost:9092", + nodeID: 0, + config: &BrokerConfig{}, + conn: &MockConn{}, + } +} + func TestNewBroker(t *testing.T) { mockey.PatchConvey("TestNewBroker", t, func() { mockey.Mock((*Broker).requestAPIVersions).Return(APIVersionsResponse{}, nil).Build() diff --git a/simple_consumer.go b/simple_consumer.go index 33c1bbd..09cc856 100644 --- a/simple_consumer.go +++ b/simple_consumer.go @@ -4,7 +4,10 @@ import ( "context" "errors" "fmt" + "io" + "os" "sync" + "syscall" "time" ) @@ -352,7 +355,7 @@ func (c *SimpleConsumer) Consume(offset int64, messageChan chan *FullMessage) (< for !c.stop { if err = c.getLeaderBroker(); err != nil { - logger.Error(err, "get leader broker of [%s/%d] error: %s", "topic", c.topic, "partitionID", c.partitionID) + logger.Error(err, "get leader broker error", "topic", c.topic, "partitionID", c.partitionID) time.Sleep(time.Millisecond * time.Duration(c.config.RetryBackOffMS)) } else { break @@ -464,7 +467,9 @@ func (c *SimpleConsumer) consumeMessages(innerMessages chan *FullMessage, messag } if message.Error != nil { logger.Error(message.Error, "message error", "topic", c.topic, "partitionID", c.partitionID) - if message.Error == &maxBytesTooSmall { + if os.IsTimeout(message.Error) || errors.Is(message.Error, io.EOF) || errors.Is(message.Error, syscall.EPIPE) { + c.leaderBroker.Close() + } else if message.Error == &maxBytesTooSmall { c.config.FetchMaxBytes *= 2 logger.Info("fetch.max.bytes is too small, double it", "new FetchMaxBytes", c.config.FetchMaxBytes) } diff --git a/simple_consumer_test.go b/simple_consumer_test.go index 552fcc3..c5417f6 100644 --- a/simple_consumer_test.go +++ b/simple_consumer_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "net" "testing" "github.com/bytedance/mockey" @@ -359,8 +360,8 @@ func TestOffsetOutofRangeConsume(t *testing.T) { }) } -func TestConsumeEOF(t *testing.T) { - mockey.PatchConvey("eof during consumping", t, func() { +func TestConsumeEOFInRequest(t *testing.T) { + mockey.PatchConvey("eof during consumping, eof during request(before response)", t, func() { topic := "testTopic" partitionID := 1 config := map[string]interface{}{ @@ -388,7 +389,7 @@ func TestConsumeEOF(t *testing.T) { To(func(fetchRequest *FetchRequest) (r io.Reader, responseLength uint32, err error) { if failCount == 0 { t.Log("mock requestFetchStreamingly") - failCount++ + failCount = 1 return nil, 0, io.EOF } else { return nil, 0, nil @@ -462,3 +463,148 @@ func TestConsumeEOF(t *testing.T) { } }) } + +type mockErrTimeout struct{} + +func (e *mockErrTimeout) Timeout() bool { return true } +func (e *mockErrTimeout) Error() string { return "mock timeout error" } + +// Encountered EOF when reading response(first time and ok afterwards). ensure broker is closed and reopen. +func TestConsumeEOFInResponse(t *testing.T) { + mockey.PatchConvey("eof during consumping, eof during response", t, func() { + topic := "testTopic" + partitionID := 1 + config := map[string]interface{}{ + "bootstrap.servers": "localhost:9092", + "client.id": "healer-test", + "retry.backoff.ms": 0, + } + + mockey.Mock((*net.Dialer).Dial).Return(&MockConn{}, nil).Build() + + mockey.Mock(NewBrokersWithConfig).Return(&Brokers{}, nil).Build() + mockey.Mock(NewBroker).Return(newMockBroker(), nil).Build() + mockey.Mock((*SimpleConsumer).refreshPartiton).Return(nil).Build() + mockey.Mock((*SimpleConsumer).getLeaderBroker).To(func(s *SimpleConsumer) error { + s.leaderBroker = newMockBroker() + return nil + }).Build() + mockey.Mock((*SimpleConsumer).initOffset).Return().Build() + + brokerCloseOrigin := (*Broker).Close + brokerClose := mockey.Mock((*Broker).Close).To(func(broker *Broker) { (brokerCloseOrigin)(broker) }).Origin(&brokerCloseOrigin).Build() + createConn := mockey.Mock((*Broker).createConnAndAuth).To(func(broker *Broker) error { + broker.conn = &MockConn{} + return nil + }).Build() + + mockey.Mock((*Broker).requestStreamingly).Return(&MockConn{}, 0, nil).Build() // ensureOpen before requestStreamingly + + hasFailed := false + mockey.Mock((*fetchResponseStreamDecoder).streamDecode).To(func(decoder *fetchResponseStreamDecoder, ctx context.Context, startOffset int64) error { + if hasFailed { + for i := 0; i < 5; i++ { + select { + case <-decoder.ctx.Done(): + return nil + case decoder.messages <- &FullMessage{ + TopicName: topic, + PartitionID: int32(partitionID), + Error: nil, + Message: &Message{ + Offset: 1, + MessageSize: 10, + Crc: 1, + MagicByte: 1, + Attributes: 1, + Timestamp: 1, + Key: []byte("test"), + Value: []byte(fmt.Sprintf("test-%d", i)), + }, + }: + } + } + } else { + for i := 0; i < 3; i++ { + select { + case <-decoder.ctx.Done(): + return nil + case decoder.messages <- &FullMessage{ + TopicName: topic, + PartitionID: int32(partitionID), + Error: nil, + Message: &Message{ + Offset: 1, + MessageSize: 10, + Crc: 1, + MagicByte: 1, + Attributes: 1, + Timestamp: 1, + Key: []byte("test"), + Value: []byte(fmt.Sprintf("test-%d", i)), + }, + }: + } + } + + hasFailed = true + decoder.messages <- &FullMessage{ + TopicName: topic, + PartitionID: int32(partitionID), + Error: &mockErrTimeout{}, + Message: &Message{ + Offset: 1, + MessageSize: 10, + Crc: 1, + MagicByte: 1, + Attributes: 1, + Timestamp: 1, + Key: []byte("eof"), + Value: []byte("eof"), + }, + } + } + return nil + }).Build() + + type testCase struct { + messageChanLength int + maxMessage int + requestFetchStreaminglyCount []int + streamDecodeCount []int + } + for _, tc := range []testCase{ + { + messageChanLength: 0, + maxMessage: 10, + requestFetchStreaminglyCount: []int{2, 2}, + streamDecodeCount: []int{1, 1}, + }, + } { + t.Logf("test case: %+v", tc) + simpleConsumer, err := NewSimpleConsumer(topic, int32(partitionID), config) + convey.So(err, convey.ShouldBeNil) + convey.So(simpleConsumer, convey.ShouldNotBeNil) + convey.So(createConn.Times(), convey.ShouldEqual, 0) + + messages := make(chan *FullMessage, tc.messageChanLength) + msg, err := simpleConsumer.Consume(-2, messages) + convey.So(err, convey.ShouldBeNil) + convey.So(msg, convey.ShouldNotBeNil) + convey.So(createConn.Times(), convey.ShouldEqual, 0) + + count := 0 + for count < tc.maxMessage { + m := <-msg + t.Logf("msg: %s", string(m.Message.Value)) + count++ + } + simpleConsumer.Stop() + t.Log("stopped") + + convey.So(count, convey.ShouldEqual, tc.maxMessage) + convey.So(brokerClose.Times(), convey.ShouldEqual, 2) + convey.So(createConn.Times(), convey.ShouldEqual, 1) + } + }) +}