Skip to content

Commit

Permalink
chore: better handling of reconnects
Browse files Browse the repository at this point in the history
  • Loading branch information
fritterhoff committed Mar 9, 2024
1 parent ab57a5c commit b236139
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 109 deletions.
98 changes: 49 additions & 49 deletions acme/mqtt/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,56 @@ func Connect(acmeDB acme.DB, host, user, password, organization string) (validat
opts.OnConnectionLost = func(cl mqtt.Client, err error) {
logrus.Println("mqtt connection lost")
}
opts.OnConnect = func(mqtt.Client) {
opts.OnConnect = func(cl mqtt.Client) {
logrus.Println("mqtt connection established")
go func() {
cl.Subscribe(fmt.Sprintf("%s/data", organization), 1, func(client mqtt.Client, msg mqtt.Message) {
logrus.Printf("Received message on topic: %s\nMessage: %s\n", msg.Topic(), msg.Payload())
ctx := context.Background()
data := msg.Payload()
var payload validation.ValidationResponse
err := json.Unmarshal(data, &payload)
if err != nil {
logrus.Errorf("error unmarshalling payload: %v", err)
return
}

ch, err := acmeDB.GetChallenge(ctx, payload.Challenge, payload.Authz)
if err != nil {
logrus.Errorf("error getting challenge: %v", err)
return
}

acc, err := acmeDB.GetAccount(ctx, ch.AccountID)
if err != nil {
logrus.Errorf("error getting account: %v", err)
return
}
expected, err := acme.KeyAuthorization(ch.Token, acc.Key)

if payload.Content != expected || err != nil {
logrus.Errorf("invalid key authorization: %v", err)
return
}
u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
logrus.Infof("challenge %s validated using mqtt", u.String())

if ch.Status != acme.StatusPending && ch.Status != acme.StatusValid {
return
}

ch.Status = acme.StatusValid
ch.Error = nil
ch.ValidatedAt = clock.Now().Format(time.RFC3339)

if err = acmeDB.UpdateChallenge(ctx, ch); err != nil {
logrus.Errorf("error updating challenge: %v", err)
} else {
logrus.Infof("challenge %s updated to valid", u.String())
}

})
}()
}
opts.OnReconnecting = func(mqtt.Client, *mqtt.ClientOptions) {
logrus.Println("mqtt attempting to reconnect")
Expand All @@ -47,54 +95,6 @@ func Connect(acmeDB acme.DB, host, user, password, organization string) (validat
return nil, token.Error()
}

go func() {
client.Subscribe(fmt.Sprintf("%s/data", organization), 1, func(client mqtt.Client, msg mqtt.Message) {
logrus.Printf("Received message on topic: %s\nMessage: %s\n", msg.Topic(), msg.Payload())
ctx := context.Background()
data := msg.Payload()
var payload validation.ValidationResponse
err := json.Unmarshal(data, &payload)
if err != nil {
logrus.Errorf("error unmarshalling payload: %v", err)
return
}

ch, err := acmeDB.GetChallenge(ctx, payload.Challenge, payload.Authz)
if err != nil {
logrus.Errorf("error getting challenge: %v", err)
return
}

acc, err := acmeDB.GetAccount(ctx, ch.AccountID)
if err != nil {
logrus.Errorf("error getting account: %v", err)
return
}
expected, err := acme.KeyAuthorization(ch.Token, acc.Key)

if payload.Content != expected || err != nil {
logrus.Errorf("invalid key authorization: %v", err)
return
}
u := &url.URL{Scheme: "http", Host: ch.Value, Path: fmt.Sprintf("/.well-known/acme-challenge/%s", ch.Token)}
logrus.Infof("challenge %s validated using mqtt", u.String())

if ch.Status != acme.StatusPending && ch.Status != acme.StatusValid {
return
}

ch.Status = acme.StatusValid
ch.Error = nil
ch.ValidatedAt = clock.Now().Format(time.RFC3339)

if err = acmeDB.UpdateChallenge(ctx, ch); err != nil {
logrus.Errorf("error updating challenge: %v", err)
} else {
logrus.Infof("challenge %s updated to valid", u.String())
}

})
}()
connection := validation.BrokerConnection{Client: client, Organization: organization}
return connection, nil
}
120 changes: 60 additions & 60 deletions cmd/step-agent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,68 @@ var agent = cli.Command{
options.OnConnectionLost = func(cl mqtt.Client, err error) {
logrus.Println("mqtt connection lost")
}
options.OnConnect = func(mqtt.Client) {
options.OnConnect = func(cl mqtt.Client) {
logrus.Println("mqtt connection established")
// Subscribe to topic
token := cl.Subscribe(fmt.Sprintf("%s/jobs", c.String("organization")), 0, func(client mqtt.Client, msg mqtt.Message) {
logrus.Infof("received message on topic %s", msg.Topic())
logrus.Infof("message: %s", msg.Payload())

var data validation.ValidationRequest

req := msg.Payload()
json.Unmarshal(req, &data)

logger := logrus.WithField("authz", data.Authz).WithField("target", data.Target).WithField("account", data.Challenge)

http := acme.NewClient()
resp, err := http.Get(data.Target)
if err != nil {
logger.WithError(err).Warn("validating failed")
return
}

defer resp.Body.Close()
if resp.StatusCode >= 400 {
logger.Warnf("validation for %s failed with error: %s", data.Target, resp.Status)
return
}

body, err := io.ReadAll(resp.Body)
if err != nil {
logger.WithError(err).Warn("parsing body failed")
return
}

keyAuth := strings.TrimSpace(string(body))
logger.Infof("keyAuth: %s", keyAuth)

json, err := json.Marshal(&validation.ValidationResponse{
Authz: data.Authz,
Challenge: data.Challenge,
Content: keyAuth,
})
if err != nil {
logger.WithError(err).Warn("marshalling failed")
return
}
// Publish to topic
token := cl.Publish(fmt.Sprintf("%s/data", c.String("organization")), 0, false, json)
if token.WaitTimeout(30*time.Second) && token.Error() != nil {
logger.WithError(token.Error()).Warn("publishing failed")
} else {
logger.Infof("published to topic %s", fmt.Sprintf("%s/data", c.String("organization")))
}

})

if token.WaitTimeout(30*time.Second) && token.Error() != nil {
logrus.WithError(token.Error()).Warn("subscribing failed")
} else {
logrus.Infof("subscribed to topic %s", fmt.Sprintf("%s/jobs", c.String("organization")))
}
}

options.OnReconnecting = func(mqtt.Client, *mqtt.ClientOptions) {
logrus.Println("mqtt reconnecting")
}
Expand All @@ -72,65 +131,6 @@ var agent = cli.Command{
logrus.Warn(token.Error())
}

// Subscribe to topic
token := client.Subscribe(fmt.Sprintf("%s/jobs", c.String("organization")), 0, func(client mqtt.Client, msg mqtt.Message) {
logrus.Infof("received message on topic %s", msg.Topic())
logrus.Infof("message: %s", msg.Payload())

var data validation.ValidationRequest

req := msg.Payload()
json.Unmarshal(req, &data)

logger := logrus.WithField("authz", data.Authz).WithField("target", data.Target).WithField("account", data.Challenge)

http := acme.NewClient()
resp, err := http.Get(data.Target)
if err != nil {
logger.WithError(err).Warn("validating failed")
return
}

defer resp.Body.Close()
if resp.StatusCode >= 400 {
logger.Warnf("validation for %s failed with error: %s", data.Target, resp.Status)
return
}

body, err := io.ReadAll(resp.Body)
if err != nil {
logger.WithError(err).Warn("parsing body failed")
return
}

keyAuth := strings.TrimSpace(string(body))
logger.Infof("keyAuth: %s", keyAuth)

json, err := json.Marshal(&validation.ValidationResponse{
Authz: data.Authz,
Challenge: data.Challenge,
Content: keyAuth,
})
if err != nil {
logger.WithError(err).Warn("marshalling failed")
return
}
// Publish to topic
token := client.Publish(fmt.Sprintf("%s/data", c.String("organization")), 0, false, json)
if token.WaitTimeout(30*time.Second) && token.Error() != nil {
logger.WithError(token.Error()).Warn("publishing failed")
} else {
logger.Infof("published to topic %s", fmt.Sprintf("%s/data", c.String("organization")))
}

})

if token.WaitTimeout(30*time.Second) && token.Error() != nil {
logrus.WithError(token.Error()).Warn("subscribing failed")
} else {
logrus.Infof("subscribed to topic %s", fmt.Sprintf("%s/jobs", c.String("organization")))
}

return nil
},
}
Expand Down

0 comments on commit b236139

Please sign in to comment.