diff --git a/cmd/root.go b/cmd/root.go index ae21f95..d534a46 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -3,25 +3,16 @@ package cmd import ( "os" - cometlog "github.com/cometbft/cometbft/libs/log" "github.com/spf13/cobra" - "github.com/strangelove-ventures/horcrux-proxy/privval" - "github.com/strangelove-ventures/horcrux-proxy/signer" ) -type appState struct { - logger cometlog.Logger - loadBalancer *privval.RemoteSignerLoadBalancer - sentries map[string]*signer.ReconnRemoteSigner -} - -func rootCmd(a *appState) *cobra.Command { +func rootCmd() *cobra.Command { cmd := &cobra.Command{ Use: "horcrux-proxy", Short: "A tendermint remote signer proxy", } - cmd.AddCommand(startCmd(a)) + cmd.AddCommand(startCmd()) cmd.AddCommand(versionCmd()) return cmd @@ -30,7 +21,7 @@ func rootCmd(a *appState) *cobra.Command { // Execute adds all child commands to the root command and sets flags appropriately. // This is called by main.main(). It only needs to happen once to the rootCmd. func Execute() { - if err := rootCmd(new(appState)).Execute(); err != nil { + if err := rootCmd().Execute(); err != nil { // Cobra will print the error os.Exit(1) } diff --git a/cmd/start.go b/cmd/start.go index a6c48c4..4557ebc 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -8,7 +8,6 @@ import ( "github.com/spf13/cobra" "github.com/strangelove-ventures/horcrux-proxy/privval" - "github.com/strangelove-ventures/horcrux-proxy/signer" ) const ( @@ -17,7 +16,7 @@ const ( flagAll = "all" ) -func startCmd(a *appState) *cobra.Command { +func startCmd() *cobra.Command { cmd := &cobra.Command{ Use: "start", Short: "Start horcrux-proxy process", @@ -32,31 +31,33 @@ func startCmd(a *appState) *cobra.Command { return fmt.Errorf("failed to parse log level: %w", err) } - a.logger = cometlog.NewFilter(cometlog.NewTMLogger(cometlog.NewSyncWriter(out)), logLevelOpt).With("module", "validator") - - a.logger.Info("Horcrux Proxy") + logger := cometlog.NewFilter(cometlog.NewTMLogger(cometlog.NewSyncWriter(out)), logLevelOpt).With("module", "validator") + logger.Info("Horcrux Proxy") listenAddrs, _ := cmd.Flags().GetStringArray(flagListen) all, _ := cmd.Flags().GetBool(flagAll) listeners := make([]privval.SignerListener, len(listenAddrs)) for i, addr := range listenAddrs { - listeners[i] = privval.NewSignerListener(a.logger, addr) + listeners[i] = privval.NewSignerListener(logger, addr) } - a.loadBalancer = privval.NewRemoteSignerLoadBalancer(a.logger, listeners) - - if err := a.loadBalancer.Start(); err != nil { + loadBalancer := privval.NewRemoteSignerLoadBalancer(logger, listeners) + if err = loadBalancer.Start(); err != nil { return fmt.Errorf("failed to start listener(s): %w", err) } + defer logIfErr(logger, loadBalancer.Stop) - a.sentries = make(map[string]*signer.ReconnRemoteSigner) + ctx := cmd.Context() - if err := watchForChangedSentries(cmd.Context(), a, all); err != nil { + watcher, err := NewSentryWatcher(ctx, logger, all, loadBalancer) + if err != nil { return err } + defer logIfErr(logger, watcher.Stop) + go watcher.Watch(ctx) - waitAndTerminate(a) + waitForSignals(logger) return nil }, @@ -69,18 +70,15 @@ func startCmd(a *appState) *cobra.Command { return cmd } -func waitAndTerminate(a *appState) { +func logIfErr(logger cometlog.Logger, fn func() error) { + if err := fn(); err != nil { + logger.Error("Error", "err", err) + } +} + +func waitForSignals(logger cometlog.Logger) { done := make(chan struct{}) - cometos.TrapSignal(a.logger, func() { - for _, s := range a.sentries { - err := s.Stop() - if err != nil { - panic(err) - } - } - if err := a.loadBalancer.Stop(); err != nil { - panic(err) - } + cometos.TrapSignal(logger, func() { close(done) }) <-done diff --git a/cmd/watcher.go b/cmd/watcher.go index 3769525..46b53f4 100644 --- a/cmd/watcher.go +++ b/cmd/watcher.go @@ -2,11 +2,14 @@ package cmd import ( "context" + "errors" "fmt" "net" "os" "time" + cometlog "github.com/cometbft/cometbft/libs/log" + "github.com/strangelove-ventures/horcrux-proxy/privval" "github.com/strangelove-ventures/horcrux-proxy/signer" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -20,69 +23,105 @@ const ( labelCosmosSentry = "app.kubernetes.io/component=cosmos-sentry" ) -func watchForChangedSentries( +type SentryWatcher struct { + all bool + client *kubernetes.Clientset + lb *privval.RemoteSignerLoadBalancer + log cometlog.Logger + node string + sentries map[string]*signer.ReconnRemoteSigner + + stop chan struct{} + done chan struct{} +} + +func NewSentryWatcher( ctx context.Context, - a *appState, + logger cometlog.Logger, all bool, // should we connect to sentries on all nodes, or just this node? -) error { + lb *privval.RemoteSignerLoadBalancer, +) (*SentryWatcher, error) { config, err := rest.InClusterConfig() if err != nil { - return fmt.Errorf("failed to get in cluster config: %w", err) + return nil, fmt.Errorf("failed to get in cluster config: %w", err) } // creates the clientset clientset, err := kubernetes.NewForConfig(config) if err != nil { - return fmt.Errorf("failed to create kube clientset: %w", err) + return nil, fmt.Errorf("failed to create kube clientset: %w", err) } - thisNode := "" + var thisNode string if !all { // need to determine which node this pod is on so we can only connect to sentries on this node nsbz, err := os.ReadFile(namespaceFile) if err != nil { - return fmt.Errorf("failed to read namespace from service account: %w", err) + return nil, fmt.Errorf("failed to read namespace from service account: %w", err) } ns := string(nsbz) thisPod, err := clientset.CoreV1().Pods(ns).Get(ctx, os.Getenv("HOSTNAME"), metav1.GetOptions{}) if err != nil { - return fmt.Errorf("failed to get this pod: %w", err) + return nil, fmt.Errorf("failed to get this pod: %w", err) } thisNode = thisPod.Spec.NodeName } - t := time.NewTimer(30 * time.Second) + return &SentryWatcher{ + all: all, + client: clientset, + done: make(chan struct{}), + lb: lb, + log: logger, + node: thisNode, + sentries: make(map[string]*signer.ReconnRemoteSigner), + stop: make(chan struct{}), + }, nil +} - go func() { - defer t.Stop() - for { - if err := reconcileSentries(ctx, a, thisNode, clientset, all); err != nil { - a.logger.Error("Failed to reconcile sentries with kube api", "error", err) - } - select { - case <-ctx.Done(): - return - case <-t.C: - t.Reset(30 * time.Second) - } +// Watch will reconcile the sentries with the kube api at a reasonable interval. +// It must be called only once. +func (w *SentryWatcher) Watch(ctx context.Context) { + defer close(w.done) + const interval = 30 * time.Second + timer := time.NewTimer(interval) + defer timer.Stop() + + for { + if err := w.reconcileSentries(ctx); err != nil { + w.log.Error("Failed to reconcile sentries with kube api", "error", err) } - }() + select { + case <-w.stop: + return + case <-ctx.Done(): + return + case <-timer.C: + timer.Reset(interval) + } + } +} - return nil +// Stop cleans up the sentries and stops the watcher. It must be called only once. +func (w *SentryWatcher) Stop() error { + // The dual channel synchronization ensures w.sentries is only read/mutated by one goroutine. + close(w.stop) + <-w.done + var err error + for _, sentry := range w.sentries { + err = errors.Join(err, sentry.Stop()) + } + return err } -func reconcileSentries( +func (w *SentryWatcher) reconcileSentries( ctx context.Context, - a *appState, - thisNode string, - clientset *kubernetes.Clientset, - all bool, // should we connect to sentries on all nodes, or just this node? ) error { configNodes := make([]string, 0) - services, err := clientset.CoreV1().Services("").List(ctx, metav1.ListOptions{ + services, err := w.client.CoreV1().Services("").List(ctx, metav1.ListOptions{ LabelSelector: labelCosmosSentry, }) @@ -97,7 +136,7 @@ func reconcileSentries( set := labels.Set(s.Spec.Selector) - pods, err := clientset.CoreV1().Pods(s.Namespace).List(ctx, metav1.ListOptions{LabelSelector: set.AsSelector().String()}) + pods, err := w.client.CoreV1().Pods(s.Namespace).List(ctx, metav1.ListOptions{LabelSelector: set.AsSelector().String()}) if err != nil { return fmt.Errorf("failed to list pods in namespace %s for service %s: %w", s.Namespace, s.Name, err) } @@ -106,7 +145,7 @@ func reconcileSentries( continue } - if !all && pods.Items[0].Spec.NodeName != thisNode { + if !w.all && pods.Items[0].Spec.NodeName != w.node { continue } @@ -118,21 +157,21 @@ func reconcileSentries( for _, newConfigSentry := range configNodes { foundNewConfigSentry := false - for existingSentry := range a.sentries { + for existingSentry := range w.sentries { if existingSentry == newConfigSentry { foundNewConfigSentry = true break } } if !foundNewConfigSentry { - a.logger.Info("Will add new sentry", "address", newConfigSentry) + w.log.Info("Will add new sentry", "address", newConfigSentry) newSentries = append(newSentries, newConfigSentry) } } removedSentries := make([]string, 0) - for existingSentry := range a.sentries { + for existingSentry := range w.sentries { foundExistingSentry := false for _, newConfigSentry := range configNodes { if existingSentry == newConfigSentry { @@ -141,26 +180,26 @@ func reconcileSentries( } } if !foundExistingSentry { - a.logger.Info("Will remove existing sentry", "address", existingSentry) + w.log.Info("Will remove existing sentry", "address", existingSentry) removedSentries = append(removedSentries, existingSentry) } } for _, s := range removedSentries { - if err := a.sentries[s].Stop(); err != nil { + if err := w.sentries[s].Stop(); err != nil { return fmt.Errorf("failed to stop remote signer: %w", err) } - delete(a.sentries, s) + delete(w.sentries, s) } for _, newSentry := range newSentries { dialer := net.Dialer{Timeout: 2 * time.Second} - s := signer.NewReconnRemoteSigner(newSentry, a.logger, a.loadBalancer, dialer) + s := signer.NewReconnRemoteSigner(newSentry, w.log, w.lb, dialer) if err := s.Start(); err != nil { return fmt.Errorf("failed to start new remote signer(s): %w", err) } - a.sentries[newSentry] = s + w.sentries[newSentry] = s } return nil diff --git a/privval/signer_listener_endpoint.go b/privval/signer_listener_endpoint.go index 9b6b033..78a47f7 100644 --- a/privval/signer_listener_endpoint.go +++ b/privval/signer_listener_endpoint.go @@ -38,7 +38,7 @@ type SignerListenerEndpoint struct { pingTimer *time.Ticker pingInterval time.Duration - instanceMtx cmtsync.Mutex // Ensures instance public methods access, i.e. SendRequest + mu cmtsync.Mutex // Ensures instance public methods access, i.e. SendRequest } // NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint. @@ -81,8 +81,8 @@ func (sl *SignerListenerEndpoint) OnStart() error { // OnStop implements service.Service func (sl *SignerListenerEndpoint) OnStop() { - sl.instanceMtx.Lock() - defer sl.instanceMtx.Unlock() + sl.mu.Lock() + defer sl.mu.Unlock() _ = sl.Close() // Stop listening @@ -98,15 +98,15 @@ func (sl *SignerListenerEndpoint) OnStop() { // WaitForConnection waits maxWait for a connection or returns a timeout error func (sl *SignerListenerEndpoint) WaitForConnection(maxWait time.Duration) error { - sl.instanceMtx.Lock() - defer sl.instanceMtx.Unlock() + sl.mu.Lock() + defer sl.mu.Unlock() return sl.ensureConnection(maxWait) } // SendRequest ensures there is a connection, sends a request and waits for a response func (sl *SignerListenerEndpoint) SendRequest(request privvalproto.Message) (*privvalproto.Message, error) { - sl.instanceMtx.Lock() - defer sl.instanceMtx.Unlock() + sl.mu.Lock() + defer sl.mu.Unlock() err := sl.ensureConnection(sl.timeoutAccept) if err != nil { @@ -209,12 +209,10 @@ func (sl *SignerListenerEndpoint) pingLoop() { for { select { case <-sl.pingTimer.C: - { - _, err := sl.SendRequest(mustWrapMsg(&privvalproto.PingRequest{})) - if err != nil { - sl.Logger.Error("SignerListener: Ping timeout") - sl.triggerReconnect() - } + _, err := sl.SendRequest(mustWrapMsg(&privvalproto.PingRequest{})) + if err != nil { + sl.Logger.Error("SignerListener: Ping timeout") + sl.triggerReconnect() } case <-sl.Quit(): return