From 782a98b92301a5e44ad7045efdec545908a6bf62 Mon Sep 17 00:00:00 2001 From: Apin Date: Mon, 8 Jan 2024 14:56:02 +0700 Subject: [PATCH] chore: adding validation message size --- kafka.go | 40 ++++++++++-- kafka_test.go | 169 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 205 insertions(+), 4 deletions(-) diff --git a/kafka.go b/kafka.go index f6e0b98..6cacc97 100644 --- a/kafka.go +++ b/kafka.go @@ -42,13 +42,16 @@ const ( auditLogTopicEnvKey = "APP_EVENT_STREAM_AUDIT_LOG_TOPIC" auditLogEnableEnvKey = "APP_EVENT_STREAM_AUDIT_LOG_ENABLED" auditLogTopicDefault = "auditLog" + + messageAdditionalSizeApprox = 2048 // in Byte. Approx data added to message that sent to kafka ) var ( - auditLogTopic = "" - auditEnabled = true - errPubNilEvent = errors.New("unable to publish nil event") - errSubNilEvent = errors.New("unable to subscribe nil event") + auditLogTopic = "" + auditEnabled = true + errPubNilEvent = errors.New("unable to publish nil event") + errSubNilEvent = errors.New("unable to subscribe nil event") + ErrMessageTooLarge = errors.New("message to large") ) // KafkaClient wraps client's functionality for Kafka @@ -217,6 +220,10 @@ func (client *KafkaClient) Publish(publishBuilder *PublishBuilder) error { return fmt.Errorf("unable to construct event : %s , error : %v", publishBuilder.eventName, err) } + if err = client.validateMessageSize(message); err != nil { + return err + } + config := client.configMap if publishBuilder.timeout == 0 { @@ -289,6 +296,10 @@ func (client *KafkaClient) PublishSync(publishBuilder *PublishBuilder) error { return fmt.Errorf("unable to construct event : %s , error : %v", publishBuilder.eventName, err) } + if err = client.validateMessageSize(message); err != nil { + return err + } + config := client.configMap if publishBuilder.timeout == 0 { @@ -305,6 +316,27 @@ func (client *KafkaClient) PublishSync(publishBuilder *PublishBuilder) error { return client.publishEvent(publishBuilder.ctx, topic, publishBuilder.eventName, config, message) } +func (client *KafkaClient) validateMessageSize(msg *kafka.Message) error { + maxSize := 1048576 // default size from kafka in bytes + if client.configMap != nil { + // https://github.com/confluentinc/librdkafka/blob/master/CONFIGURATION.md + if valInterface, ok := (*client.configMap)["message.max.bytes"]; ok { + if intValue, ok := valInterface.(int); ok { + maxSize = intValue + } else if intValue, ok := valInterface.(int32); ok { + maxSize = int(intValue) + } else if intValue, ok := valInterface.(int64); ok { + maxSize = int(intValue) + } + } + } + maxSize -= messageAdditionalSizeApprox + if len(msg.Key)+len(msg.Value) > maxSize { + return ErrMessageTooLarge + } + return nil +} + // Publish send event to a topic func (client *KafkaClient) publishEvent(ctx context.Context, topic, eventName string, config *kafka.ConfigMap, message *kafka.Message) (err error) { diff --git a/kafka_test.go b/kafka_test.go index 8a17ebf..577767f 100644 --- a/kafka_test.go +++ b/kafka_test.go @@ -18,8 +18,11 @@ package eventstream import ( "context" + "math/rand" "testing" + "time" + "github.com/google/uuid" "github.com/stretchr/testify/assert" ) @@ -329,3 +332,169 @@ func TestKafkaSubNilCallback(t *testing.T) { assert.Equal(t, errInvalidCallback, err, "error should be equal") } + +var seededRand *rand.Rand = rand.New( + rand.NewSource(time.Now().UnixNano())) + +func randomString(length int) string { + charset := "abcdefghijklmnopqrstuvwxyz" + + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + b := make([]byte, length) + for i := range b { + b[i] = charset[seededRand.Intn(len(charset))] + } + return string(b) +} + +func makePayload(keyLength, messageLength int) map[string]interface{} { + ret := make(map[string]interface{}) + for i := 0; i < keyLength; i++ { + ret[uuid.NewString()] = randomString(messageLength) + } + return ret +} + +func TestKafkaMaxMessageSize(t *testing.T) { + t.Parallel() + client := createKafkaClient(t) + topicName := constructTopicTest() + + testCases := []struct { + Payload map[string]interface{} + Err error + }{ + {Payload: makePayload(10, 1000), Err: nil}, + {Payload: makePayload(2000, 1000), Err: ErrMessageTooLarge}, + } + + for _, testCase := range testCases { + var mockPayload = testCase.Payload + + mockAdditionalFields := map[string]interface{}{ + "summary": "user:_failed", + } + + mockEvent := &Event{ + EventName: "testEvent", + Namespace: "event", + ClientID: "661a4ac82b854f3ca3ac2e0377d356e4", + TraceID: "5005e27d01064f23b962e8fd2e560a8a", + SpanContext: "test-span-context", + UserID: "661a4ac82b854f3ca3ac2e0377d356e4", + EventID: 3, + EventType: 301, + EventLevel: 3, + ServiceName: "test", + ClientIDs: []string{"7d480ce0e8624b02901bd80d9ba9817c"}, + TargetUserIDs: []string{"1fe7f425a0e049d29d87ca3d32e45b5a"}, + TargetNamespace: "publisher", + Privacy: true, + AdditionalFields: mockAdditionalFields, + Version: 2, + Payload: mockPayload, + } + + err := client.Publish( + NewPublish(). + Topic(topicName). + EventName(mockEvent.EventName). + Namespace(mockEvent.Namespace). + ClientID(mockEvent.ClientID). + UserID(mockEvent.UserID). + SessionID(mockEvent.SessionID). + TraceID(mockEvent.TraceID). + SpanContext(mockEvent.SpanContext). + EventID(mockEvent.EventID). + EventType(mockEvent.EventType). + EventLevel(mockEvent.EventLevel). + ServiceName(mockEvent.ServiceName). + ClientIDs(mockEvent.ClientIDs). + TargetUserIDs(mockEvent.TargetUserIDs). + TargetNamespace(mockEvent.TargetNamespace). + Privacy(mockEvent.Privacy). + AdditionalFields(mockEvent.AdditionalFields). + Version(2). + Context(context.Background()). + Payload(mockPayload)) + + assert.Equal(t, testCase.Err, err) + } +} + +func TestKafkaMaxMessageSizeModified(t *testing.T) { + t.Parallel() + + config := &BrokerConfig{ + CACertFile: "", + StrictValidation: true, + DialTimeout: 2 * time.Second, + BaseConfig: map[string]interface{}{ + "message.max.bytes": 4096, + }, + } + + brokerList := []string{"localhost:9092"} + client, _ := NewClient(prefix, eventStreamKafka, brokerList, config) + topicName := constructTopicTest() + + testCases := []struct { + Payload map[string]interface{} + Err error + }{ + {Payload: makePayload(1, 1000), Err: nil}, + {Payload: makePayload(10, 1000), Err: ErrMessageTooLarge}, + } + + for _, testCase := range testCases { + var mockPayload = testCase.Payload + + mockAdditionalFields := map[string]interface{}{ + "summary": "user:_failed", + } + + mockEvent := &Event{ + EventName: "testEvent", + Namespace: "event", + ClientID: "661a4ac82b854f3ca3ac2e0377d356e4", + TraceID: "5005e27d01064f23b962e8fd2e560a8a", + SpanContext: "test-span-context", + UserID: "661a4ac82b854f3ca3ac2e0377d356e4", + EventID: 3, + EventType: 301, + EventLevel: 3, + ServiceName: "test", + ClientIDs: []string{"7d480ce0e8624b02901bd80d9ba9817c"}, + TargetUserIDs: []string{"1fe7f425a0e049d29d87ca3d32e45b5a"}, + TargetNamespace: "publisher", + Privacy: true, + AdditionalFields: mockAdditionalFields, + Version: 2, + Payload: mockPayload, + } + + err := client.Publish( + NewPublish(). + Topic(topicName). + EventName(mockEvent.EventName). + Namespace(mockEvent.Namespace). + ClientID(mockEvent.ClientID). + UserID(mockEvent.UserID). + SessionID(mockEvent.SessionID). + TraceID(mockEvent.TraceID). + SpanContext(mockEvent.SpanContext). + EventID(mockEvent.EventID). + EventType(mockEvent.EventType). + EventLevel(mockEvent.EventLevel). + ServiceName(mockEvent.ServiceName). + ClientIDs(mockEvent.ClientIDs). + TargetUserIDs(mockEvent.TargetUserIDs). + TargetNamespace(mockEvent.TargetNamespace). + Privacy(mockEvent.Privacy). + AdditionalFields(mockEvent.AdditionalFields). + Version(2). + Context(context.Background()). + Payload(mockPayload)) + + assert.Equal(t, testCase.Err, err) + } +}