From b77ae3a9478b6c7fb49ae35263a5d8b723802af7 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sun, 29 Dec 2019 11:48:09 +0100 Subject: [PATCH] move interleaved frames logic into rtsp.Conn --- main.go | 21 +++++---------- rtsp/conn.go | 71 ++++++++++++++++++++++++++++++++++++++++++++++---- rtsp_client.go | 65 ++++++++++++++++----------------------------- 3 files changed, 94 insertions(+), 63 deletions(-) diff --git a/main.go b/main.go index 76b910d2db6..465d661a083 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "encoding/binary" "fmt" "log" "net" @@ -63,45 +62,37 @@ func (p *program) run() { <-infty } -func (p *program) handleRtp(buf []byte) { +func (p *program) handleRtp(frame []byte) { p.mutex.RLock() defer p.mutex.RUnlock() - tcpHeader := [4]byte{0x24, 0x00, 0x00, 0x00} - binary.BigEndian.PutUint16(tcpHeader[2:], uint16(len(buf))) - for c := range p.clients { if c.state == "PLAY" { if c.rtpProto == "udp" { - p.rtpl.nconn.WriteTo(buf, &net.UDPAddr{ + p.rtpl.nconn.WriteTo(frame, &net.UDPAddr{ IP: c.IP, Port: c.rtpPort, }) } else { - c.nconn.Write(tcpHeader[:]) - c.nconn.Write(buf) + c.rconn.WriteInterleavedFrame(frame) } } } } -func (p *program) handleRtcp(buf []byte) { +func (p *program) handleRtcp(frame []byte) { p.mutex.RLock() defer p.mutex.RUnlock() - tcpHeader := [4]byte{0x24, 0x00, 0x00, 0x00} - binary.BigEndian.PutUint16(tcpHeader[2:], uint16(len(buf))) - for c := range p.clients { if c.state == "PLAY" { if c.rtpProto == "udp" { - p.rtcpl.nconn.WriteTo(buf, &net.UDPAddr{ + p.rtcpl.nconn.WriteTo(frame, &net.UDPAddr{ IP: c.IP, Port: c.rtcpPort, }) } else { - c.nconn.Write(tcpHeader[:]) - c.nconn.Write(buf) + c.rconn.WriteInterleavedFrame(frame) } } } diff --git a/rtsp/conn.go b/rtsp/conn.go index 31fecc7642b..a268cfa12de 100644 --- a/rtsp/conn.go +++ b/rtsp/conn.go @@ -1,25 +1,86 @@ package rtsp import ( + "encoding/binary" + "fmt" + "io" "net" ) type Conn struct { - net.Conn + c net.Conn + writeBuf []byte +} + +func NewConn(c net.Conn) *Conn { + return &Conn{ + c: c, + writeBuf: make([]byte, 2048), + } +} + +func (c *Conn) Close() error { + return c.c.Close() +} + +func (c *Conn) RemoteAddr() net.Addr { + return c.c.RemoteAddr() } func (c *Conn) ReadRequest() (*Request, error) { - return requestDecode(c) + return requestDecode(c.c) } func (c *Conn) WriteRequest(req *Request) error { - return requestEncode(c, req) + return requestEncode(c.c, req) } func (c *Conn) ReadResponse() (*Response, error) { - return responseDecode(c) + return responseDecode(c.c) } func (c *Conn) WriteResponse(res *Response) error { - return responseEncode(c, res) + return responseEncode(c.c, res) +} + +func (c *Conn) ReadInterleavedFrame(frame []byte) (int, error) { + var header [4]byte + _, err := io.ReadFull(c.c, header[:]) + if err != nil { + return 0, err + } + + // connection terminated + if header[0] == 0x54 { + return 0, io.EOF + } + + if header[0] != 0x24 { + return 0, fmt.Errorf("wrong magic byte (0x%.2x)", header[0]) + } + + framelen := binary.BigEndian.Uint16(header[2:]) + if framelen > 2048 { + return 0, fmt.Errorf("frame length greater than 2048") + } + + _, err = io.ReadFull(c.c, frame[:framelen]) + if err != nil { + return 0, err + } + + return int(framelen), nil +} + +func (c *Conn) WriteInterleavedFrame(frame []byte) error { + c.writeBuf[0] = 0x24 + c.writeBuf[1] = 0x00 + binary.BigEndian.PutUint16(c.writeBuf[2:], uint16(len(frame))) + n := copy(c.writeBuf[4:], frame) + + _, err := c.c.Write(c.writeBuf[:4+n]) + if err != nil { + return err + } + return nil } diff --git a/rtsp_client.go b/rtsp_client.go index 872639b7e41..3b54c32d7bd 100644 --- a/rtsp_client.go +++ b/rtsp_client.go @@ -1,8 +1,6 @@ package main import ( - "bufio" - "encoding/binary" "errors" "fmt" "io" @@ -23,7 +21,7 @@ var ( type rtspClient struct { p *program - nconn net.Conn + rconn *rtsp.Conn state string IP net.IP rtpProto string @@ -34,7 +32,7 @@ type rtspClient struct { func newRtspClient(p *program, nconn net.Conn) *rtspClient { c := &rtspClient{ p: p, - nconn: nconn, + rconn: rtsp.NewConn(nconn), state: "STARTING", } @@ -52,7 +50,7 @@ func (c *rtspClient) close() error { } delete(c.p.clients, c) - c.nconn.Close() + c.rconn.Close() if c.p.streamAuthor == c { c.p.streamAuthor = nil @@ -69,7 +67,7 @@ func (c *rtspClient) close() error { } func (c *rtspClient) log(format string, args ...interface{}) { - format = "[RTSP client " + c.nconn.RemoteAddr().String() + "] " + format + format = "[RTSP client " + c.rconn.RemoteAddr().String() + "] " + format log.Printf(format, args...) } @@ -81,15 +79,13 @@ func (c *rtspClient) run() { c.close() }() - ipstr, _, _ := net.SplitHostPort(c.nconn.RemoteAddr().String()) + ipstr, _, _ := net.SplitHostPort(c.rconn.RemoteAddr().String()) c.IP = net.ParseIP(ipstr) - rconn := &rtsp.Conn{c.nconn} - c.log("connected") for { - req, err := rconn.ReadRequest() + req, err := c.rconn.ReadRequest() if err != nil { if err != io.EOF { c.log("ERR: %s", err) @@ -104,7 +100,7 @@ func (c *rtspClient) run() { switch err { // normal response case nil: - err = rconn.WriteResponse(res) + err = c.rconn.WriteResponse(res) if err != nil { c.log("ERR: %s", err) return @@ -119,7 +115,7 @@ func (c *rtspClient) run() { // before the response // then switch to RTP if TCP case errPlay: - err = rconn.WriteResponse(res) + err = c.rconn.WriteResponse(res) if err != nil { c.log("ERR: %s", err) return @@ -134,10 +130,13 @@ func (c *rtspClient) run() { // when rtp protocol is TCP, the RTSP connection becomes a RTP connection // receive RTP feedback, do not parse it, wait until connection closes if c.rtpProto == "tcp" { - buf := make([]byte, 1024) + buf := make([]byte, 2048) for { - _, err := c.nconn.Read(buf) + _, err := c.rconn.ReadInterleavedFrame(buf) if err != nil { + if err != io.EOF { + c.log("ERR: %s", err) + } return } } @@ -145,7 +144,7 @@ func (c *rtspClient) run() { // RECORD: switch to RTP if TCP case errRecord: - err = rconn.WriteResponse(res) + err = c.rconn.WriteResponse(res) if err != nil { c.log("ERR: %s", err) return @@ -160,37 +159,17 @@ func (c *rtspClient) run() { // when rtp protocol is TCP, the RTSP connection becomes a RTP connection // receive RTP data and parse it if c.rtpProto == "tcp" { - packet := make([]byte, 2048) - bconn := bufio.NewReader(c.nconn) + buf := make([]byte, 2048) for { - byts, err := bconn.Peek(4) - if err != nil { - return - } - bconn.Discard(4) - - if byts[0] != 0x24 { - c.log("ERR: wrong magic byte") - return - } - - if byts[1] != 0x00 { - c.log("ERR: wrong channel") - return - } - - plen := binary.BigEndian.Uint16(byts[2:]) - if plen > 2048 { - c.log("ERR: packet len > 2048") - return - } - - _, err = io.ReadFull(bconn, packet[:plen]) + n, err := c.rconn.ReadInterleavedFrame(buf) if err != nil { + if err != io.EOF { + c.log("ERR: %s", err) + } return } - c.p.handleRtp(packet[:plen]) + c.p.handleRtp(buf[:n]) } } @@ -199,7 +178,7 @@ func (c *rtspClient) run() { c.log("ERR: %s", err) if cseq, ok := req.Headers["cseq"]; ok { - rconn.WriteResponse(&rtsp.Response{ + c.rconn.WriteResponse(&rtsp.Response{ StatusCode: 400, Status: "Bad Request", Headers: map[string]string{ @@ -207,7 +186,7 @@ func (c *rtspClient) run() { }, }) } else { - rconn.WriteResponse(&rtsp.Response{ + c.rconn.WriteResponse(&rtsp.Response{ StatusCode: 400, Status: "Bad Request", })