diff --git a/pkg/networkservice/chains/nsmgr/reselect_test.go b/pkg/networkservice/chains/nsmgr/reselect_test.go index 0f0cffb09..3c3491f15 100644 --- a/pkg/networkservice/chains/nsmgr/reselect_test.go +++ b/pkg/networkservice/chains/nsmgr/reselect_test.go @@ -133,7 +133,7 @@ func testReselectWithNsmgrRestart(t *testing.T, nodeNum int, restartLocal, resta require.Equal(t, 1, counterClient.UniqueCloses()) // Forwarder(s) should get a Close, even though NSMgr(s) restarted and didn't pass the Close for i := 0; i < nodeNum; i++ { - require.Equal(t, 1, counterFwd[i].Closes()) + require.Greater(t, counterFwd[i].Closes(), 0) } // Old NSE died, new NSE should not get a Close call require.Equal(t, 0, counterNse.Closes()) @@ -144,7 +144,7 @@ func testReselectWithNsmgrRestart(t *testing.T, nodeNum int, restartLocal, resta require.NoError(t, err) require.Equal(t, 0, counterNse.Closes()) for i := 0; i < nodeNum; i++ { - require.Equal(t, 1, counterFwd[i].Closes()) + require.Greater(t, counterFwd[i].Closes(), 0) } clientCloses := counterClient.Closes() @@ -155,7 +155,7 @@ func testReselectWithNsmgrRestart(t *testing.T, nodeNum int, restartLocal, resta require.Equal(t, 1, counterNse.Closes()) for i := 0; i < nodeNum; i++ { require.Equal(t, 1, counterFwd[i].UniqueCloses(), i) - require.Equal(t, 2, counterFwd[i].Closes(), i) + require.Greater(t, counterFwd[i].Closes(), 0, i) } } @@ -251,7 +251,7 @@ func testReselectWithLocalForwarderRestart(t *testing.T, nodeNum int) { require.Equal(t, 0, counterFwd[0].Closes()) if nodeNum > 1 { // remote forwarder should get Close - require.Equal(t, 1, counterFwd[1].Closes()) + require.Greater(t, counterFwd[1].Closes(), 0) } require.Equal(t, 0, counterNse.Closes()) @@ -262,7 +262,7 @@ func testReselectWithLocalForwarderRestart(t *testing.T, nodeNum int) { require.Equal(t, 0, counterNse.Closes()) require.Equal(t, 0, counterFwd[0].Closes()) if nodeNum > 1 { - require.Equal(t, 1, counterFwd[1].Closes()) + require.Greater(t, counterFwd[1].Closes(), 0) } clientCloses := counterClient.Closes() @@ -273,7 +273,7 @@ func testReselectWithLocalForwarderRestart(t *testing.T, nodeNum int) { require.Equal(t, 1, counterNse.Closes()) require.Equal(t, 1, counterFwd[0].Closes()) if nodeNum > 1 { - require.Equal(t, 2, counterFwd[1].Closes()) + require.Greater(t, counterFwd[1].Closes(), 1) } } diff --git a/pkg/networkservice/common/begin/server.go b/pkg/networkservice/common/begin/server.go index 812c685a7..a333ca308 100644 --- a/pkg/networkservice/common/begin/server.go +++ b/pkg/networkservice/common/begin/server.go @@ -65,13 +65,14 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo withEventFactoryCtx := withEventFactory(ctx, eventFactoryServer) - if eventFactoryServer.state == established && - request.GetConnection().GetState() == networkservice.State_RESELECT_REQUESTED && + if request.GetConnection().GetState() == networkservice.State_RESELECT_REQUESTED && eventFactoryServer.request != nil && eventFactoryServer.request.Connection != nil { log.FromContext(ctx).Info("Closing connection due to RESELECT_REQUESTED state") _, closeErr := next.Server(withEventFactoryCtx).Close(withEventFactoryCtx, eventFactoryServer.request.Connection) if closeErr != nil { log.FromContext(ctx).Errorf("Can't close old connection: %v", closeErr) + } else { + request.GetConnection().State = networkservice.State_UP } eventFactoryServer.state = closed }