Skip to content

Commit

Permalink
add --read-ips and --publish-ips arguments; fix #12
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 committed Jun 15, 2020
1 parent e644587 commit fbd9f74
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 35 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ Flags:
--write-timeout=5s timeout of write operations
--publish-user="" optional username required to publish
--publish-pass="" optional password required to publish
--publish-ips="" comma-separated list of IPs or networks (x.x.x.x/24) that can publish
--read-user="" optional username required to read
--read-pass="" optional password required to read
--read-ips="" comma-separated list of IPs or networks (x.x.x.x/24) that can read
--pre-script="" optional script to run on client connect
--post-script="" optional script to run on client disconnect
```
Expand Down
62 changes: 53 additions & 9 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"fmt"
"log"
"net"
"os"
"regexp"
"strings"
Expand All @@ -13,6 +14,30 @@ import (

var Version string = "v0.0.0"

func parseIpCidrList(in string) ([]interface{}, error) {
if in == "" {
return nil, nil
}

var ret []interface{}
for _, t := range strings.Split(in, ",") {
_, ipnet, err := net.ParseCIDR(t)
if err == nil {
ret = append(ret, ipnet)
continue
}

ip := net.ParseIP(t)
if ip != nil {
ret = append(ret, ip)
continue
}

return nil, fmt.Errorf("unable to parse ip/network '%s'", t)
}
return ret, nil
}

type trackFlow int

const (
Expand Down Expand Up @@ -49,18 +74,22 @@ type args struct {
writeTimeout time.Duration
publishUser string
publishPass string
publishIps string
readUser string
readPass string
readIps string
preScript string
postScript string
}

type program struct {
args args
protocols map[streamProtocol]struct{}
tcpl *serverTcpListener
udplRtp *serverUdpListener
udplRtcp *serverUdpListener
args args
protocols map[streamProtocol]struct{}
publishIps []interface{}
readIps []interface{}
tcpl *serverTcpListener
udplRtp *serverUdpListener
udplRtcp *serverUdpListener
}

func newProgram(sargs []string) (*program, error) {
Expand All @@ -76,8 +105,10 @@ func newProgram(sargs []string) (*program, error) {
argWriteTimeout := kingpin.Flag("write-timeout", "timeout of write operations").Default("5s").Duration()
argPublishUser := kingpin.Flag("publish-user", "optional username required to publish").Default("").String()
argPublishPass := kingpin.Flag("publish-pass", "optional password required to publish").Default("").String()
argPublishIps := kingpin.Flag("publish-ips", "comma-separated list of IPs or networks (x.x.x.x/24) that can publish").Default("").String()
argReadUser := kingpin.Flag("read-user", "optional username required to read").Default("").String()
argReadPass := kingpin.Flag("read-pass", "optional password required to read").Default("").String()
argReadIps := kingpin.Flag("read-ips", "comma-separated list of IPs or networks (x.x.x.x/24) that can read").Default("").String()
argPreScript := kingpin.Flag("pre-script", "optional script to run on client connect").Default("").String()
argPostScript := kingpin.Flag("post-script", "optional script to run on client disconnect").Default("").String()

Expand All @@ -93,8 +124,10 @@ func newProgram(sargs []string) (*program, error) {
writeTimeout: *argWriteTimeout,
publishUser: *argPublishUser,
publishPass: *argPublishPass,
publishIps: *argPublishIps,
readUser: *argReadUser,
readPass: *argReadPass,
readIps: *argReadIps,
preScript: *argPreScript,
postScript: *argPostScript,
}
Expand All @@ -120,12 +153,14 @@ func newProgram(sargs []string) (*program, error) {
if len(protocols) == 0 {
return nil, fmt.Errorf("no protocols provided")
}

if (args.rtpPort % 2) != 0 {
return nil, fmt.Errorf("rtp port must be even")
}
if args.rtcpPort != (args.rtpPort + 1) {
return nil, fmt.Errorf("rtcp and rtp ports must be consecutive")
}

if args.publishUser != "" {
if !regexp.MustCompile("^[a-zA-Z0-9]+$").MatchString(args.publishUser) {
return nil, fmt.Errorf("publish username must be alphanumeric")
Expand All @@ -136,6 +171,11 @@ func newProgram(sargs []string) (*program, error) {
return nil, fmt.Errorf("publish password must be alphanumeric")
}
}
publishIps, err := parseIpCidrList(args.publishIps)
if err != nil {
return nil, err
}

if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" {
return nil, fmt.Errorf("read username and password must be both filled")
}
Expand All @@ -152,16 +192,20 @@ func newProgram(sargs []string) (*program, error) {
if args.readUser != "" && args.readPass == "" || args.readUser == "" && args.readPass != "" {
return nil, fmt.Errorf("read username and password must be both filled")
}
readIps, err := parseIpCidrList(args.readIps)
if err != nil {
return nil, err
}

log.Printf("rtsp-simple-server %s", Version)

p := &program{
args: args,
protocols: protocols,
args: args,
protocols: protocols,
publishIps: publishIps,
readIps: readIps,
}

var err error

p.udplRtp, err = newServerUdpListener(p, args.rtpPort, _TRACK_FLOW_RTP)
if err != nil {
return nil, err
Expand Down
2 changes: 2 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ func TestPublishAuth(t *testing.T) {
p, err := newProgram([]string{
"--publish-user=testuser",
"--publish-pass=testpass",
"--publish-ips=172.17.0.0/16",
})
require.NoError(t, err)
defer p.close()
Expand Down Expand Up @@ -185,6 +186,7 @@ func TestReadAuth(t *testing.T) {
p, err := newProgram([]string{
"--read-user=testuser",
"--read-pass=testpass",
"--read-ips=172.17.0.0/16",
})
require.NoError(t, err)
defer p.close()
Expand Down
86 changes: 60 additions & 26 deletions server-client.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,36 +202,70 @@ func (c *serverClient) writeResError(req *gortsplib.Request, code gortsplib.Stat
var errAuthCritical = errors.New("auth critical")
var errAuthNotCritical = errors.New("auth not critical")

func (c *serverClient) validateAuth(req *gortsplib.Request, user string, pass string, auth **gortsplib.AuthServer) error {
if user == "" {
return nil
}
func (c *serverClient) validateAuth(req *gortsplib.Request, user string, pass string, auth **gortsplib.AuthServer, ips []interface{}) error {
err := func() error {
if ips == nil {
return nil
}

initialRequest := false
if *auth == nil {
initialRequest = true
*auth = gortsplib.NewAuthServer(user, pass, nil)
}
connIp := c.conn.NetConn().LocalAddr().(*net.TCPAddr).IP

err := (*auth).ValidateHeader(req.Header["Authorization"], req.Method, req.Url)
if err != nil {
if !initialRequest {
c.log("ERR: Unauthorized: %s", err)
for _, item := range ips {
switch titem := item.(type) {
case net.IP:
if titem.Equal(connIp) {
return nil
}

case *net.IPNet:
if titem.Contains(connIp) {
return nil
}
}
}

c.conn.WriteResponse(&gortsplib.Response{
StatusCode: gortsplib.StatusUnauthorized,
Header: gortsplib.Header{
"CSeq": []string{req.Header["CSeq"][0]},
"WWW-Authenticate": (*auth).GenerateHeader(),
},
})
c.log("ERR: ip '%s' not allowed", connIp)
return errAuthCritical
}()
if err != nil {
return err
}

err = func() error {
if user == "" {
return nil
}

if !initialRequest {
return errAuthCritical
initialRequest := false
if *auth == nil {
initialRequest = true
*auth = gortsplib.NewAuthServer(user, pass, nil)
}

return errAuthNotCritical
err := (*auth).ValidateHeader(req.Header["Authorization"], req.Method, req.Url)
if err != nil {
if !initialRequest {
c.log("ERR: unauthorized: %s", err)
}

c.conn.WriteResponse(&gortsplib.Response{
StatusCode: gortsplib.StatusUnauthorized,
Header: gortsplib.Header{
"CSeq": []string{req.Header["CSeq"][0]},
"WWW-Authenticate": (*auth).GenerateHeader(),
},
})

if !initialRequest {
return errAuthCritical
}

return errAuthNotCritical
}
return nil
}()
if err != nil {
return err
}

return nil
Expand Down Expand Up @@ -291,7 +325,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false
}

err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth)
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth, c.p.readIps)
if err != nil {
if err == errAuthCritical {
return false
Expand Down Expand Up @@ -333,7 +367,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
return false
}

err := c.validateAuth(req, c.p.args.publishUser, c.p.args.publishPass, &c.publishAuth)
err := c.validateAuth(req, c.p.args.publishUser, c.p.args.publishPass, &c.publishAuth, c.p.publishIps)
if err != nil {
if err == errAuthCritical {
return false
Expand Down Expand Up @@ -405,7 +439,7 @@ func (c *serverClient) handleRequest(req *gortsplib.Request) bool {
switch c.state {
// play
case _CLIENT_STATE_STARTING, _CLIENT_STATE_PRE_PLAY:
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth)
err := c.validateAuth(req, c.p.args.readUser, c.p.args.readPass, &c.readAuth, c.p.readIps)
if err != nil {
if err == errAuthCritical {
return false
Expand Down

0 comments on commit fbd9f74

Please sign in to comment.