Skip to content

Commit

Permalink
transport: simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
egonelbre committed Oct 3, 2022
1 parent 87e597c commit f526337
Showing 1 changed file with 8 additions and 20 deletions.
28 changes: 8 additions & 20 deletions transport/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"io/ioutil"
"net"
"sync"

"github.com/gogo/protobuf/proto"

Expand All @@ -27,56 +26,45 @@ func ListenUDP(addr string) (*UDPListener, error) {

return &UDPListener{
addr: addr,
pool: sync.Pool{
New: func() interface{} {
return make([]byte, 10*1024)
},
},
conn: conn,
}, nil
}

// UDPListener handles reading packets from the underlying UDP connection.
type UDPListener struct {
addr string
pool sync.Pool
conn *net.UDPConn
}

// Next returns the next packet from UDP and it's associated source address. Should an error occur, then it is returned.
// A source address may be returned alongside an error for further reporting in the event of abuse/debugging.
func (u *UDPListener) Next() (packet *pb.Packet, source *net.UDPAddr, err error) {
buf := u.pool.Get().([]byte)
defer func() {
if buf != nil {
u.pool.Put(buf)
}
}()

n, source, err := u.conn.ReadFromUDP(buf)
var buf [10 * 1024]byte

n, source, err := u.conn.ReadFromUDP(buf[:])
if err != nil {
return nil, nil, err
}

// TODO: handle malformed packet more gracefully... return source address for further reporting
packet, err = parsePacket(n, buf)
packet, err = parsePacket(buf[:n])
if err != nil {
return nil, source, err
}

return packet, source, err
}

func (u *UDPListener) Close() error{
func (u *UDPListener) Close() error {
return u.conn.Close()
}

func parsePacket(n int, buf []byte) (*pb.Packet, error) {
if n < 4 || string(buf[:2]) != "EK" {
func parsePacket(buf []byte) (*pb.Packet, error) {
if len(buf) < 4 || string(buf[:2]) != "EK" {
return nil, errors.New("missing magic number")
}

zl, err := zlib.NewReader(bytes.NewReader(buf[2:n]))
zl, err := zlib.NewReader(bytes.NewReader(buf[2:]))
if err != nil {
return nil, err
}
Expand Down

0 comments on commit f526337

Please sign in to comment.