diff --git a/ppp/pptp/gre.go b/ppp/pptp/gre.go index 799f6fa..7f7cdaa 100644 --- a/ppp/pptp/gre.go +++ b/ppp/pptp/gre.go @@ -5,20 +5,23 @@ import ( "fmt" "io" "net" + "sync" "github.com/google/gopacket" "github.com/google/gopacket/layers" ) const ( - greProtocol = 47 + greProtocol = 47 + recvQueueSize = 4 ) var ( wrongLayers = errors.New("layers not as expected: want IP->GRE") wrongGREFields = errors.New("GRE fields wrong: want version=1, ethernet type PPP") - wrongSession = errors.New("packet for a different GRE session") + unknownSession = errors.New("packet for an unknown GRE session") outOfSequencePacket = errors.New("out of sequence packet received") + recvQueueOverflow = errors.New("session receive queue is full") ) var _ = (io.ReadWriteCloser)(&greSession{}) @@ -26,31 +29,21 @@ var _ = (io.ReadWriteCloser)(&greSession{}) // greSession is used to send and receive packets for a particular PPP-over-GRE // session. type greSession struct { - conn net.Conn + s *greServer + closed bool + recvQueue chan gopacket.Packet + addr net.IP sendCallID, recvCallID uint16 sentSeq, recvSeq, recvAcked uint32 } func (s *greSession) recvPacket(p []byte) (int, error) { - cnt, err := s.conn.Read(p) - if err != nil { - return 0, err + pkt, ok := <-s.recvQueue + if !ok { + return 0, io.EOF } - pkt := gopacket.NewPacket(p[:cnt], layers.LayerTypeIPv4, gopacket.NoCopy) ls := pkt.Layers() - if len(ls) < 2 || ls[0].LayerType() != layers.LayerTypeIPv4 || ls[1].LayerType() != layers.LayerTypeGRE { - return 0, wrongLayers - } greHeader := ls[1].(*layers.GRE) - if greHeader.Version != 1 || greHeader.Protocol != layers.EthernetTypePPP { - return 0, wrongGREFields - } - // In PPTP modified GRE, the bottom two octets of the key field are - // used to contain the call ID. - callID := uint16(greHeader.Key & 0xffff) - if !greHeader.KeyPresent || callID != s.recvCallID { - return 0, wrongSession - } // RFC 2637 mandates that "out of sequence packets between the PNS and // PAC MUST be silently discarded [or reordered]" because PPP cannot // handle out-of-order packets. @@ -79,7 +72,7 @@ func (s *greSession) Read(p []byte) (int, error) { if cnt > 0 { return cnt, nil } - case wrongLayers, wrongGREFields, wrongSession, outOfSequencePacket: + case outOfSequencePacket: // try again default: return 0, err @@ -110,7 +103,9 @@ func (s *greSession) sendPacket(frame []byte) (int, error) { greHeader, gopacket.Payload(frame), ) - return s.conn.Write(buf.Bytes()) + return s.s.conn.WriteToIP(buf.Bytes(), &net.IPAddr{ + IP: s.addr, + }) } func (s *greSession) Write(frame []byte) (int, error) { @@ -118,17 +113,114 @@ func (s *greSession) Write(frame []byte) (int, error) { } func (s *greSession) Close() error { - return s.conn.Close() + sk := s.sessionKey() + s.s.mu.Lock() + defer s.s.mu.Unlock() + if !s.closed { + delete(s.s.sessions, *sk) + close(s.recvQueue) + s.closed = true + } + return nil } -func startGRESession(remoteAddr net.IP, sendCallID, recvCallID uint16) (*greSession, error) { - conn, err := net.Dial(fmt.Sprintf("ip4:%d", greProtocol), remoteAddr.String()) +func (s *greSession) sessionKey() *sessionKey { + return &sessionKey{ + IP: s.addr.String(), + CallID: s.recvCallID, + } +} + +type sessionKey struct { + IP string + CallID uint16 +} + +type greServer struct { + conn *net.IPConn + sessions map[sessionKey]*greSession + mu sync.Mutex +} + +func startGREServer() (*greServer, error) { + conn, err := net.ListenIP(fmt.Sprintf("ip4:%d", greProtocol), nil) if err != nil { return nil, err } - return &greSession{ - conn: conn, + return &greServer{ + conn: conn, + sessions: make(map[sessionKey]*greSession), + }, nil +} + +func (s *greServer) startSession(remoteAddr net.IP, sendCallID, recvCallID uint16) (*greSession, error) { + session := &greSession{ + s: s, + addr: remoteAddr, + recvQueue: make(chan gopacket.Packet, recvQueueSize), sendCallID: sendCallID, recvCallID: recvCallID, - }, nil + } + sk := session.sessionKey() + s.mu.Lock() + s.sessions[*sk] = session + s.mu.Unlock() + return session, nil +} + +func (s *greServer) processPacket(pkt gopacket.Packet) error { + ls := pkt.Layers() + if len(ls) < 2 || ls[0].LayerType() != layers.LayerTypeIPv4 || ls[1].LayerType() != layers.LayerTypeGRE { + return wrongLayers + } + ipHeader := ls[0].(*layers.IPv4) + greHeader := ls[1].(*layers.GRE) + if greHeader.Version != 1 || greHeader.Protocol != layers.EthernetTypePPP { + return wrongGREFields + } + // In PPTP modified GRE, the bottom two octets of the key field are + // used to contain the call ID. + if !greHeader.KeyPresent { + return wrongGREFields + } + sk := &sessionKey{ + IP: ipHeader.SrcIP.String(), + CallID: uint16(greHeader.Key & 0xffff), + } + s.mu.Lock() + defer s.mu.Unlock() + session, ok := s.sessions[*sk] + if !ok || session.closed { + return unknownSession + } + // Try to place onto session's receive queue, but don't block. + select { + case session.recvQueue <- pkt: + return nil + default: + return recvQueueOverflow + } +} + +func (s *greServer) Run() error { + var recvBuf [1500]byte + for { + cnt, err := s.conn.Read(recvBuf[:]) + if err != nil { + return err + } + pkt := gopacket.NewPacket(recvBuf[:cnt], layers.LayerTypeIPv4, gopacket.Default) + // TODO: Log errors returned by processPacket? + s.processPacket(pkt) + } +} + +func (s *greServer) Close() error { + s.mu.Lock() + for _, session := range s.sessions { + close(session.recvQueue) + session.closed = true + } + s.mu.Unlock() + return s.conn.Close() } diff --git a/ppp/pptp/server.go b/ppp/pptp/server.go index 6c3dbbb..5d3bc55 100644 --- a/ppp/pptp/server.go +++ b/ppp/pptp/server.go @@ -115,7 +115,7 @@ func (c *Connection) startPPPSession(ctx context.Context, sendCallID uint16) { } addr := c.conn.RemoteAddr().(*net.TCPAddr) var err error - gre, err := startGRESession(addr.IP, sendCallID, c.callID) + gre, err := c.s.greServer.startSession(addr.IP, sendCallID, c.callID) if err != nil { // TODO: Send back error message? Log error? c.conn.Close() @@ -238,11 +238,13 @@ type Server struct { listener *net.TCPListener nextCallID uint16 n network.Network + greServer *greServer } // Run listens for and accepts new connections to the server. It blocks until // the server is shut down, so it should be invoked in a dedicated goroutine. func (s *Server) Run(ctx context.Context) { + go s.greServer.Run() for { conn, err := s.listener.Accept() if err != nil { @@ -257,19 +259,26 @@ func (s *Server) Run(ctx context.Context) { } func (s *Server) Close() error { + s.greServer.Close() return s.listener.Close() } func NewServer(n network.Network) (*Server, error) { + gs, err := startGREServer() + if err != nil { + return nil, err + } listener, err := net.ListenTCP("tcp", &net.TCPAddr{ Port: pptpPort, }) if err != nil { + gs.Close() return nil, err } return &Server{ listener: listener, nextCallID: 384, n: n, + greServer: gs, }, nil }