Skip to content

Commit

Permalink
fix panic when serializing big bitfields
Browse files Browse the repository at this point in the history
  • Loading branch information
cenkalti committed Mar 29, 2019
1 parent cd5c167 commit d410952
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 12 deletions.
22 changes: 22 additions & 0 deletions internal/peerconn/peerwriter/piece_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package peerwriter

import (
"bytes"
"testing"
)

func BenchmarkRead(b *testing.B) {
buf := make([]byte, 10)
buf2 := make([]byte, 25)
r := bytes.NewReader(buf)
p := Piece{
Piece: r,
Begin: 2,
Length: 5,
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
p.Read(buf2)
}
}
20 changes: 12 additions & 8 deletions internal/peerprotocol/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,26 @@ type ExtensionMessage struct {
func (m ExtensionMessage) ID() MessageID { return Extension }

func (m ExtensionMessage) Read([]byte) (int, error) {
panic("read must not be called")
panic("Read must not be called, use WriteTo")
}

func (m ExtensionMessage) WriteTo(w io.Writer) (int64, error) {
_, err := w.Write([]byte{m.ExtendedMessageID})
func (m ExtensionMessage) WriteTo(w io.Writer) (n int64, err error) {
nn, err := w.Write([]byte{m.ExtendedMessageID})
n += int64(nn)
if err != nil {
return 0, err
return
}
err = bencode.NewEncoder(w).Encode(m.Payload)
wc := newWriterCounter(w)
err = bencode.NewEncoder(wc).Encode(m.Payload)
n += wc.Count()
if err != nil {
return 0, err
return
}
if mm, ok := m.Payload.(ExtensionMetadataMessage); ok {
_, err = w.Write(mm.Data)
nn, err = w.Write(mm.Data)
n += int64(nn)
}
return 0, err
return
}

func (m *ExtensionMessage) UnmarshalBinary(data []byte) error {
Expand Down
13 changes: 9 additions & 4 deletions internal/peerprotocol/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
)

type Message interface {
Read([]byte) (int, error)
io.Reader
ID() MessageID
}

Expand Down Expand Up @@ -48,13 +48,18 @@ func (m PieceMessage) Read(b []byte) (int, error) {

type BitfieldMessage struct {
Data []byte
pos int
}

func (m BitfieldMessage) ID() MessageID { return Bitfield }

func (m BitfieldMessage) Read(b []byte) (int, error) {
copy(b[0:len(m.Data)], m.Data)
return len(m.Data), io.EOF
func (m BitfieldMessage) Read(b []byte) (n int, err error) {
n = copy(b, m.Data[m.pos:])
m.pos += n
if m.pos == len(m.Data) {
err = io.EOF
}
return
}

type emptyMessage struct{}
Expand Down
25 changes: 25 additions & 0 deletions internal/peerprotocol/writecounter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package peerprotocol

import "io"

type writerCounter struct {
io.Writer
count int64
writer io.Writer
}

func newWriterCounter(w io.Writer) *writerCounter {
return &writerCounter{
writer: w,
}
}

func (w *writerCounter) Write(buf []byte) (int, error) {
n, err := w.writer.Write(buf)
w.count += int64(n)
return n, err
}

func (w *writerCounter) Count() int64 {
return w.count
}

0 comments on commit d410952

Please sign in to comment.