diff --git a/mutex_map.go b/mutex_map.go index 967e033..df77df9 100644 --- a/mutex_map.go +++ b/mutex_map.go @@ -26,6 +26,22 @@ func (m *MutexMap[K, V]) Get(key K) (V, bool) { return value, ok } +// GetOrSetDefault returns the value for the given key if it exists. If it does not exist, it creates a default +// with the provided function and sets it for that key +func (m *MutexMap[K, V]) GetOrSetDefault(key K, createDefault func() V) V { + m.Lock() + defer m.Unlock() + + value, ok := m.real[key] + + if !ok { + value = createDefault() + m.real[key] = value + } + + return value +} + // Has checks if a key exists in the map func (m *MutexMap[K, V]) Has(key K) bool { m.RLock() diff --git a/prudp_connection.go b/prudp_connection.go index ce66f87..aea1b58 100644 --- a/prudp_connection.go +++ b/prudp_connection.go @@ -29,11 +29,13 @@ type PRUDPConnection struct { Signature []byte // * Connection signature for packets coming from the client, as seen by the server ServerConnectionSignature []byte // * Connection signature for packets coming from the server, as seen by the client UnreliablePacketBaseKey []byte // * The base key used for encrypting unreliable DATA packets + rtt *RTT // * The round-trip transmission time of this connection slidingWindows *MutexMap[uint8, *SlidingWindow] // * Outbound reliable packet substreams packetDispatchQueues *MutexMap[uint8, *PacketDispatchQueue] // * Inbound reliable packet substreams incomingFragmentBuffers *MutexMap[uint8, []byte] // * Buffers which store the incoming payloads from fragmented DATA packets outgoingUnreliableSequenceIDCounter *Counter[uint16] outgoingPingSequenceIDCounter *Counter[uint16] + lastSentPingTime time.Time heartbeatTimer *time.Timer pingKickTimer *time.Timer StationURLs *types.List[*types.StationURL] @@ -62,12 +64,13 @@ func (pc *PRUDPConnection) SetPID(pid *types.PID) { // reset resets the connection state to all zero values func (pc *PRUDPConnection) reset() { + pc.ConnectionState = StateNotConnected pc.packetDispatchQueues.Clear(func(_ uint8, packetDispatchQueue *PacketDispatchQueue) { packetDispatchQueue.Purge() }) pc.slidingWindows.Clear(func(_ uint8, slidingWindow *SlidingWindow) { - slidingWindow.ResendScheduler.Stop() + slidingWindow.TimeoutManager.Stop() }) pc.Signature = make([]byte, 0) @@ -289,6 +292,7 @@ func NewPRUDPConnection(socket *SocketConnection) *PRUDPConnection { pc := &PRUDPConnection{ Socket: socket, ConnectionState: StateNotConnected, + rtt: NewRTT(), pid: types.NewPID(0), slidingWindows: NewMutexMap[uint8, *SlidingWindow](), packetDispatchQueues: NewMutexMap[uint8, *PacketDispatchQueue](), diff --git a/prudp_endpoint.go b/prudp_endpoint.go index 11269bf..c2713ac 100644 --- a/prudp_endpoint.go +++ b/prudp_endpoint.go @@ -19,20 +19,25 @@ import ( // and secure servers. However the functionality of rdv::PRUDPEndPoint and nn::nex::SecureEndPoint is seemingly // identical. Rather than duplicate the logic from PRUDPEndpoint, a IsSecureEndpoint flag has been added instead. type PRUDPEndPoint struct { - Server *PRUDPServer - StreamID uint8 - DefaultStreamSettings *StreamSettings - Connections *MutexMap[string, *PRUDPConnection] - packetHandlers map[uint16]func(packet PRUDPPacketInterface) - packetEventHandlers map[string][]func(packet PacketInterface) - connectionEndedEventHandlers []func(connection *PRUDPConnection) - errorEventHandlers []func(err *Error) - ConnectionIDCounter *Counter[uint32] - ServerAccount *Account - AccountDetailsByPID func(pid *types.PID) (*Account, *Error) - AccountDetailsByUsername func(username string) (*Account, *Error) - IsSecureEndPoint bool -} + Server *PRUDPServer + StreamID uint8 + DefaultStreamSettings *StreamSettings + Connections *MutexMap[string, *PRUDPConnection] + packetHandlers map[uint16]func(packet PRUDPPacketInterface) + packetEventHandlers map[string][]func(packet PacketInterface) + connectionEndedEventHandlers []func(connection *PRUDPConnection) + errorEventHandlers []func(err *Error) + ConnectionIDCounter *Counter[uint32] + ServerAccount *Account + AccountDetailsByPID func(pid *types.PID) (*Account, *Error) + AccountDetailsByUsername func(username string) (*Account, *Error) + IsSecureEndPoint bool + CalcRetransmissionTimeoutCallback CalcRetransmissionTimeoutCallback +} + +// CalcRetransmissionTimeoutCallback is an optional callback which can be used to override the RTO calculation +// for packets sent by this `PRUDPEndpoint` +type CalcRetransmissionTimeoutCallback func(rtt float64, sendCount uint32) time.Duration // RegisterServiceProtocol registers a NEX service with the endpoint func (pep *PRUDPEndPoint) RegisterServiceProtocol(protocol ServiceProtocol) { @@ -111,19 +116,19 @@ func (pep *PRUDPEndPoint) processPacket(packet PRUDPPacketInterface, socket *Soc streamType := packet.SourceVirtualPortStreamType() streamID := packet.SourceVirtualPortStreamID() discriminator := fmt.Sprintf("%s-%d-%d", socket.Address.String(), streamType, streamID) - connection, ok := pep.Connections.Get(discriminator) - - if !ok { - connection = NewPRUDPConnection(socket) + connection := pep.Connections.GetOrSetDefault(discriminator, func() *PRUDPConnection { + connection := NewPRUDPConnection(socket) connection.endpoint = pep connection.ID = pep.ConnectionIDCounter.Next() connection.DefaultPRUDPVersion = packet.Version() connection.StreamType = streamType connection.StreamID = streamID connection.StreamSettings = pep.DefaultStreamSettings.Copy() + return connection + }) - pep.Connections.Set(discriminator, connection) - } + connection.Lock() + defer connection.Unlock() packet.SetSender(connection) @@ -153,8 +158,14 @@ func (pep *PRUDPEndPoint) handleAcknowledgment(packet PRUDPPacketInterface) { return } - slidingWindow := connection.SlidingWindow(packet.SubstreamID()) - slidingWindow.ResendScheduler.AcknowledgePacket(packet.SequenceID()) + if packet.Type() == constants.PingPacket { + if packet.SequenceID() == connection.outgoingPingSequenceIDCounter.Value { + connection.rtt.Adjust(time.Since(connection.lastSentPingTime)) + } + } else { + slidingWindow := connection.SlidingWindow(packet.SubstreamID()) + slidingWindow.TimeoutManager.AcknowledgePacket(packet.SequenceID()) + } } func (pep *PRUDPEndPoint) handleMultiAcknowledgment(packet PRUDPPacketInterface) { @@ -191,7 +202,7 @@ func (pep *PRUDPEndPoint) handleMultiAcknowledgment(packet PRUDPPacketInterface) // * MutexMap.Each locks the mutex, can't remove while reading. // * Have to just loop again - slidingWindow.ResendScheduler.packets.Each(func(sequenceID uint16, pending *PendingPacket) bool { + slidingWindow.TimeoutManager.packets.Each(func(sequenceID uint16, pending PRUDPPacketInterface) bool { if sequenceID <= baseSequenceID && !slices.Contains(sequenceIDs, sequenceID) { sequenceIDs = append(sequenceIDs, sequenceID) } @@ -201,7 +212,7 @@ func (pep *PRUDPEndPoint) handleMultiAcknowledgment(packet PRUDPPacketInterface) // * Actually remove the packets from the pool for _, sequenceID := range sequenceIDs { - slidingWindow.ResendScheduler.AcknowledgePacket(sequenceID) + slidingWindow.TimeoutManager.AcknowledgePacket(sequenceID) } } @@ -397,7 +408,6 @@ func (pep *PRUDPEndPoint) handleData(packet PRUDPPacketInterface) { func (pep *PRUDPEndPoint) handleDisconnect(packet PRUDPPacketInterface) { // TODO - Should we check the state here, or just let the connection disconnect at any time? - // TODO - Should we bother to set the connections state here? It's being destroyed anyway if packet.HasFlag(constants.PacketFlagNeedsAck) { pep.acknowledgePacket(packet) @@ -407,6 +417,8 @@ func (pep *PRUDPEndPoint) handleDisconnect(packet PRUDPPacketInterface) { streamID := packet.SourceVirtualPortStreamID() discriminator := fmt.Sprintf("%s-%d-%d", packet.Sender().Address().String(), streamType, streamID) if connection, ok := pep.Connections.Get(discriminator); ok { + // * We make sure to update the connection state here because we could still be attempting to + // * resend packets. connection.cleanup() pep.Connections.Delete(discriminator) } @@ -539,8 +551,6 @@ func (pep *PRUDPEndPoint) handleReliable(packet PRUDPPacketInterface) { } connection := packet.Sender().(*PRUDPConnection) - connection.Lock() - defer connection.Unlock() substreamID := packet.SubstreamID() @@ -702,6 +712,39 @@ func (pep *PRUDPEndPoint) FindConnectionByPID(pid uint64) *PRUDPConnection { return connection } +// ComputeRetransmitTimeout computes the RTO (Retransmit timeout) for a given packet +func (pep *PRUDPEndPoint) ComputeRetransmitTimeout(packet PRUDPPacketInterface) time.Duration { + connection := packet.Sender().(*PRUDPConnection) + rtt := connection.rtt + + if callback := pep.CalcRetransmissionTimeoutCallback; callback != nil { + rttAverage := rtt.GetRTTSmoothedAvg() + rttDeviation := rtt.GetRTTSmoothedDev() + return callback(rttAverage+rttDeviation*4.0, packet.SendCount()) + } + + var retransmitTimeBase int64 + if packet.Type() == constants.SynPacket { + retransmitTimeBase = int64(pep.DefaultStreamSettings.SynInitialRTT) + } else { + retransmitTimeBase = int64(pep.DefaultStreamSettings.InitialRTT) + if rtt.Initialized() { + retransmitTimeBase = int64(rtt.Average()/time.Millisecond) / 8 + } + } + + retransmitTimeBaseMultiplier := packet.SendCount() + + var retransmitMultiplier float64 + if packet.SendCount() < pep.DefaultStreamSettings.ExtraRetransmitTimeoutTrigger { + retransmitMultiplier = float64(pep.DefaultStreamSettings.RetransmitTimeoutMultiplier) + } else { + retransmitMultiplier = float64(pep.DefaultStreamSettings.ExtraRetransmitTimeoutMultiplier) + } + + return time.Duration(float64(retransmitTimeBase*int64(retransmitTimeBaseMultiplier))*retransmitMultiplier) * time.Millisecond +} + // AccessKey returns the servers sandbox access key func (pep *PRUDPEndPoint) AccessKey() string { return pep.Server.AccessKey diff --git a/prudp_packet.go b/prudp_packet.go index 93d1d9f..e66f98f 100644 --- a/prudp_packet.go +++ b/prudp_packet.go @@ -2,6 +2,7 @@ package nex import ( "crypto/rc4" + "time" "github.com/PretendoNetwork/nex-go/v2/constants" ) @@ -24,6 +25,9 @@ type PRUDPPacket struct { fragmentID uint8 payload []byte message *RMCMessage + sendCount uint32 + sentAt time.Time + timeout *Timeout } // SetSender sets the Client who sent the packet @@ -184,6 +188,32 @@ func (p *PRUDPPacket) SetRMCMessage(message *RMCMessage) { p.message = message } +// SendCount returns the number of times this packet has been sent +func (p *PRUDPPacket) SendCount() uint32 { + return p.sendCount +} + +func (p *PRUDPPacket) incrementSendCount() { + p.sendCount++ +} + +// SentAt returns the latest time that this packet has been sent +func (p *PRUDPPacket) SentAt() time.Time { + return p.sentAt +} + +func (p *PRUDPPacket) setSentAt(time time.Time) { + p.sentAt = time +} + +func (p *PRUDPPacket) getTimeout() *Timeout { + return p.timeout +} + +func (p *PRUDPPacket) setTimeout(timeout *Timeout) { + p.timeout = timeout +} + func (p *PRUDPPacket) processUnreliableCrypto() []byte { // * Since unreliable DATA packets can come in out of // * order, each packet uses a dedicated RC4 stream diff --git a/prudp_packet_interface.go b/prudp_packet_interface.go index cf5ef8f..6bb33ed 100644 --- a/prudp_packet_interface.go +++ b/prudp_packet_interface.go @@ -2,6 +2,7 @@ package nex import ( "net" + "time" "github.com/PretendoNetwork/nex-go/v2/constants" ) @@ -36,6 +37,12 @@ type PRUDPPacketInterface interface { SetPayload(payload []byte) RMCMessage() *RMCMessage SetRMCMessage(message *RMCMessage) + SendCount() uint32 + incrementSendCount() + SentAt() time.Time + setSentAt(time time.Time) + getTimeout() *Timeout + setTimeout(timeout *Timeout) decode() error setSignature(signature []byte) calculateConnectionSignature(addr net.Addr) ([]byte, error) diff --git a/prudp_server.go b/prudp_server.go index 51d8b30..68abdef 100644 --- a/prudp_server.go +++ b/prudp_server.go @@ -251,6 +251,7 @@ func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { packetCopy.SetSequenceID(connection.outgoingUnreliableSequenceIDCounter.Next()) } else if packetCopy.Type() == constants.PingPacket { packetCopy.SetSequenceID(connection.outgoingPingSequenceIDCounter.Next()) + connection.lastSentPingTime = time.Now() } else { packetCopy.SetSequenceID(0) } @@ -288,9 +289,12 @@ func (ps *PRUDPServer) sendPacket(packet PRUDPPacketInterface) { packetCopy.setSignature(packetCopy.calculateSignature(connection.SessionKey, connection.ServerConnectionSignature)) } + packetCopy.incrementSendCount() + packetCopy.setSentAt(time.Now()) + if packetCopy.HasFlag(constants.PacketFlagReliable) && packetCopy.HasFlag(constants.PacketFlagNeedsAck) { slidingWindow := connection.SlidingWindow(packetCopy.SubstreamID()) - slidingWindow.ResendScheduler.AddPacket(packetCopy) + slidingWindow.TimeoutManager.SchedulePacketTimeout(packetCopy) } ps.sendRaw(packetCopy.Sender().(*PRUDPConnection).Socket, packetCopy.Bytes()) diff --git a/resend_scheduler.go b/resend_scheduler.go deleted file mode 100644 index 8a93477..0000000 --- a/resend_scheduler.go +++ /dev/null @@ -1,148 +0,0 @@ -package nex - -import ( - "time" -) - -// TODO - REMOVE THIS ENTIRELY AND REPLACE IT WITH AN IMPLEMENTATION OF rdv::Timeout AND rdv::TimeoutManager AND USE MORE STREAM SETTINGS! - -// PendingPacket represends a packet scheduled to be resent -type PendingPacket struct { - packet PRUDPPacketInterface - lastSendTime time.Time - resendCount uint32 - isAcknowledged bool - interval time.Duration - ticker *time.Ticker - rs *ResendScheduler -} - -func (pi *PendingPacket) startResendTimer() { - pi.lastSendTime = time.Now() - pi.ticker = time.NewTicker(pi.interval) - - for range pi.ticker.C { - finished := false - - if pi.isAcknowledged { - pi.ticker.Stop() - pi.rs.packets.Delete(pi.packet.SequenceID()) - finished = true - } else { - finished = pi.rs.resendPacket(pi) - } - - if finished { - return - } - } -} - -// ResendScheduler manages the resending of reliable PRUDP packets -type ResendScheduler struct { - packets *MutexMap[uint16, *PendingPacket] -} - -// Stop kills the resend scheduler and stops all pending packets -func (rs *ResendScheduler) Stop() { - stillPending := make([]uint16, rs.packets.Size()) - - rs.packets.Each(func(sequenceID uint16, packet *PendingPacket) bool { - if !packet.isAcknowledged { - stillPending = append(stillPending, sequenceID) - } - - return false - }) - - for _, sequenceID := range stillPending { - if pendingPacket, ok := rs.packets.Get(sequenceID); ok { - pendingPacket.isAcknowledged = true // * Prevent an edge case where the ticker is already being processed - - if pendingPacket.ticker != nil { - // * This should never happen, but popped up in CTGP-7 testing? - // * Did the GC clear this before we called it? - pendingPacket.ticker.Stop() - } - - rs.packets.Delete(sequenceID) - } - } -} - -// AddPacket adds a packet to the scheduler and begins it's timer -func (rs *ResendScheduler) AddPacket(packet PRUDPPacketInterface) { - connection := packet.Sender().(*PRUDPConnection) - slidingWindow := connection.SlidingWindow(packet.SubstreamID()) - - pendingPacket := &PendingPacket{ - packet: packet, - rs: rs, - // TODO: This may not be accurate, needs more research - interval: time.Duration(slidingWindow.streamSettings.KeepAliveTimeout) * time.Millisecond, - } - - rs.packets.Set(packet.SequenceID(), pendingPacket) - - go pendingPacket.startResendTimer() -} - -// AcknowledgePacket marks a pending packet as acknowledged. It will be ignored at the next resend attempt -func (rs *ResendScheduler) AcknowledgePacket(sequenceID uint16) { - if pendingPacket, ok := rs.packets.Get(sequenceID); ok { - pendingPacket.isAcknowledged = true - } -} - -func (rs *ResendScheduler) resendPacket(pendingPacket *PendingPacket) bool { - if pendingPacket.isAcknowledged { - // * Prevent a race condition where resendPacket may be called - // * at the same time a packet is acknowledged - return false - } - - packet := pendingPacket.packet - connection := packet.Sender().(*PRUDPConnection) - slidingWindow := connection.SlidingWindow(packet.SubstreamID()) - - if pendingPacket.resendCount >= slidingWindow.streamSettings.MaxPacketRetransmissions { - // * The maximum resend count has been reached, consider the connection dead. - pendingPacket.ticker.Stop() - rs.packets.Delete(packet.SequenceID()) - connection.cleanup() // * "removed" event is dispatched here - - connection.endpoint.deleteConnectionByID(connection.ID) - - return true - } - - // TODO: This may not be accurate, needs more research - if time.Since(pendingPacket.lastSendTime) >= time.Duration(slidingWindow.streamSettings.KeepAliveTimeout)*time.Millisecond { - // * Resend the packet to the connection - server := connection.endpoint.Server - data := packet.Bytes() - server.sendRaw(connection.Socket, data) - - pendingPacket.resendCount++ - - var retransmitTimeoutMultiplier float32 - if pendingPacket.resendCount < slidingWindow.streamSettings.ExtraRestransmitTimeoutTrigger { - retransmitTimeoutMultiplier = slidingWindow.streamSettings.RetransmitTimeoutMultiplier - } else { - retransmitTimeoutMultiplier = slidingWindow.streamSettings.ExtraRetransmitTimeoutMultiplier - } - pendingPacket.interval += time.Duration(uint32(float32(slidingWindow.streamSettings.KeepAliveTimeout)*retransmitTimeoutMultiplier)) * time.Millisecond - - pendingPacket.ticker.Reset(pendingPacket.interval) - pendingPacket.lastSendTime = time.Now() - } - - return false -} - -// NewResendScheduler creates a new ResendScheduler -func NewResendScheduler() *ResendScheduler { - return &ResendScheduler{ - packets: NewMutexMap[uint16, *PendingPacket](), - } -} diff --git a/rtt.go b/rtt.go new file mode 100644 index 0000000..1b48d69 --- /dev/null +++ b/rtt.go @@ -0,0 +1,66 @@ +package nex + +import ( + "math" + "sync" + "time" +) + +const ( + alpha float64 = 1.0 / 8.0 + beta float64 = 1.0 / 4.0 + k float64 = 4.0 +) + +// RTT is an implementation of rdv::RTT. +// Used to calculate the average round trip time of reliable packets +type RTT struct { + sync.Mutex + lastRTT float64 + average float64 + variance float64 + initialized bool +} + +// Adjust updates the average RTT with the new value +func (rtt *RTT) Adjust(next time.Duration) { + // * This calculation comes from the RFC6298 which defines RTT calculation for TCP packets + rtt.Lock() + if rtt.initialized { + rtt.variance = (1.0-beta)*rtt.variance + beta*math.Abs(rtt.variance-float64(next)) + rtt.average = (1.0-alpha)*rtt.average + alpha*float64(next) + } else { + rtt.lastRTT = float64(next) + rtt.variance = float64(next) / 2 + rtt.average = float64(next) + k*rtt.variance + rtt.initialized = true + } + rtt.Unlock() +} + +// GetRTTSmoothedAvg returns the smoothed average of this RTT, it is used in calls to the custom +// RTO calculation function set on `PRUDPEndpoint::SetCalcRetransmissionTimeoutCallback` +func (rtt *RTT) GetRTTSmoothedAvg() float64 { + return rtt.average / 16 +} + +// GetRTTSmoothedDev returns the smoothed standard deviation of this RTT, it is used in calls to the custom +// RTO calculation function set on `PRUDPEndpoint::SetCalcRetransmissionTimeoutCallback` +func (rtt *RTT) GetRTTSmoothedDev() float64 { + return rtt.variance / 8 +} + +// Initialized returns a bool indicating whether this RTT has been initialized +func (rtt *RTT) Initialized() bool { + return rtt.initialized +} + +// GetRTO returns the current average +func (rtt *RTT) Average() time.Duration { + return time.Duration(rtt.average) +} + +// NewRTT returns a new RTT based on the first value +func NewRTT() *RTT { + return &RTT{} +} diff --git a/sliding_window.go b/sliding_window.go index 3e43b1b..0995b81 100644 --- a/sliding_window.go +++ b/sliding_window.go @@ -8,7 +8,7 @@ package nex type SlidingWindow struct { sequenceIDCounter *Counter[uint16] streamSettings *StreamSettings - ResendScheduler *ResendScheduler + TimeoutManager *TimeoutManager } // SetCipherKey sets the reliable substreams RC4 cipher keys @@ -35,7 +35,7 @@ func (sw *SlidingWindow) Encrypt(data []byte) ([]byte, error) { func NewSlidingWindow() *SlidingWindow { sw := &SlidingWindow{ sequenceIDCounter: NewCounter[uint16](0), - ResendScheduler: NewResendScheduler(), + TimeoutManager: NewTimeoutManager(), } return sw diff --git a/stream_settings.go b/stream_settings.go index 374b365..17eed98 100644 --- a/stream_settings.go +++ b/stream_settings.go @@ -12,17 +12,18 @@ import ( // The original library has more settings which are not present here as their use is unknown. // Not all values are used at this time, and only exist to future-proof for a later time. type StreamSettings struct { - ExtraRestransmitTimeoutTrigger uint32 // * The number of times a packet can be retransmitted before ExtraRetransmitTimeoutMultiplier is used + ExtraRetransmitTimeoutTrigger uint32 // * The number of times a packet can be retransmitted before ExtraRetransmitTimeoutMultiplier is used MaxPacketRetransmissions uint32 // * The number of times a packet can be retransmitted before the timeout time is checked KeepAliveTimeout uint32 // * Presumably the time a packet can be alive for without acknowledgement? Milliseconds? ChecksumBase uint32 // * Unused. The base value for PRUDPv0 checksum calculations FaultDetectionEnabled bool // * Unused. Presumably used to detect PIA faults? - InitialRTT uint32 // * Unused. The connections initial RTT + InitialRTT uint32 // * The initial connection RTT used for all non-SYN packets + SynInitialRTT uint32 // * The initial connection RTT used for all SYN packets EncryptionAlgorithm encryption.Algorithm // * The encryption algorithm used for packet payloads ExtraRetransmitTimeoutMultiplier float32 // * Used as part of the RTO calculations when retransmitting a packet. Only used if ExtraRestransmitTimeoutTrigger has been reached WindowSize uint32 // * Unused. The max number of (reliable?) packets allowed in a SlidingWindow CompressionAlgorithm compression.Algorithm // * The compression algorithm used for packet payloads - RTTRetransmit uint32 // * Unused. Unknown use + RTTRetransmit uint32 // * This is the number of times that a retried packet will be included in RTT calculations if we receive an ACK packet for it RetransmitTimeoutMultiplier float32 // * Used as part of the RTO calculations when retransmitting a packet. Only used if ExtraRestransmitTimeoutTrigger has not been reached MaxSilenceTime uint32 // * Presumably the time a connection can go without any packets from the other side? Milliseconds? } @@ -31,7 +32,7 @@ type StreamSettings struct { func (ss *StreamSettings) Copy() *StreamSettings { copied := NewStreamSettings() - copied.ExtraRestransmitTimeoutTrigger = ss.ExtraRestransmitTimeoutTrigger + copied.ExtraRetransmitTimeoutTrigger = ss.ExtraRetransmitTimeoutTrigger copied.MaxPacketRetransmissions = ss.MaxPacketRetransmissions copied.KeepAliveTimeout = ss.KeepAliveTimeout copied.ChecksumBase = ss.ChecksumBase @@ -50,21 +51,22 @@ func (ss *StreamSettings) Copy() *StreamSettings { // NewStreamSettings returns a new instance of StreamSettings with default params func NewStreamSettings() *StreamSettings { - // * Default values based on WATCH_DOGS. Not all values are used currently, and only + // * Default values based on WATCH_DOGS other than where stated. Not all values are used currently, and only // * exist to mimic what is seen in that game. Many are planned for future use. return &StreamSettings{ - ExtraRestransmitTimeoutTrigger: 0x32, + ExtraRetransmitTimeoutTrigger: 0x32, MaxPacketRetransmissions: 0x14, KeepAliveTimeout: 1000, ChecksumBase: 0, FaultDetectionEnabled: true, - InitialRTT: 0xFA, + InitialRTT: 0x2EE, + SynInitialRTT: 0xFA, EncryptionAlgorithm: encryption.NewRC4Encryption(), ExtraRetransmitTimeoutMultiplier: 1.0, WindowSize: 8, CompressionAlgorithm: compression.NewDummyCompression(), - RTTRetransmit: 0x32, + RTTRetransmit: 2, // * This value is taken from Xenoblade Chronicles, WATCH_DOGS sets this to 0x32 but it is then ignored. Setting this to 2 matches the TCP spec by not using resent packets in RTT calculations. RetransmitTimeoutMultiplier: 1.25, - MaxSilenceTime: 5000, + MaxSilenceTime: 10000, // * This value is taken from Xenoblade Chronicles, WATCH_DOGS sets this to 5000. } } diff --git a/timeout.go b/timeout.go new file mode 100644 index 0000000..05edc80 --- /dev/null +++ b/timeout.go @@ -0,0 +1,29 @@ +package nex + +import ( + "context" + "time" +) + +// Timeout is an implementation of rdv::Timeout. +// Used to hold state related to resend timeouts on a packet +type Timeout struct { + timeout time.Duration + ctx context.Context + cancel context.CancelFunc +} + +// SetRTO sets the timeout field on this instance +func (t *Timeout) SetRTO(timeout time.Duration) { + t.timeout = timeout +} + +// GetRTO gets the timeout field of this instance +func (t *Timeout) RTO() time.Duration { + return t.timeout +} + +// NewTimeout creates a new Timeout +func NewTimeout() *Timeout { + return &Timeout{} +} diff --git a/timeout_manager.go b/timeout_manager.go new file mode 100644 index 0000000..8bd18fe --- /dev/null +++ b/timeout_manager.go @@ -0,0 +1,103 @@ +package nex + +import ( + "context" + "time" +) + +// TimeoutManager is an implementation of rdv::TimeoutManager and manages the resending of reliable PRUDP packets +type TimeoutManager struct { + ctx context.Context + cancel context.CancelFunc + packets *MutexMap[uint16, PRUDPPacketInterface] + streamSettings *StreamSettings +} + +// SchedulePacketTimeout adds a packet to the scheduler and begins it's timer +func (tm *TimeoutManager) SchedulePacketTimeout(packet PRUDPPacketInterface) { + endpoint := packet.Sender().Endpoint().(*PRUDPEndPoint) + + rto := endpoint.ComputeRetransmitTimeout(packet) + ctx, cancel := context.WithTimeout(tm.ctx, rto) + + timeout := NewTimeout() + timeout.SetRTO(rto) + timeout.ctx = ctx + timeout.cancel = cancel + packet.setTimeout(timeout) + + tm.packets.Set(packet.SequenceID(), packet) + go tm.start(packet) +} + +// AcknowledgePacket marks a pending packet as acknowledged. It will be ignored at the next resend attempt +func (tm *TimeoutManager) AcknowledgePacket(sequenceID uint16) { + if packet, ok := tm.packets.Get(sequenceID); ok { + // * Acknowledge the packet + tm.packets.Delete(sequenceID) + + // * Update the RTT on the connection if the packet hasn't been resent + if packet.SendCount() <= tm.streamSettings.RTTRetransmit { + rttm := time.Since(packet.SentAt()) + packet.Sender().(*PRUDPConnection).rtt.Adjust(rttm) + } + } +} + +func (tm *TimeoutManager) start(packet PRUDPPacketInterface) { + <-packet.getTimeout().ctx.Done() + + connection := packet.Sender().(*PRUDPConnection) + connection.Lock() + defer connection.Unlock() + + // * If the connection is closed stop trying to resend + if connection.ConnectionState != StateConnected { + return + } + + if tm.packets.Has(packet.SequenceID()) { + // * This is `<` instead of `<=` for accuracy with observed behavior, even though we're comparing send count vs _resend_ max + if packet.SendCount() < tm.streamSettings.MaxPacketRetransmissions { + endpoint := packet.Sender().Endpoint().(*PRUDPEndPoint) + + packet.incrementSendCount() + packet.setSentAt(time.Now()) + rto := endpoint.ComputeRetransmitTimeout(packet) + + ctx, cancel := context.WithTimeout(tm.ctx, rto) + timeout := packet.getTimeout() + timeout.timeout = rto + timeout.ctx = ctx + timeout.cancel = cancel + + // * Schedule the packet to be resent + go tm.start(packet) + + // * Resend the packet to the connection + server := connection.endpoint.Server + data := packet.Bytes() + server.sendRaw(connection.Socket, data) + } else { + // * Packet has been retried too many times, consider the connection dead + connection.cleanup() + } + } +} + +// Stop kills the resend scheduler and stops all pending packets +func (tm *TimeoutManager) Stop() { + tm.cancel() + tm.packets.Clear(func(key uint16, value PRUDPPacketInterface) {}) +} + +// NewTimeoutManager creates a new TimeoutManager +func NewTimeoutManager() *TimeoutManager { + ctx, cancel := context.WithCancel(context.Background()) + return &TimeoutManager{ + ctx: ctx, + cancel: cancel, + packets: NewMutexMap[uint16, PRUDPPacketInterface](), + streamSettings: NewStreamSettings(), + } +}