Skip to content

Commit

Permalink
move interleaved frames logic into rtsp.Conn
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 committed Dec 29, 2019
1 parent 8a773a1 commit b77ae3a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 63 deletions.
21 changes: 6 additions & 15 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"encoding/binary"
"fmt"
"log"
"net"
Expand Down Expand Up @@ -63,45 +62,37 @@ func (p *program) run() {
<-infty
}

func (p *program) handleRtp(buf []byte) {
func (p *program) handleRtp(frame []byte) {
p.mutex.RLock()
defer p.mutex.RUnlock()

tcpHeader := [4]byte{0x24, 0x00, 0x00, 0x00}
binary.BigEndian.PutUint16(tcpHeader[2:], uint16(len(buf)))

for c := range p.clients {
if c.state == "PLAY" {
if c.rtpProto == "udp" {
p.rtpl.nconn.WriteTo(buf, &net.UDPAddr{
p.rtpl.nconn.WriteTo(frame, &net.UDPAddr{
IP: c.IP,
Port: c.rtpPort,
})
} else {
c.nconn.Write(tcpHeader[:])
c.nconn.Write(buf)
c.rconn.WriteInterleavedFrame(frame)
}
}
}
}

func (p *program) handleRtcp(buf []byte) {
func (p *program) handleRtcp(frame []byte) {
p.mutex.RLock()
defer p.mutex.RUnlock()

tcpHeader := [4]byte{0x24, 0x00, 0x00, 0x00}
binary.BigEndian.PutUint16(tcpHeader[2:], uint16(len(buf)))

for c := range p.clients {
if c.state == "PLAY" {
if c.rtpProto == "udp" {
p.rtcpl.nconn.WriteTo(buf, &net.UDPAddr{
p.rtcpl.nconn.WriteTo(frame, &net.UDPAddr{
IP: c.IP,
Port: c.rtcpPort,
})
} else {
c.nconn.Write(tcpHeader[:])
c.nconn.Write(buf)
c.rconn.WriteInterleavedFrame(frame)
}
}
}
Expand Down
71 changes: 66 additions & 5 deletions rtsp/conn.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,86 @@
package rtsp

import (
"encoding/binary"
"fmt"
"io"
"net"
)

type Conn struct {
net.Conn
c net.Conn
writeBuf []byte
}

func NewConn(c net.Conn) *Conn {
return &Conn{
c: c,
writeBuf: make([]byte, 2048),
}
}

func (c *Conn) Close() error {
return c.c.Close()
}

func (c *Conn) RemoteAddr() net.Addr {
return c.c.RemoteAddr()
}

func (c *Conn) ReadRequest() (*Request, error) {
return requestDecode(c)
return requestDecode(c.c)
}

func (c *Conn) WriteRequest(req *Request) error {
return requestEncode(c, req)
return requestEncode(c.c, req)
}

func (c *Conn) ReadResponse() (*Response, error) {
return responseDecode(c)
return responseDecode(c.c)
}

func (c *Conn) WriteResponse(res *Response) error {
return responseEncode(c, res)
return responseEncode(c.c, res)
}

func (c *Conn) ReadInterleavedFrame(frame []byte) (int, error) {
var header [4]byte
_, err := io.ReadFull(c.c, header[:])
if err != nil {
return 0, err
}

// connection terminated
if header[0] == 0x54 {
return 0, io.EOF
}

if header[0] != 0x24 {
return 0, fmt.Errorf("wrong magic byte (0x%.2x)", header[0])
}

framelen := binary.BigEndian.Uint16(header[2:])
if framelen > 2048 {
return 0, fmt.Errorf("frame length greater than 2048")
}

_, err = io.ReadFull(c.c, frame[:framelen])
if err != nil {
return 0, err
}

return int(framelen), nil
}

func (c *Conn) WriteInterleavedFrame(frame []byte) error {
c.writeBuf[0] = 0x24
c.writeBuf[1] = 0x00
binary.BigEndian.PutUint16(c.writeBuf[2:], uint16(len(frame)))
n := copy(c.writeBuf[4:], frame)

_, err := c.c.Write(c.writeBuf[:4+n])
if err != nil {
return err
}
return nil
}
65 changes: 22 additions & 43 deletions rtsp_client.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package main

import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"io"
Expand All @@ -23,7 +21,7 @@ var (

type rtspClient struct {
p *program
nconn net.Conn
rconn *rtsp.Conn
state string
IP net.IP
rtpProto string
Expand All @@ -34,7 +32,7 @@ type rtspClient struct {
func newRtspClient(p *program, nconn net.Conn) *rtspClient {
c := &rtspClient{
p: p,
nconn: nconn,
rconn: rtsp.NewConn(nconn),
state: "STARTING",
}

Expand All @@ -52,7 +50,7 @@ func (c *rtspClient) close() error {
}

delete(c.p.clients, c)
c.nconn.Close()
c.rconn.Close()

if c.p.streamAuthor == c {
c.p.streamAuthor = nil
Expand All @@ -69,7 +67,7 @@ func (c *rtspClient) close() error {
}

func (c *rtspClient) log(format string, args ...interface{}) {
format = "[RTSP client " + c.nconn.RemoteAddr().String() + "] " + format
format = "[RTSP client " + c.rconn.RemoteAddr().String() + "] " + format
log.Printf(format, args...)
}

Expand All @@ -81,15 +79,13 @@ func (c *rtspClient) run() {
c.close()
}()

ipstr, _, _ := net.SplitHostPort(c.nconn.RemoteAddr().String())
ipstr, _, _ := net.SplitHostPort(c.rconn.RemoteAddr().String())
c.IP = net.ParseIP(ipstr)

rconn := &rtsp.Conn{c.nconn}

c.log("connected")

for {
req, err := rconn.ReadRequest()
req, err := c.rconn.ReadRequest()
if err != nil {
if err != io.EOF {
c.log("ERR: %s", err)
Expand All @@ -104,7 +100,7 @@ func (c *rtspClient) run() {
switch err {
// normal response
case nil:
err = rconn.WriteResponse(res)
err = c.rconn.WriteResponse(res)
if err != nil {
c.log("ERR: %s", err)
return
Expand All @@ -119,7 +115,7 @@ func (c *rtspClient) run() {
// before the response
// then switch to RTP if TCP
case errPlay:
err = rconn.WriteResponse(res)
err = c.rconn.WriteResponse(res)
if err != nil {
c.log("ERR: %s", err)
return
Expand All @@ -134,18 +130,21 @@ func (c *rtspClient) run() {
// when rtp protocol is TCP, the RTSP connection becomes a RTP connection
// receive RTP feedback, do not parse it, wait until connection closes
if c.rtpProto == "tcp" {
buf := make([]byte, 1024)
buf := make([]byte, 2048)
for {
_, err := c.nconn.Read(buf)
_, err := c.rconn.ReadInterleavedFrame(buf)
if err != nil {
if err != io.EOF {
c.log("ERR: %s", err)
}
return
}
}
}

// RECORD: switch to RTP if TCP
case errRecord:
err = rconn.WriteResponse(res)
err = c.rconn.WriteResponse(res)
if err != nil {
c.log("ERR: %s", err)
return
Expand All @@ -160,37 +159,17 @@ func (c *rtspClient) run() {
// when rtp protocol is TCP, the RTSP connection becomes a RTP connection
// receive RTP data and parse it
if c.rtpProto == "tcp" {
packet := make([]byte, 2048)
bconn := bufio.NewReader(c.nconn)
buf := make([]byte, 2048)
for {
byts, err := bconn.Peek(4)
if err != nil {
return
}
bconn.Discard(4)

if byts[0] != 0x24 {
c.log("ERR: wrong magic byte")
return
}

if byts[1] != 0x00 {
c.log("ERR: wrong channel")
return
}

plen := binary.BigEndian.Uint16(byts[2:])
if plen > 2048 {
c.log("ERR: packet len > 2048")
return
}

_, err = io.ReadFull(bconn, packet[:plen])
n, err := c.rconn.ReadInterleavedFrame(buf)
if err != nil {
if err != io.EOF {
c.log("ERR: %s", err)
}
return
}

c.p.handleRtp(packet[:plen])
c.p.handleRtp(buf[:n])
}
}

Expand All @@ -199,15 +178,15 @@ func (c *rtspClient) run() {
c.log("ERR: %s", err)

if cseq, ok := req.Headers["cseq"]; ok {
rconn.WriteResponse(&rtsp.Response{
c.rconn.WriteResponse(&rtsp.Response{
StatusCode: 400,
Status: "Bad Request",
Headers: map[string]string{
"CSeq": cseq,
},
})
} else {
rconn.WriteResponse(&rtsp.Response{
c.rconn.WriteResponse(&rtsp.Response{
StatusCode: 400,
Status: "Bad Request",
})
Expand Down

0 comments on commit b77ae3a

Please sign in to comment.