Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve WireGuard handshake success rate #3092

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions client/internal/peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/peer/guard"
Expand All @@ -36,10 +35,17 @@ const (
connPriorityICEP2P ConnPriority = 2
)

type WgInterface interface {
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(publicKey string) error
GetProxy() wgproxy.Proxy
GetStats(peerKey string) (configurer.WGStats, error)
}

type WgConfig struct {
WgListenPort int
RemoteKey string
WgInterface iface.IWGIface
WgInterface WgInterface
AllowedIps string
PreSharedKey *wgtypes.Key
}
Expand Down Expand Up @@ -107,6 +113,8 @@ type Conn struct {

guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup

endpointUpdater *endpointUpdater
}

// NewConn creates a new not opened Conn to the remote peer.
Expand All @@ -133,6 +141,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
semaphore: semaphore,
endpointUpdater: &endpointUpdater{
log: connLog,
wgConfig: config.WgConfig,
initiator: isWireGuardInitiator(config),
},
}

rFns := WorkerRelayCallbacks{
Expand Down Expand Up @@ -240,7 +253,7 @@ func (conn *Conn) Close() {
conn.wgProxyICE = nil
}

if err := conn.removeWgPeer(); err != nil {
if err := conn.endpointUpdater.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
}

Expand Down Expand Up @@ -364,7 +377,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
wgProxy.Work()
}

if err = conn.configureWGEndpoint(ep); err != nil {
if err = conn.endpointUpdater.configureWGEndpoint(ep); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
return
}
Expand Down Expand Up @@ -397,7 +410,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.log.Debugf("ICE disconnected, set Relay to active connection")
conn.wgProxyRelay.Work()

if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
if err := conn.endpointUpdater.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err)
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
Expand Down Expand Up @@ -456,7 +469,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
}

wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
if err := conn.endpointUpdater.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err)
}
Expand Down Expand Up @@ -486,7 +499,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {

if conn.currentConnPriority == connPriorityRelay {
conn.log.Debugf("clean up WireGuard config")
if err := conn.removeWgPeer(); err != nil {
if err := conn.endpointUpdater.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
}
}
Expand Down Expand Up @@ -525,16 +538,6 @@ func (conn *Conn) listenGuardEvent(ctx context.Context) {
}
}

func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error {
return conn.config.WgConfig.WgInterface.UpdatePeer(
conn.config.WgConfig.RemoteKey,
conn.config.WgConfig.AllowedIps,
defaultWgKeepAlive,
addr,
conn.config.WgConfig.PreSharedKey,
)
}

