Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Data race in sentries #5

Merged
merged 6 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 3 additions & 12 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,16 @@ package cmd
import (
"os"

cometlog "github.com/cometbft/cometbft/libs/log"
"github.com/spf13/cobra"
"github.com/strangelove-ventures/horcrux-proxy/privval"
"github.com/strangelove-ventures/horcrux-proxy/signer"
)

type appState struct {
logger cometlog.Logger
loadBalancer *privval.RemoteSignerLoadBalancer
sentries map[string]*signer.ReconnRemoteSigner
}

func rootCmd(a *appState) *cobra.Command {
func rootCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "horcrux-proxy",
Short: "A tendermint remote signer proxy",
}

cmd.AddCommand(startCmd(a))
cmd.AddCommand(startCmd())
cmd.AddCommand(versionCmd())

return cmd
Expand All @@ -30,7 +21,7 @@ func rootCmd(a *appState) *cobra.Command {
// Execute adds all child commands to the root command and sets flags appropriately.
// This is called by main.main(). It only needs to happen once to the rootCmd.
func Execute() {
if err := rootCmd(new(appState)).Execute(); err != nil {
if err := rootCmd().Execute(); err != nil {
// Cobra will print the error
os.Exit(1)
}
Expand Down
44 changes: 21 additions & 23 deletions cmd/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/spf13/cobra"

"github.com/strangelove-ventures/horcrux-proxy/privval"
"github.com/strangelove-ventures/horcrux-proxy/signer"
)

const (
Expand All @@ -17,7 +16,7 @@ const (
flagAll = "all"
)

func startCmd(a *appState) *cobra.Command {
func startCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "start",
Short: "Start horcrux-proxy process",
Expand All @@ -32,31 +31,33 @@ func startCmd(a *appState) *cobra.Command {
return fmt.Errorf("failed to parse log level: %w", err)
}

a.logger = cometlog.NewFilter(cometlog.NewTMLogger(cometlog.NewSyncWriter(out)), logLevelOpt).With("module", "validator")

a.logger.Info("Horcrux Proxy")
logger := cometlog.NewFilter(cometlog.NewTMLogger(cometlog.NewSyncWriter(out)), logLevelOpt).With("module", "validator")
logger.Info("Horcrux Proxy")

listenAddrs, _ := cmd.Flags().GetStringArray(flagListen)
all, _ := cmd.Flags().GetBool(flagAll)

listeners := make([]privval.SignerListener, len(listenAddrs))
for i, addr := range listenAddrs {
listeners[i] = privval.NewSignerListener(a.logger, addr)
listeners[i] = privval.NewSignerListener(logger, addr)
}

a.loadBalancer = privval.NewRemoteSignerLoadBalancer(a.logger, listeners)

if err := a.loadBalancer.Start(); err != nil {
loadBalancer := privval.NewRemoteSignerLoadBalancer(logger, listeners)
if err = loadBalancer.Start(); err != nil {
return fmt.Errorf("failed to start listener(s): %w", err)
}
defer logIfErr(logger, loadBalancer.Stop)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously loadBalancer would not have been stopped if watching sentries failed.


a.sentries = make(map[string]*signer.ReconnRemoteSigner)
ctx := cmd.Context()

if err := watchForChangedSentries(cmd.Context(), a, all); err != nil {
watcher, err := NewSentryWatcher(ctx, logger, all, loadBalancer)
if err != nil {
return err
}
defer logIfErr(logger, watcher.Stop)
go watcher.Watch(ctx)

waitAndTerminate(a)
waitForSignals(logger)

return nil
},
Expand All @@ -69,18 +70,15 @@ func startCmd(a *appState) *cobra.Command {
return cmd
}

func waitAndTerminate(a *appState) {
func logIfErr(logger cometlog.Logger, fn func() error) {
if err := fn(); err != nil {
logger.Error("Error", "err", err)
}
}

func waitForSignals(logger cometlog.Logger) {
done := make(chan struct{})
cometos.TrapSignal(a.logger, func() {
for _, s := range a.sentries {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One part of the race is we read from the map here.

err := s.Stop()
if err != nil {
panic(err)
}
}
if err := a.loadBalancer.Stop(); err != nil {
panic(err)
}
cometos.TrapSignal(logger, func() {
close(done)
})
<-done
Expand Down
117 changes: 78 additions & 39 deletions cmd/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package cmd

import (
"context"
"errors"
"fmt"
"net"
"os"
"time"

cometlog "github.com/cometbft/cometbft/libs/log"
"github.com/strangelove-ventures/horcrux-proxy/privval"
"github.com/strangelove-ventures/horcrux-proxy/signer"

metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -20,69 +23,105 @@ const (
labelCosmosSentry = "app.kubernetes.io/component=cosmos-sentry"
)

func watchForChangedSentries(
type SentryWatcher struct {
all bool
client *kubernetes.Clientset
lb *privval.RemoteSignerLoadBalancer
log cometlog.Logger
node string
sentries map[string]*signer.ReconnRemoteSigner

stop chan struct{}
done chan struct{}
}

func NewSentryWatcher(
ctx context.Context,
a *appState,
logger cometlog.Logger,
all bool, // should we connect to sentries on all nodes, or just this node?
) error {
lb *privval.RemoteSignerLoadBalancer,
) (*SentryWatcher, error) {
config, err := rest.InClusterConfig()
if err != nil {
return fmt.Errorf("failed to get in cluster config: %w", err)
return nil, fmt.Errorf("failed to get in cluster config: %w", err)
}
// creates the clientset
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
return fmt.Errorf("failed to create kube clientset: %w", err)
return nil, fmt.Errorf("failed to create kube clientset: %w", err)
}

thisNode := ""
var thisNode string
if !all {
// need to determine which node this pod is on so we can only connect to sentries on this node

nsbz, err := os.ReadFile(namespaceFile)
if err != nil {
return fmt.Errorf("failed to read namespace from service account: %w", err)
return nil, fmt.Errorf("failed to read namespace from service account: %w", err)
}
ns := string(nsbz)

thisPod, err := clientset.CoreV1().Pods(ns).Get(ctx, os.Getenv("HOSTNAME"), metav1.GetOptions{})
if err != nil {
return fmt.Errorf("failed to get this pod: %w", err)
return nil, fmt.Errorf("failed to get this pod: %w", err)
}

thisNode = thisPod.Spec.NodeName
}

t := time.NewTimer(30 * time.Second)
return &SentryWatcher{
all: all,
client: clientset,
done: make(chan struct{}),
lb: lb,
log: logger,
node: thisNode,
sentries: make(map[string]*signer.ReconnRemoteSigner),
stop: make(chan struct{}),
}, nil
}

go func() {
defer t.Stop()
for {
if err := reconcileSentries(ctx, a, thisNode, clientset, all); err != nil {
a.logger.Error("Failed to reconcile sentries with kube api", "error", err)
}
select {
case <-ctx.Done():
return
case <-t.C:
t.Reset(30 * time.Second)
}
// Watch will reconcile the sentries with the kube api at a reasonable interval.
// It must be called only once.
func (w *SentryWatcher) Watch(ctx context.Context) {
defer close(w.done)
const interval = 30 * time.Second
timer := time.NewTimer(interval)
defer timer.Stop()

for {
if err := w.reconcileSentries(ctx); err != nil {
w.log.Error("Failed to reconcile sentries with kube api", "error", err)
}
}()
select {
case <-w.stop:
return
case <-ctx.Done():
return
case <-timer.C:
timer.Reset(interval)
}
}
}

return nil
// Stop cleans up the sentries and stops the watcher. It must be called only once.
func (w *SentryWatcher) Stop() error {
// The dual channel synchronization ensures w.sentries is only read/mutated by one goroutine.
close(w.stop)
<-w.done
Comment on lines +110 to +111
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The synchronization here ensures only one goroutine is modifying w.sentries.

var err error
for _, sentry := range w.sentries {
err = errors.Join(err, sentry.Stop())
}
return err
}

func reconcileSentries(
func (w *SentryWatcher) reconcileSentries(
ctx context.Context,
a *appState,
thisNode string,
clientset *kubernetes.Clientset,
all bool, // should we connect to sentries on all nodes, or just this node?
) error {
configNodes := make([]string, 0)

services, err := clientset.CoreV1().Services("").List(ctx, metav1.ListOptions{
services, err := w.client.CoreV1().Services("").List(ctx, metav1.ListOptions{
LabelSelector: labelCosmosSentry,
})

Expand All @@ -97,7 +136,7 @@ func reconcileSentries(

set := labels.Set(s.Spec.Selector)

pods, err := clientset.CoreV1().Pods(s.Namespace).List(ctx, metav1.ListOptions{LabelSelector: set.AsSelector().String()})
pods, err := w.client.CoreV1().Pods(s.Namespace).List(ctx, metav1.ListOptions{LabelSelector: set.AsSelector().String()})
if err != nil {
return fmt.Errorf("failed to list pods in namespace %s for service %s: %w", s.Namespace, s.Name, err)
}
Expand All @@ -106,7 +145,7 @@ func reconcileSentries(
continue
}

if !all && pods.Items[0].Spec.NodeName != thisNode {
if !w.all && pods.Items[0].Spec.NodeName != w.node {
continue
}

Expand All @@ -118,21 +157,21 @@ func reconcileSentries(

for _, newConfigSentry := range configNodes {
foundNewConfigSentry := false
for existingSentry := range a.sentries {
for existingSentry := range w.sentries {
if existingSentry == newConfigSentry {
foundNewConfigSentry = true
break
}
}
if !foundNewConfigSentry {
a.logger.Info("Will add new sentry", "address", newConfigSentry)
w.log.Info("Will add new sentry", "address", newConfigSentry)
newSentries = append(newSentries, newConfigSentry)
}
}

removedSentries := make([]string, 0)

for existingSentry := range a.sentries {
for existingSentry := range w.sentries {
foundExistingSentry := false
for _, newConfigSentry := range configNodes {
if existingSentry == newConfigSentry {
Expand All @@ -141,26 +180,26 @@ func reconcileSentries(
}
}
if !foundExistingSentry {
a.logger.Info("Will remove existing sentry", "address", existingSentry)
w.log.Info("Will remove existing sentry", "address", existingSentry)
removedSentries = append(removedSentries, existingSentry)
}
}

for _, s := range removedSentries {
if err := a.sentries[s].Stop(); err != nil {
if err := w.sentries[s].Stop(); err != nil {
return fmt.Errorf("failed to stop remote signer: %w", err)
}
delete(a.sentries, s)
delete(w.sentries, s)
}

for _, newSentry := range newSentries {
dialer := net.Dialer{Timeout: 2 * time.Second}
s := signer.NewReconnRemoteSigner(newSentry, a.logger, a.loadBalancer, dialer)
s := signer.NewReconnRemoteSigner(newSentry, w.log, w.lb, dialer)

if err := s.Start(); err != nil {
return fmt.Errorf("failed to start new remote signer(s): %w", err)
}
a.sentries[newSentry] = s
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 2nd part of the race is we are reading and mutating the sentries map here. This is in a separate goroutine. There was no guarantee this goroutine would exit while the main() goroutine needing to read the map.

w.sentries[newSentry] = s
}

return nil
Expand Down
Loading