diff --git a/relay-server/server/relayServer.go b/relay-server/server/relayServer.go index 97e14cc..6c46030 100644 --- a/relay-server/server/relayServer.go +++ b/relay-server/server/relayServer.go @@ -372,10 +372,6 @@ func NewClient(server string) *LogClient { kg.Warnf("Failed to call WatchLogs (%s)\n err=%s", server, err.Error()) return nil } - // == // - - // set wait group - lc.WgServer, lc.Context = errgroup.WithContext(context.Background()) return lc } @@ -402,30 +398,35 @@ func (lc *LogClient) DoHealthCheck() bool { } // WatchMessages Function -func (lc *LogClient) WatchMessages(ctx context.Context) error { +func (lc *LogClient) WatchMessages(wg *sync.WaitGroup, stop chan struct{}, errCh chan error) { + + defer wg.Done() var err error for lc.Running { var res *pb.Message - if res, err = lc.MsgStream.Recv(); err != nil { - return fmt.Errorf("failed to receive a message (%s) %s", lc.Server, err.Error()) - - } select { - case MsgBufferChannel <- res: - case <-ctx.Done(): - // The context is over, stop processing results - return nil + case <-stop: + return default: - //not able to add it to Log buffer + if res, err = lc.MsgStream.Recv(); err != nil { + errCh <- fmt.Errorf("failed to receive a message (%s) %s", lc.Server, err.Error()) + return + } + + select { + case MsgBufferChannel <- res: + case <-stop: + return + default: + // Not able to add it to Message buffer + } } } kg.Print("Stopped watching messages from " + lc.Server) - - return nil } // AddMsgFromBuffChan Adds Msg from MsgBufferChannel into MsgStructs @@ -461,30 +462,35 @@ func (rs *RelayServer) AddMsgFromBuffChan() { } // WatchAlerts Function -func (lc *LogClient) WatchAlerts(ctx context.Context) error { +func (lc *LogClient) WatchAlerts(wg *sync.WaitGroup, stop chan struct{}, errCh chan error) { + + defer wg.Done() var err error for lc.Running { var res *pb.Alert - if res, err = lc.AlertStream.Recv(); err != nil { - return fmt.Errorf("failed to receive a alert (%s) %s", lc.Server, err.Error()) - } - select { - case AlertBufferChannel <- res: - case <-ctx.Done(): - // The context is over, stop processing results - return nil + case <-stop: + return default: - //not able to add it to Log buffer + if res, err = lc.AlertStream.Recv(); err != nil { + errCh <- fmt.Errorf("failed to receive an alert (%s) %s", lc.Server, err.Error()) + return + } + + select { + case AlertBufferChannel <- res: + case <-stop: + return + default: + // Not able to add it to Alert buffer + } } } kg.Print("Stopped watching alerts from " + lc.Server) - - return nil } // AddAlertFromBuffChan Adds ALert from AlertBufferChannel into AlertStructs @@ -520,30 +526,34 @@ func (rs *RelayServer) AddAlertFromBuffChan() { } // WatchLogs Function -func (lc *LogClient) WatchLogs(ctx context.Context) error { +func (lc *LogClient) WatchLogs(wg *sync.WaitGroup, stop chan struct{}, errCh chan error) { + defer wg.Done() var err error for lc.Running { var res *pb.Log - if res, err = lc.LogStream.Recv(); err != nil { - return fmt.Errorf("failed to receive a log (%s) %s", lc.Server, err.Error()) - } - select { - case LogBufferChannel <- res: - case <-ctx.Done(): - // The context is over, stop processing results - return nil + case <-stop: + return default: - //not able to add it to Log buffer + if res, err = lc.LogStream.Recv(); err != nil { + errCh <- fmt.Errorf("failed to receive a log (%s) %s", lc.Server, err.Error()) + return + } + + select { + case LogBufferChannel <- res: + case <-stop: + return + default: + // Not able to add it to Log buffer + } } } kg.Print("Stopped watching logs from " + lc.Server) - - return nil } // AddLogFromBuffChan Adds Log from LogBufferChannel into LogStructs @@ -744,26 +754,45 @@ func connectToKubeArmor(nodeIP, port string) error { } kg.Printf("Checked the liveness of KubeArmor's gRPC service (%s)", server) - // watch messages - client.WgServer.Go(func() error { - return client.WatchMessages(client.Context) - }) + var wg sync.WaitGroup + stop := make(chan struct{}) + errCh := make(chan error, 1) + + // Start watching messages + wg.Add(1) + go func() { + client.WatchMessages(&wg, stop, errCh) + }() kg.Print("Started to watch messages from " + server) - // watch alerts - client.WgServer.Go(func() error { - return client.WatchAlerts(client.Context) - }) + // Start watching alerts + wg.Add(1) + go func() { + client.WatchAlerts(&wg, stop, errCh) + }() kg.Print("Started to watch alerts from " + server) - // watch logs - client.WgServer.Go(func() error { - return client.WatchLogs(client.Context) - }) + // Start watching logs + wg.Add(1) + go func() { + client.WatchLogs(&wg, stop, errCh) + }() kg.Print("Started to watch logs from " + server) - if err := client.WgServer.Wait(); err != nil { + // Wait for an error or all goroutines to finish + select { + case err := <-errCh: + close(stop) // Stop other goroutines kg.Warn(err.Error()) + case <-func() chan struct{} { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + return done + }(): + // All goroutines finished without error } if err := client.DestroyClient(); err != nil {