func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
peerState := State{
PubKey: conn.config.Key,
Expand Down Expand Up @@ -714,10 +717,6 @@ func (conn *Conn) iceP2PIsActive() bool {
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
}

func (conn *Conn) removeWgPeer() error {
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
}

func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil {
Expand Down Expand Up @@ -756,6 +755,11 @@ func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

// isWireGuardInitiator returns true if the local peer is the initiator of the WireGuard connection
func isWireGuardInitiator(config ConnConfig) bool {
return isController(config)
}

func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil
}
Expand Down
87 changes: 87 additions & 0 deletions client/internal/peer/endpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package peer

import (
"context"
"net"
"sync"
"time"

"github.com/sirupsen/logrus"
)

// fallbackDelay could be const but because of testing it is a var
var fallbackDelay = 5 * time.Second

type endpointUpdater struct {
log *logrus.Entry
wgConfig WgConfig
initiator bool

cancelFunc func()
configUpdateMutex sync.Mutex
}

// configureWGEndpoint sets up the WireGuard endpoint configuration.
// The initiator immediately configures the endpoint, while the non-initiator
// waits for a fallback period before configuring to avoid handshake congestion.
func (e *endpointUpdater) configureWGEndpoint(addr *net.UDPAddr) error {
if e.initiator {
return e.updateWireGuardPeer(addr)
}

// prevent to run new update while cancel the previous update
e.configUpdateMutex.Lock()
if e.cancelFunc != nil {
e.cancelFunc()
}
e.configUpdateMutex.Unlock()

var ctx context.Context
ctx, e.cancelFunc = context.WithCancel(context.Background())
go e.scheduleDelayedUpdate(ctx, addr)

return e.updateWireGuardPeer(nil)
}

func (e *endpointUpdater) removeWgPeer() error {
e.configUpdateMutex.Lock()
defer e.configUpdateMutex.Unlock()

if e.cancelFunc != nil {
e.cancelFunc()
}

return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
}

// scheduleDelayedUpdate waits for the fallback period before updating the endpoint
func (e *endpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr) {
t := time.NewTimer(fallbackDelay)
defer t.Stop()

select {
case <-ctx.Done():
return
case <-t.C:
e.configUpdateMutex.Lock()
defer e.configUpdateMutex.Unlock()

if ctx.Err() != nil {
return
}

if err := e.updateWireGuardPeer(addr); err != nil {
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
}
}
}

func (e *endpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr) error {
return e.wgConfig.WgInterface.UpdatePeer(
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
endpoint,
e.wgConfig.PreSharedKey,
)
}
178 changes: 178 additions & 0 deletions client/internal/peer/endpoint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package peer

import (
"net"
"testing"
"time"

log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/mock"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)

type MockWgInterface struct {
mock.Mock

lastSetAddr *net.UDPAddr
}

func (m *MockWgInterface) GetStats(peerKey string) (configurer.WGStats, error) {
panic("implement me")
}

func (m *MockWgInterface) GetProxy() wgproxy.Proxy {
panic("implement me")
}

func (m *MockWgInterface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
args := m.Called(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
m.lastSetAddr = endpoint
return args.Error(0)
}

func (m *MockWgInterface) RemovePeer(publicKey string) error {
args := m.Called(publicKey)
return args.Error(0)
}

func Test_endpointUpdater_initiator(t *testing.T) {
mockWgInterface := &MockWgInterface{}
e := &endpointUpdater{
log: log.WithField("peer", "my-peer-key"),
wgConfig: WgConfig{
WgListenPort: 51820,
RemoteKey: "secret-remote-key",
WgInterface: mockWgInterface,
AllowedIps: "172.16.254.1",
},
initiator: true,
}
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1234,
}

mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
addr,
(*wgtypes.Key)(nil),
).Return(nil)

if err := e.configureWGEndpoint(addr); err != nil {
t.Fatalf("updateWireGuardPeer() failed: %v", err)
}

mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr, (*wgtypes.Key)(nil))
}

func Test_endpointUpdater_nonInitiator(t *testing.T) {
fallbackDelay = 1 * time.Second
mockWgInterface := &MockWgInterface{}
e := &endpointUpdater{
log: log.WithField("peer", "my-peer-key"),
wgConfig: WgConfig{
WgListenPort: 51820,
RemoteKey: "secret-remote-key",
WgInterface: mockWgInterface,
AllowedIps: "172.16.254.1",
},
initiator: false,
}
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1234,
}

mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
(*net.UDPAddr)(nil),
(*wgtypes.Key)(nil),
).Return(nil)

mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
addr,
(*wgtypes.Key)(nil),
).Return(nil)

err := e.configureWGEndpoint(addr)
if err != nil {
t.Fatalf("updateWireGuardPeer() failed: %v", err)
}
mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, (*net.UDPAddr)(nil), (*wgtypes.Key)(nil))

time.Sleep(fallbackDelay + time.Second)

mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr, (*wgtypes.Key)(nil))
}

func Test_endpointUpdater_overRule(t *testing.T) {
fallbackDelay = 1 * time.Second
mockWgInterface := &MockWgInterface{}
e := &endpointUpdater{
log: log.WithField("peer", "my-peer-key"),
wgConfig: WgConfig{
WgListenPort: 51820,
RemoteKey: "secret-remote-key",
WgInterface: mockWgInterface,
AllowedIps: "172.16.254.1",
},
initiator: false,
}
addr1 := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1000,
}

addr2 := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1001,
}

mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
(*net.UDPAddr)(nil),
(*wgtypes.Key)(nil),
).Return(nil)

mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
addr2,
(*wgtypes.Key)(nil),
).Return(nil)

if err := e.configureWGEndpoint(addr1); err != nil {
t.Fatalf("updateWireGuardPeer() failed: %v", err)
}
mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, (*net.UDPAddr)(nil), (*wgtypes.Key)(nil))

if err := e.configureWGEndpoint(addr2); err != nil {
t.Fatalf("updateWireGuardPeer() failed: %v", err)
}

time.Sleep(fallbackDelay + time.Second)

mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr2, (*wgtypes.Key)(nil))

if mockWgInterface.lastSetAddr != addr2 {
t.Fatalf("lastSetAddr is not equal to addr2")
}
}
Loading