diff --git a/share/tunnel/tunnel_out_ssh_udp.go b/share/tunnel/tunnel_out_ssh_udp.go index d3c4c62f..9f184b22 100644 --- a/share/tunnel/tunnel_out_ssh_udp.go +++ b/share/tunnel/tunnel_out_ssh_udp.go @@ -28,6 +28,7 @@ func (t *Tunnel) handleUDP(l *cio.Logger, rwc io.ReadWriteCloser, hostPort strin }, udpConns: conns, maxMTU: settings.EnvInt("UDP_MAX_SIZE", 9012), + maxConns: settings.EnvInt("UDP_MAX_CONNS", 100), } h.Debugf("UDP max size: %d bytes", h.maxMTU) for { @@ -43,7 +44,8 @@ type udpHandler struct { hostPort string *udpChannel *udpConns - maxMTU int + maxMTU int + maxConns int } func (h *udpHandler) handleWrite(p *udpPacket) error { @@ -62,12 +64,14 @@ func (h *udpHandler) handleWrite(p *udpPacket) error { //TODO++ dont use go-routines, switch to pollable // array of listeners where all listeners are // sweeped periodically, removing the idle ones - const maxConns = 100 + if !exists { - if h.udpConns.len() <= maxConns { + if h.udpConns.len() <= h.maxConns { go h.handleRead(p, conn) } else { - h.Debugf("exceeded max udp connections (%d)", maxConns) + //write only + h.udpConns.setCleanUpTimer(conn.id) + h.Debugf("exceeded max udp connections (%d)", h.maxConns) } } _, err = conn.Write(p.Payload) @@ -79,7 +83,14 @@ func (h *udpHandler) handleWrite(p *udpPacket) error { func (h *udpHandler) handleRead(p *udpPacket, conn *udpConn) { //ensure connection is cleaned up - defer h.udpConns.remove(conn.id) + defer func() { + h.udpConns.remove(conn.id) + conn.Close() + }() + if h.udpConns.len() > h.maxConns { + h.Debugf("exceeded max udp connections (%d)", h.maxConns) + return + } buff := make([]byte, h.maxMTU) for { //response must arrive within 15 seconds @@ -149,7 +160,35 @@ func (cs *udpConns) closeAll() { cs.Unlock() } +func (cs *udpConns) setCleanUpTimer(id string) { + cs.Lock() + defer cs.Unlock() + conn, ok := cs.m[id] + if ok { + conn.writeTimer = time.AfterFunc(settings.EnvDuration("UDP_DEADLINE", 15*time.Second), func() { + cs.remove(conn.id) + conn.Close() + }) + } +} + type udpConn struct { id string net.Conn + writeTimer *time.Timer +} + +func (w *udpConn) Write(b []byte) (int, error) { + if w.writeTimer != nil { + w.writeTimer.Stop() + w.writeTimer.Reset(settings.EnvDuration("UDP_DEADLINE", 15*time.Second)) + } + return w.Conn.Write(b) +} + +func (w *udpConn) Close() error { + if w.writeTimer != nil { + w.writeTimer.Stop() + } + return w.Conn.Close() }