diff --git a/queue.go b/queue.go index 2409bf3..a027a48 100644 --- a/queue.go +++ b/queue.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "os" "strconv" "github.com/aws/aws-sdk-go/aws" @@ -17,7 +18,16 @@ import ( const ( longPollingWaitTimeSeconds = 20 - queuePolicy = ` +) + +var queuePolicy = loadQueuePolicy() + +func loadQueuePolicy() string { + filePath := "/etc/lifecycled/queue_policy.json" + policy, err := os.ReadFile(filePath) + if err != nil { + // Handle the error, e.g., by providing a default policy + return ` { "Version":"2012-10-17", "Statement":[ @@ -35,7 +45,9 @@ const ( ] } ` -) + } + return string(policy) +} // SQSClient for testing purposes //go:generate mockgen -destination=mocks/mock_sqs_client.go -package=mocks github.com/buildkite/lifecycled SQSClient @@ -69,16 +81,24 @@ func NewQueue(queueName, topicArn string, sqsClient SQSClient, snsClient SNSClie // Create the SQS queue. func (q *Queue) Create() error { + attributes := map[string]*string{ + "Policy": aws.String(fmt.Sprintf(queuePolicy, q.topicArn)), + "ReceiveMessageWaitTimeSeconds": aws.String(strconv.Itoa(longPollingWaitTimeSeconds)), + } + + kmsMasterKeyID := os.Getenv("KMS_MASTER_KEY_ID") + if kmsMasterKeyID != "" { + attributes["KMSMasterKeyId"] = aws.String(kmsMasterKeyID) + } + out, err := q.sqsClient.CreateQueue(&sqs.CreateQueueInput{ - QueueName: aws.String(q.name), - Attributes: map[string]*string{ - "Policy": aws.String(fmt.Sprintf(queuePolicy, q.topicArn)), - "ReceiveMessageWaitTimeSeconds": aws.String(strconv.Itoa(longPollingWaitTimeSeconds)), - }, + QueueName: aws.String(q.name), + Attributes: attributes, }) if err != nil { return err } + q.url = aws.StringValue(out.QueueUrl) return nil }