Skip to content

Commit

Permalink
Remove invalid hostname characters from nats client id
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Wilde <[email protected]>
  • Loading branch information
ewilde authored and alexellis committed Jan 8, 2019
1 parent 0eae079 commit 4d38388
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 8 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ WORKDIR /go/src/github.com/openfaas/nats-queue-worker

COPY vendor vendor
COPY handler handler
COPY nats nats
COPY main.go .
COPY readconfig.go .
COPY readconfig_test.go .
Expand Down
1 change: 1 addition & 0 deletions Dockerfile.armhf
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ WORKDIR /go/src/github.com/openfaas/nats-queue-worker

COPY vendor vendor
COPY handler handler
COPY nats nats
COPY main.go .
COPY readconfig.go .
COPY readconfig_test.go .
Expand Down
3 changes: 2 additions & 1 deletion handler/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package handler

import (
"github.com/openfaas/nats-queue-worker/nats"
"os"
"strings"
"testing"
Expand All @@ -12,7 +13,7 @@ func Test_GetClientID_ContainsHostname(t *testing.T) {
val := c.GetClientID()

hostname, _ := os.Hostname()
encodedHostname := supportedCharacters.ReplaceAllString(hostname, "_")
encodedHostname := nats.GetClientID(hostname)
if !strings.HasSuffix(val, encodedHostname) {
t.Errorf("GetClientID should contain hostname as suffix, got: %s", val)
t.Fail()
Expand Down
7 changes: 3 additions & 4 deletions handler/nats_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package handler

import (
"os"
"regexp"

"github.com/openfaas/nats-queue-worker/nats"
)

type NatsConfig interface {
Expand All @@ -12,14 +13,12 @@ type NatsConfig interface {
type DefaultNatsConfig struct {
}

var supportedCharacters, _ = regexp.Compile("[^a-zA-Z0-9-_]+")

// GetClientID returns the ClientID assigned to this producer/consumer.
func (DefaultNatsConfig) GetClientID() string {
val, _ := os.Hostname()
return getClientID(val)
}

func getClientID(hostname string) string {
return "faas-publisher-" + supportedCharacters.ReplaceAllString(hostname, "_")
return "faas-publisher-" + nats.GetClientID(hostname)
}
6 changes: 3 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/nats-io/go-nats-streaming"
"github.com/openfaas/faas-provider/auth"
"github.com/openfaas/faas/gateway/queue"
"github.com/openfaas/nats-queue-worker/nats"
)

// AsyncReport is the report from a function executed on a queue worker.
Expand Down Expand Up @@ -51,12 +52,11 @@ func makeClient() http.Client {
func main() {
readConfig := ReadConfig{}
config := readConfig.Read()

log.SetFlags(0)

clusterID := "faas-cluster"
val, _ := os.Hostname()
clientID := "faas-worker-" + val
clientID := "faas-worker-" + nats.GetClientID(val)

var durable string
var qgroup string
Expand All @@ -75,7 +75,7 @@ func main() {
client := makeClient()
sc, err := stan.Connect(clusterID, clientID, stan.NatsURL("nats://"+config.NatsAddress+":4222"))
if err != nil {
log.Fatalf("Can't connect: %v\n", err)
log.Fatalf("Can't connect to %s: %v\n", "nats://"+config.NatsAddress+":4222", err)
}

startOpt := stan.StartWithLastReceived()
Expand Down
8 changes: 8 additions & 0 deletions nats/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package nats

import "regexp"

var supportedCharacters = regexp.MustCompile("[^a-zA-Z0-9-_]+")
func GetClientID(value string) string {
return supportedCharacters.ReplaceAllString(value, "_")
}
23 changes: 23 additions & 0 deletions nats/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package nats

import (
"testing"
)

func TestGetClientID(t *testing.T) {
clientID := GetClientID("computer-a")
want := "computer-a"
if clientID != want {
t.Logf("Want clientID: `%s`, but got: `%s`\n", want, clientID)
t.Fail()
}
}

func TestGetClientIDWhenHostHasUnsupportedCharacters(t *testing.T) {
clientID := GetClientID("computer-a.acme.com")
want := "computer-a_acme_com"
if clientID != want {
t.Logf("Want clientID: `%s`, but got: `%s`\n", want, clientID)
t.Fail()
}
}

0 comments on commit 4d38388

Please sign in to comment.