diff --git a/cmd/main.go b/cmd/main.go index bd49d13..0f9950a 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -137,6 +137,63 @@ func assignAddress(c context.Context, log *logrus.Entry, client kubernetes.Inter return "", errors.New("reached maximum number of retries") } +func waitForAddressToBeReported(c context.Context, log *logrus.Entry, explorer nd.Explorer, node *types.Node, assignedAddress string, cfg *config.Config) error { + ctx, cancel := context.WithCancel(c) + defer cancel() + + // ticker for retry interval + ticker := time.NewTicker(cfg.RetryInterval) + defer ticker.Stop() + + for retryCounter := 0; retryCounter <= cfg.RetryAttempts; retryCounter++ { + log.WithFields(logrus.Fields{ + "node": node.Name, + "instance": node.Instance, + "address": assignedAddress, + "retry-counter": retryCounter, + "retry-attempts": cfg.RetryAttempts, + }).Debug("Waiting for node to report assigned address") + + nodeInfo, err := explorer.GetNode(ctx, node.Name) + if err == nil { + for _, ip := range nodeInfo.ExternalIPs { + if ip.String() == assignedAddress { + log.WithFields(logrus.Fields{ + "node": node.Name, + "instance": node.Instance, + "address": assignedAddress, + "retry-counter": retryCounter, + "retry-attempts": cfg.RetryAttempts, + }).Info("Node is reporting assigned address") + return nil + } + } + log.WithError(err).WithFields(logrus.Fields{ + "node": node.Name, + "instance": node.Instance, + "address": assignedAddress, + }).Warn("Node is not yet reporting the assigned address") + } else { + log.WithError(err).WithFields(logrus.Fields{ + "node": node.Name, + "instance": node.Instance, + "address": assignedAddress, + }).Error("failed to check if node is reporting the assigned address") + } + + log.Infof("retrying after %v", cfg.RetryInterval) + + select { + case <-ticker.C: + continue + case <-ctx.Done(): + // If the context is done, return an error indicating that the operation was cancelled + return errors.Wrap(ctx.Err(), "context cancelled while waiting for node to report assigned address") + } + } + return errors.New("reached maximum number of retries") +} + func run(c context.Context, log *logrus.Entry, cfg *config.Config) error { ctx, cancel := context.WithCancel(c) defer cancel() @@ -170,12 +227,16 @@ func run(c context.Context, log *logrus.Entry, cfg *config.Config) error { return errors.Wrap(err, "initializing assigner") } - _, err = assignAddress(ctx, log, clientset, assigner, n, cfg) + assignedAddress, err := assignAddress(ctx, log, clientset, assigner, n, cfg) if err != nil { return errors.Wrap(err, "assigning static public IP address") } if cfg.TaintKey != "" { + if err := waitForAddressToBeReported(ctx, log, explorer, n, assignedAddress, cfg); err != nil { + return errors.Wrap(err, "waiting for node to report assigned address") + } + logger := log.WithField("taint-key", cfg.TaintKey) tainter := nd.NewTainter(clientset) diff --git a/cmd/main_test.go b/cmd/main_test.go index 57982f1..b948dc7 100644 --- a/cmd/main_test.go +++ b/cmd/main_test.go @@ -2,13 +2,16 @@ package main import ( "context" + "net" "testing" "time" "github.com/doitintl/kubeip/internal/address" "github.com/doitintl/kubeip/internal/config" + "github.com/doitintl/kubeip/internal/node" "github.com/doitintl/kubeip/internal/types" mocks "github.com/doitintl/kubeip/mocks/address" + nodeMocks "github.com/doitintl/kubeip/mocks/node" "github.com/pkg/errors" tmock "github.com/stretchr/testify/mock" "k8s.io/client-go/kubernetes/fake" @@ -180,3 +183,200 @@ func Test_assignAddress(t *testing.T) { }) } } + +func Test_waitForAddressToBeReported(t *testing.T) { + type args struct { + c context.Context + explorerFn func(t *testing.T) node.Explorer + node *types.Node + address string + cfg *config.Config + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "address reported with no retries", + args: args{ + c: context.Background(), + address: "1.1.1.1", + explorerFn: func(t *testing.T) node.Explorer { + mock := nodeMocks.NewExplorer(t) + mock.EXPECT().GetNode(tmock.Anything, "test-node").Return( + &types.Node{ + Name: "test-node", + Instance: "test-instance", + Region: "test-region", + Zone: "test-zone", + ExternalIPs: []net.IP{net.IPv4(1, 1, 1, 1)}, + }, + nil, + ) + return mock + }, + node: &types.Node{ + Name: "test-node", + Instance: "test-instance", + Region: "test-region", + Zone: "test-zone", + }, + cfg: &config.Config{ + Filter: []string{"test-filter"}, + OrderBy: "test-order-by", + RetryAttempts: 3, + RetryInterval: time.Millisecond, + LeaseDuration: 1, + }, + }, + }, + { + name: "address reported after a few retries", + args: args{ + c: context.Background(), + address: "1.1.1.1", + explorerFn: func(t *testing.T) node.Explorer { + mock := nodeMocks.NewExplorer(t) + mock.EXPECT().GetNode(tmock.Anything, "test-node").Return(&types.Node{ + Name: "test-node", + Instance: "test-instance", + Region: "test-region", + Zone: "test-zone", + ExternalIPs: []net.IP{net.IPv4(9, 9, 9, 9)}, + }, nil).Times(3) + mock.EXPECT().GetNode(tmock.Anything, "test-node").Return(&types.Node{ + Name: "test-node", + Instance: "test-instance", + Region: "test-region", + Zone: "test-zone", + ExternalIPs: []net.IP{net.IPv4(1, 1, 1, 1)}, + }, nil).Once() + return mock + }, + node: &types.Node{ + Name: "test-node", + Instance: "test-instance", + Region: "test-region", + Zone: "test-zone", + }, + cfg: &config.Config{ + Filter: []string{"test-filter"}, + OrderBy: "test-order-by", + RetryAttempts: 3, + RetryInterval: time.Millisecond, + LeaseDuration: 1, + }, + }, + }, + { + name: "error after a few retries and reached maximum number of retries", + args: args{ + c: context.Background(), + explorerFn: func(t *testing.T) node.Explorer { + mock := nodeMocks.NewExplorer(t) + mock.EXPECT().GetNode(tmock.Anything, "test-node").Return(&types.Node{ + Name: "test-node", + Instance: "test-instance", + Region: "test-region", + Zone: "test-zone", + ExternalIPs: []net.IP{net.IPv4(9, 9, 9, 9)}, + }, nil).Times(4) + mock.EXPECT().GetNode(tmock.Anything, "test-node").Return(&types.Node{ + Name: "test-node", + Instance: "test-instance", + Region: "test-region", + Zone: "test-zone", + ExternalIPs: []net.IP{net.IPv4(1, 1, 1, 1)}, + }, nil).Times(0) + return mock + }, + node: &types.Node{ + Name: "test-node", + Instance: "test-instance", + Region: "test-region", + Zone: "test-zone", + }, + cfg: &config.Config{ + Filter: []string{"test-filter"}, + OrderBy: "test-order-by", + RetryAttempts: 3, + RetryInterval: time.Millisecond, + LeaseDuration: 1, + }, + }, + wantErr: true, + }, + { + name: "context cancelled while waiting for address to be reported", + args: args{ + c: func() context.Context { + ctx, cancel := context.WithCancel(context.Background()) + go func() { + // Simulate a shutdown signal being received after a short delay + time.Sleep(20 * time.Millisecond) + cancel() + }() + return ctx + }(), + explorerFn: func(t *testing.T) node.Explorer { + mock := nodeMocks.NewExplorer(t) + mock.EXPECT().GetNode(tmock.Anything, "test-node").Return(nil, errors.New("error")).Maybe() + return mock + }, + node: &types.Node{ + Name: "test-node", + Instance: "test-instance", + Region: "test-region", + Zone: "test-zone", + }, + cfg: &config.Config{ + Filter: []string{"test-filter"}, + OrderBy: "test-order-by", + RetryAttempts: 10, + RetryInterval: 5 * time.Millisecond, + LeaseDuration: 1, + }, + }, + wantErr: true, + }, + { + name: "error after a few retries and context is done", + args: args{ + c: func() context.Context { + ctx, _ := context.WithTimeout(context.Background(), 10*time.Millisecond) //nolint:govet + return ctx + }(), + explorerFn: func(t *testing.T) node.Explorer { + mock := nodeMocks.NewExplorer(t) + mock.EXPECT().GetNode(tmock.Anything, "test-node").Return(nil, errors.New("error")).Maybe() + return mock + }, + node: &types.Node{ + Name: "test-node", + Instance: "test-instance", + Region: "test-region", + Zone: "test-zone", + }, + cfg: &config.Config{ + Filter: []string{"test-filter"}, + OrderBy: "test-order-by", + RetryAttempts: 3, + RetryInterval: 15 * time.Millisecond, + LeaseDuration: 1, + }, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + log := prepareLogger("debug", false) + explorer := tt.args.explorerFn(t) + err := waitForAddressToBeReported(tt.args.c, log, explorer, tt.args.node, tt.args.address, tt.args.cfg) + if err != nil != tt.wantErr { + t.Errorf("waitForAddressToBeReported() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}