Skip to content

Commit

Permalink
Merge pull request #30 from buildkite/fix-dangling-queues-on-startup
Browse files Browse the repository at this point in the history
Fix dangling queues on startup (closes #25).
  • Loading branch information
lox authored Sep 4, 2018
2 parents ec36eda + c015937 commit ca7559f
Show file tree
Hide file tree
Showing 6 changed files with 466 additions and 65 deletions.
19 changes: 19 additions & 0 deletions daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,26 @@ type Daemon struct {
Signals chan os.Signal
}

// Start the daemon.
func (d *Daemon) Start(ctx context.Context) error {
if err := d.Queue.Create(); err != nil {
return err
}
defer func() {
if err := d.Queue.Delete(); err != nil {
log.WithError(err).Error("Failed to delete queue")
}
}()

if err := d.Queue.Subscribe(); err != nil {
return err
}
defer func() {
if err := d.Queue.Unsubscribe(); err != nil {
log.WithError(err).Error("Failed to unsubscribe from sns topic")
}
}()

ch := make(chan *sqs.Message)

go func() {
Expand Down
12 changes: 3 additions & 9 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,10 @@ func main() {
instanceID = id
}

sess := session.New()
queue, err := CreateQueue(sess, generateQueueName(instanceID), snsTopic)
sess, err := session.NewSession()
if err != nil {
log.Fatal(err)
log.WithError(err).Fatal("Failed to create new session")
}
defer func() {
if err = queue.Delete(); err != nil {
log.Fatalf("Failed to delete queue: %v", err)
}
}()

sigs := make(chan os.Signal, 2)
defer close(sigs)
Expand Down Expand Up @@ -114,7 +108,7 @@ func main() {
AutoScaling: autoscaling.New(sess),
Handler: handler,
Signals: sigs,
Queue: queue,
Queue: NewQueue(sess, generateQueueName(instanceID), snsTopic),
}

return daemon.Start(ctx)
Expand Down
145 changes: 89 additions & 56 deletions queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ import (
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sns"
"github.com/aws/aws-sdk-go/service/sns/snsiface"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/aws/aws-sdk-go/service/sqs/sqsiface"
)

const queuePolicy = `
Expand All @@ -38,77 +40,103 @@ const (
longPollingWaitTimeSeconds = 20
)

// SQSClient for testing purposes (TODO: Gomock).
type SQSClient sqsiface.SQSAPI

// SNSClient for testing purposes (TODO: Gomock).
type SNSClient snsiface.SNSAPI

// Queue manages the SQS queue and SNS subscription.
type Queue struct {
name string
url string
arn string
subscription string
session *session.Session
name string
url string
arn string
topicArn string
subscriptionArn string

sqsClient SQSClient
snsClient SNSClient
}

func CreateQueue(sess *session.Session, queueName string, topicARN string) (*Queue, error) {
sqsAPI := sqs.New(sess)
snsAPI := sns.New(sess)
// NewQueue returns a new... Queue.
func NewQueue(sess *session.Session, queueName, topicArn string) *Queue {
return &Queue{
name: queueName,
topicArn: topicArn,
sqsClient: sqs.New(sess),
snsClient: sns.New(sess),
}
}

log.WithFields(log.Fields{"queue": queueName}).Debug("Creating sqs queue")
resp, err := sqsAPI.CreateQueue(&sqs.CreateQueueInput{
QueueName: aws.String(queueName),
// Create the SQS queue.
func (q *Queue) Create() error {
log.WithFields(log.Fields{"queue": q.name}).Debug("Creating sqs queue")
out, err := q.sqsClient.CreateQueue(&sqs.CreateQueueInput{
QueueName: aws.String(q.name),
Attributes: map[string]*string{
"Policy": aws.String(fmt.Sprintf(queuePolicy, topicARN)),
"Policy": aws.String(fmt.Sprintf(queuePolicy, q.topicArn)),
"ReceiveMessageWaitTimeSeconds": aws.String(strconv.Itoa(longPollingWaitTimeSeconds)),
},
})
if err != nil {
return nil, err
return err
}
q.url = aws.StringValue(out.QueueUrl)
return nil
}

log.WithFields(log.Fields{"queue": queueName}).Debug("Looking up sqs queue url")
attrs, err := sqsAPI.GetQueueAttributes(&sqs.GetQueueAttributesInput{
AttributeNames: aws.StringSlice([]string{"QueueArn"}),
QueueUrl: resp.QueueUrl,
})
if err != nil {
return nil, err
// GetArn for the SQS queue.
func (q *Queue) getArn() (string, error) {
if q.arn == "" {
log.WithFields(log.Fields{"queue": q.name}).Debug("Looking up sqs queue arn")
out, err := q.sqsClient.GetQueueAttributes(&sqs.GetQueueAttributesInput{
AttributeNames: aws.StringSlice([]string{"QueueArn"}),
QueueUrl: aws.String(q.url),
})
if err != nil {
return "", err
}
arn, ok := out.Attributes["QueueArn"]
if !ok {
return "", errors.New("No attribute QueueArn")
}
q.arn = aws.StringValue(arn)
}
return q.arn, nil
}

arn, ok := attrs.Attributes["QueueArn"]
if !ok {
return nil, errors.New("No attribute QueueArn")
}
// Subscribe the queue to an SNS topic
func (q *Queue) Subscribe() error {
log.WithFields(log.Fields{"queue": q.name, "topic": q.topicArn}).Debug("Subscribing queue to sns topic")

log.WithFields(log.Fields{"queue": queueName, "topic": topicARN}).Debug("Subscribing queue to sns topic")
subscr, err := snsAPI.Subscribe(&sns.SubscribeInput{
arn, err := q.getArn()
if err != nil {
return err
}
out, err := q.snsClient.Subscribe(&sns.SubscribeInput{
TopicArn: aws.String(q.topicArn),
Protocol: aws.String("sqs"),
TopicArn: aws.String(topicARN),
Endpoint: arn,
Endpoint: aws.String(arn),
})
if err != nil {
return nil, err
return err
}

return &Queue{
name: queueName,
url: *resp.QueueUrl,
subscription: *subscr.SubscriptionArn,
arn: *arn,
session: sess,
}, nil
q.subscriptionArn = aws.StringValue(out.SubscriptionArn)
return nil
}

// Receive a message from the SQS queue.
func (q *Queue) Receive(ctx context.Context, ch chan *sqs.Message) error {
// Close channel before returning since this is the sending side.
defer close(ch)
log.WithFields(log.Fields{"queueURL": q.url}).Debugf("Polling sqs for messages")

sqsAPI := sqs.New(q.session)
defer close(ch) // Close channel before returning since this is the sending side.

Loop:
for {
select {
case <-ctx.Done():
break Loop
default:
resp, err := sqsAPI.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{
out, err := q.sqsClient.ReceiveMessageWithContext(ctx, &sqs.ReceiveMessageInput{
QueueUrl: aws.String(q.url),
MaxNumberOfMessages: aws.Int64(1),
WaitTimeSeconds: aws.Int64(longPollingWaitTimeSeconds),
Expand All @@ -121,8 +149,8 @@ Loop:
}
return err
}
for _, m := range resp.Messages {
sqsAPI.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{
for _, m := range out.Messages {
q.sqsClient.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{
QueueUrl: aws.String(q.url),
ReceiptHandle: m.ReceiptHandle,
})
Expand All @@ -133,21 +161,26 @@ Loop:
return nil
}

func (q *Queue) Delete() error {
sqsAPI := sqs.New(q.session)
snsAPI := sns.New(q.session)

log.WithFields(log.Fields{"arn": q.subscription}).Debugf("Deleting sns subscription")
_, err := snsAPI.Unsubscribe(&sns.UnsubscribeInput{
SubscriptionArn: aws.String(q.subscription),
// Unsubscribe the queue from the SNS topic.
func (q *Queue) Unsubscribe() error {
log.WithFields(log.Fields{"arn": q.subscriptionArn}).Debugf("Deleting sns subscription")
_, err := q.snsClient.Unsubscribe(&sns.UnsubscribeInput{
SubscriptionArn: aws.String(q.subscriptionArn),
})
if err != nil {
return err
}
return err
}

// Delete the SQS queue.
func (q *Queue) Delete() error {
log.WithFields(log.Fields{"url": q.url}).Debugf("Deleting sqs queue")
_, err = sqsAPI.DeleteQueue(&sqs.DeleteQueueInput{
_, err := q.sqsClient.DeleteQueue(&sqs.DeleteQueueInput{
QueueUrl: aws.String(q.url),
})
return err
if err != nil {
// Ignore error if queue does not exist (which is what we want)
if e, ok := err.(awserr.Error); !ok || e.Code() != sqs.ErrCodeQueueDoesNotExist {
return err
}
}
return nil
}
Loading

0 comments on commit ca7559f

Please sign in to comment.