Skip to content

Commit

Permalink
apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
roblaszczak committed Oct 15, 2024
1 parent 8cf4b95 commit cee41b4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 10 deletions.
15 changes: 8 additions & 7 deletions sns/marshaler.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
package sns

import (
"github.com/ThreeDotsLabs/watermill-amazonsqs/sqs"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sns"
"github.com/aws/aws-sdk-go-v2/service/sns/types"

"github.com/ThreeDotsLabs/watermill/message"
)

// todo: check if it can be renamed
const UUIDAttribute = "UUID"

type Marshaler interface {
Marshal(topicArn TopicArn, msg *message.Message) *sns.PublishInput
}
Expand All @@ -21,9 +19,8 @@ func (d DefaultMarshalerUnmarshaler) Marshal(topicArn TopicArn, msg *message.Mes
// client side uuid
// there is a deduplication id that can be use for
// fifo queues
// todo: check how it works
attributes, deduplicationId, groupId := metadataToAttributes(msg.Metadata)
attributes[UUIDAttribute] = types.MessageAttributeValue{
attributes[sqs.UUIDAttribute] = types.MessageAttributeValue{
StringValue: aws.String(msg.UUID),
DataType: aws.String("String"),
}
Expand All @@ -42,11 +39,11 @@ func metadataToAttributes(meta message.Metadata) (map[string]types.MessageAttrib
var deduplicationId, groupId *string
for k, v := range meta {
// SNS has special attributes for deduplication and group id
if k == "MessageDeduplicationId" {
if k == MessageDeduplicationIdMetadataField {
deduplicationId = aws.String(v)
continue
}
if k == "MessageGroupId" {
if k == MessageGroupIdMetadataField {
groupId = aws.String(v)
continue
}
Expand All @@ -58,3 +55,7 @@ func metadataToAttributes(meta message.Metadata) (map[string]types.MessageAttrib

return attributes, deduplicationId, groupId
}

const MessageDeduplicationIdMetadataField = "MessageDeduplicationId"

const MessageGroupIdMetadataField = "MessageGroupId"
7 changes: 5 additions & 2 deletions sns/subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (s *Subscriber) SubscribeInitializeWithContext(ctx context.Context, topic s
}

if !s.config.DoNotSetQueueAccessPolicy {
if err := s.setSqsQuePolicy(ctx, *sqsQueueArn, snsTopicArn, *sqsURL); err != nil {
if err := s.setSqsQueuePolicy(ctx, *sqsQueueArn, snsTopicArn, *sqsURL); err != nil {
return fmt.Errorf("cannot set queue access policy for topic %s: %w", snsTopicArn, err)
}
}
Expand Down Expand Up @@ -147,12 +147,15 @@ func (s *Subscriber) SubscribeInitializeWithContext(ctx context.Context, topic s
return nil
}

func (s *Subscriber) setSqsQuePolicy(ctx context.Context, sqsQueueArn sqs.QueueArn, snsTopicArn TopicArn, sqsURL sqs.QueueURL) error {
func (s *Subscriber) setSqsQueuePolicy(ctx context.Context, sqsQueueArn sqs.QueueArn, snsTopicArn TopicArn, sqsURL sqs.QueueURL) error {
policy, err := s.config.GenerateQueueAccessPolicy(ctx, GenerateQueueAccessPolicyParams{
SqsQueueArn: sqsQueueArn,
SnsTopicArn: snsTopicArn,
SqsURL: sqsURL,
})
if err != nil {
return fmt.Errorf("cannot generate queue access policy: %w", err)
}

policyJSON, err := json.Marshal(policy)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion sqs/marshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/ThreeDotsLabs/watermill/message"

Check failure on line 7 in sqs/marshaler.go

View workflow job for this annotation

GitHub Actions / ci / build

missing go.sum entry for module providing package github.com/ThreeDotsLabs/watermill/message (imported by github.com/ThreeDotsLabs/watermill-amazonsqs/sns); to add:
)

const UUIDAttribute = "UUID"
const UUIDAttribute = "_watermill_message_uuid"

type Marshaler interface {
Marshal(msg *message.Message) (*types.Message, error)
Expand Down

0 comments on commit cee41b4

Please sign in to comment.