Skip to content

Commit

Permalink
Consolidate usage counting inside a listenAddress type.
Browse files Browse the repository at this point in the history
  • Loading branch information
sbruens committed Jul 22, 2024
1 parent 7a15e7d commit 499829e
Showing 1 changed file with 48 additions and 63 deletions.
111 changes: 48 additions & 63 deletions service/listeners.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,31 +88,50 @@ func (spc *sharedPacketConn) Close() error {
return spc.onCloseFunc()
}

type concreteListener struct {
type listenAddr struct {
ln *net.TCPListener
pc net.PacketConn
usage atomic.Int32
acceptCh chan acceptResponse
onCloseFunc func() // Called when the listener's last user closes it.
}

func (cl *concreteListener) Close() error {
if cl.usage.Add(-1) == 0 {
if cl.ln != nil {
err := cl.ln.Close()
if err != nil {
return err
func (cl *listenAddr) NewStreamListener() StreamListener {
cl.usage.Add(1)
sl := &sharedListener{
listener: *cl.ln,
closeCh: make(chan struct{}),
onCloseFunc: func() error {
if cl.usage.Add(-1) == 0 {
err := cl.ln.Close()
if err != nil {
return err
}
cl.onCloseFunc()
}
}
if cl.pc != nil {
err := cl.pc.Close()
if err != nil {
return err
return nil
},
}
sl.acceptCh = &atomic.Value{}
sl.acceptCh.Store(cl.acceptCh)
return sl
}

func (cl *listenAddr) NewPacketListener() net.PacketConn {
cl.usage.Add(1)
return &sharedPacketConn{
PacketConn: cl.pc,
onCloseFunc: func() error {
if cl.usage.Add(-1) == 0 {
err := cl.pc.Close()
if err != nil {
return err
}
cl.onCloseFunc()
}
}
cl.onCloseFunc()
return nil
},
}
return nil
}

// ListenerManager holds and manages the state of shared listeners.
Expand All @@ -122,14 +141,14 @@ type ListenerManager interface {
}

type listenerManager struct {
listeners map[string]*concreteListener
listeners map[string]*listenAddr
listenersMu sync.Mutex
}

// NewListenerManager creates a new [ListenerManger].
func NewListenerManager() ListenerManager {
return &listenerManager{
listeners: make(map[string]*concreteListener),
listeners: make(map[string]*listenAddr),
}
}

Expand All @@ -144,18 +163,8 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe
defer m.listenersMu.Unlock()

lnKey := listenerKey(network, addr)
if lnConcrete, ok := m.listeners[lnKey]; ok {
lnConcrete.usage.Add(1)
sl := &sharedListener{
listener: *lnConcrete.ln,
closeCh: make(chan struct{}),
onCloseFunc: func() error {
return lnConcrete.Close()
},
}
sl.acceptCh = &atomic.Value{}
sl.acceptCh.Store(lnConcrete.acceptCh)
return sl, nil
if listenAddress, ok := m.listeners[lnKey]; ok {
return listenAddress.NewStreamListener(), nil
}

tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
Expand All @@ -167,7 +176,7 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe
return nil, err
}

lnConcrete := &concreteListener{
listenAddress := &listenAddr{
ln: ln,
acceptCh: make(chan acceptResponse),
onCloseFunc: func() {
Expand All @@ -176,26 +185,15 @@ func (m *listenerManager) ListenStream(network string, addr string) (StreamListe
}
go func() {
for {
conn, err := lnConcrete.ln.AcceptTCP()
conn, err := listenAddress.ln.AcceptTCP()
if errors.Is(err, net.ErrClosed) {
return
}
lnConcrete.acceptCh <- acceptResponse{conn, err}
listenAddress.acceptCh <- acceptResponse{conn, err}
}
}()
lnConcrete.usage.Store(1)
m.listeners[lnKey] = lnConcrete

sl := &sharedListener{
listener: *lnConcrete.ln,
closeCh: make(chan struct{}),
onCloseFunc: func() error {
return lnConcrete.Close()
},
}
sl.acceptCh = &atomic.Value{}
sl.acceptCh.Store(lnConcrete.acceptCh)
return sl, nil
m.listeners[lnKey] = listenAddress
return listenAddress.NewStreamListener(), nil
}

// ListenPacket creates a new packet listener for a given network and address.
Expand All @@ -206,36 +204,23 @@ func (m *listenerManager) ListenPacket(network string, addr string) (net.PacketC
defer m.listenersMu.Unlock()

lnKey := listenerKey(network, addr)
if lnConcrete, ok := m.listeners[lnKey]; ok {
lnConcrete.usage.Add(1)
return &sharedPacketConn{
PacketConn: lnConcrete.pc,
onCloseFunc: func() error {
return lnConcrete.Close()
},
}, nil
if listenAddress, ok := m.listeners[lnKey]; ok {
return listenAddress.NewPacketListener(), nil
}

pc, err := net.ListenPacket(network, addr)
if err != nil {
return nil, err
}

lnConcrete := &concreteListener{
listenAddress := &listenAddr{
pc: pc,
onCloseFunc: func() {
m.delete(lnKey)
},
}
lnConcrete.usage.Store(1)
m.listeners[lnKey] = lnConcrete

return &sharedPacketConn{
PacketConn: pc,
onCloseFunc: func() error {
return lnConcrete.Close()
},
}, nil
m.listeners[lnKey] = listenAddress
return listenAddress.NewPacketListener(), nil
}

func (m *listenerManager) delete(key string) {
Expand Down

0 comments on commit 499829e

Please sign in to comment.