diff --git a/beacon/light/api/light_api.go b/beacon/light/api/light_api.go index 903db5734455..6f60fc0cc655 100755 --- a/beacon/light/api/light_api.go +++ b/beacon/light/api/light_api.go @@ -494,9 +494,6 @@ func (api *BeaconLightApi) StartHeadListener(listener HeadEventListener) func() for { select { - case <-ctx.Done(): - stream.Close() - case event, ok := <-stream.Events: if !ok { log.Trace("Event stream closed") diff --git a/beacon/light/request/server.go b/beacon/light/request/server.go index 9f3b09b81e80..a06dec99ae75 100644 --- a/beacon/light/request/server.go +++ b/beacon/light/request/server.go @@ -186,10 +186,14 @@ func (s *serverWithTimeout) eventCallback(event Event) { // call will just do nothing timer.Stop() delete(s.timeouts, id) - s.childEventCb(event) + if s.childEventCb != nil { + s.childEventCb(event) + } } default: - s.childEventCb(event) + if s.childEventCb != nil { + s.childEventCb(event) + } } } @@ -211,25 +215,27 @@ func (s *serverWithTimeout) startTimeout(reqData RequestResponse) { delete(s.timeouts, id) childEventCb := s.childEventCb s.lock.Unlock() - childEventCb(Event{Type: EvFail, Data: reqData}) + if childEventCb != nil { + childEventCb(Event{Type: EvFail, Data: reqData}) + } }) childEventCb := s.childEventCb s.lock.Unlock() - childEventCb(Event{Type: EvTimeout, Data: reqData}) + if childEventCb != nil { + childEventCb(Event{Type: EvTimeout, Data: reqData}) + } }) } // unsubscribe stops all goroutines associated with the server. func (s *serverWithTimeout) unsubscribe() { s.lock.Lock() - defer s.lock.Unlock() - for _, timer := range s.timeouts { if timer != nil { timer.Stop() } } - s.childEventCb = nil + s.lock.Unlock() s.parent.Unsubscribe() } @@ -328,10 +334,10 @@ func (s *serverWithLimits) eventCallback(event Event) { } childEventCb := s.childEventCb s.lock.Unlock() - if passEvent { + if passEvent && childEventCb != nil { childEventCb(event) } - if sendCanRequestAgain { + if sendCanRequestAgain && childEventCb != nil { childEventCb(Event{Type: EvCanRequestAgain}) } } @@ -347,13 +353,12 @@ func (s *serverWithLimits) sendRequest(request Request) (reqId ID) { // unsubscribe stops all goroutines associated with the server. func (s *serverWithLimits) unsubscribe() { s.lock.Lock() - defer s.lock.Unlock() - if s.delayTimer != nil { s.delayTimer.Stop() s.delayTimer = nil } s.childEventCb = nil + s.lock.Unlock() s.serverWithTimeout.unsubscribe() } @@ -383,7 +388,7 @@ func (s *serverWithLimits) canRequestNow() bool { } childEventCb := s.childEventCb s.lock.Unlock() - if sendCanRequestAgain { + if sendCanRequestAgain && childEventCb != nil { childEventCb(Event{Type: EvCanRequestAgain}) } return canRequest @@ -415,7 +420,7 @@ func (s *serverWithLimits) delay(delay time.Duration) { } childEventCb := s.childEventCb s.lock.Unlock() - if sendCanRequestAgain { + if sendCanRequestAgain && childEventCb != nil { childEventCb(Event{Type: EvCanRequestAgain}) } }) diff --git a/beacon/light/request/server_test.go b/beacon/light/request/server_test.go index 38629cb8c464..fef5d062ea2c 100644 --- a/beacon/light/request/server_test.go +++ b/beacon/light/request/server_test.go @@ -51,6 +51,7 @@ func TestServerEvents(t *testing.T) { expEvent(EvFail) rs.eventCb(Event{Type: EvResponse, Data: RequestResponse{ID: 1, Request: testRequest, Response: testResponse}}) expEvent(nil) + srv.unsubscribe() } func TestServerParallel(t *testing.T) { @@ -129,9 +130,7 @@ func TestServerEventRateLimit(t *testing.T) { srv := NewServer(rs, clock) var eventCount int srv.subscribe(func(event Event) { - if !event.IsRequestEvent() { - eventCount++ - } + eventCount++ }) expEvents := func(send, expAllowed int) { eventCount = 0 @@ -147,6 +146,30 @@ func TestServerEventRateLimit(t *testing.T) { expEvents(5, 1) clock.Run(maxServerEventRate * maxServerEventBuffer * 2) expEvents(maxServerEventBuffer+5, maxServerEventBuffer) + srv.unsubscribe() +} + +func TestServerUnsubscribe(t *testing.T) { + rs := &testRequestServer{} + clock := &mclock.Simulated{} + srv := NewServer(rs, clock) + var eventCount int + srv.subscribe(func(event Event) { + eventCount++ + }) + eventCb := rs.eventCb + eventCb(Event{Type: testEventType}) + if eventCount != 1 { + t.Errorf("Server event callback not called before unsubscribe") + } + srv.unsubscribe() + if rs.eventCb != nil { + t.Errorf("Server event callback not removed after unsubscribe") + } + eventCb(Event{Type: testEventType}) + if eventCount != 1 { + t.Errorf("Server event callback called after unsubscribe") + } } type testRequestServer struct { @@ -156,4 +179,4 @@ type testRequestServer struct { func (rs *testRequestServer) Name() string { return "" } func (rs *testRequestServer) Subscribe(eventCb func(Event)) { rs.eventCb = eventCb } func (rs *testRequestServer) SendRequest(ID, Request) {} -func (rs *testRequestServer) Unsubscribe() {} +func (rs *testRequestServer) Unsubscribe() { rs.eventCb = nil }