Skip to content

Commit

Permalink
Merge pull request #462 from libp2p/fix/observe-context-in-message-se…
Browse files Browse the repository at this point in the history
…nder

fix: obey the context when sending messages to peers
  • Loading branch information
Stebalien authored Mar 5, 2020
2 parents a92f79b + 0b02938 commit dbb3d2c
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 6 deletions.
28 changes: 28 additions & 0 deletions ctx_mutex.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package dht

import (
"context"
)

type ctxMutex chan struct{}

func newCtxMutex() ctxMutex {
return make(ctxMutex, 1)
}

func (m ctxMutex) Lock(ctx context.Context) error {
select {
case m <- struct{}{}:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

func (m ctxMutex) Unlock() {
select {
case <-m:
default:
panic("not locked")
}
}
19 changes: 14 additions & 5 deletions dht_net.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
dht.smlk.Unlock()
return ms, nil
}
ms = &messageSender{p: p, dht: dht}
ms = &messageSender{p: p, dht: dht, lk: newCtxMutex()}
dht.strmap[p] = ms
dht.smlk.Unlock()

Expand Down Expand Up @@ -274,7 +274,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa
type messageSender struct {
s network.Stream
r msgio.ReadCloser
lk sync.Mutex
lk ctxMutex
p peer.ID
dht *IpfsDHT

Expand All @@ -294,8 +294,11 @@ func (ms *messageSender) invalidate() {
}

func (ms *messageSender) prepOrInvalidate(ctx context.Context) error {
ms.lk.Lock()
if err := ms.lk.Lock(ctx); err != nil {
return err
}
defer ms.lk.Unlock()

if err := ms.prep(ctx); err != nil {
ms.invalidate()
return err
Expand Down Expand Up @@ -328,8 +331,11 @@ func (ms *messageSender) prep(ctx context.Context) error {
const streamReuseTries = 3

func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error {
ms.lk.Lock()
if err := ms.lk.Lock(ctx); err != nil {
return err
}
defer ms.lk.Unlock()

retry := false
for {
if err := ms.prep(ctx); err != nil {
Expand Down Expand Up @@ -363,8 +369,11 @@ func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) erro
}

func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) {
ms.lk.Lock()
if err := ms.lk.Lock(ctx); err != nil {
return nil, err
}
defer ms.lk.Unlock()

retry := false
for {
if err := ms.prep(ctx); err != nil {
Expand Down
43 changes: 43 additions & 0 deletions ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,49 @@ import (
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
)

func TestHang(t *testing.T) {
ctx := context.Background()
mn, err := mocknet.FullMeshConnected(ctx, 2)
if err != nil {
t.Fatal(err)
}
hosts := mn.Hosts()

os := []opts.Option{opts.DisableAutoRefresh()}
d, err := New(ctx, hosts[0], os...)
if err != nil {
t.Fatal(err)
}
// Hang on every request.
hosts[1].SetStreamHandler(d.protocols[0], func(s network.Stream) {
defer s.Reset()
<-ctx.Done()
})
d.Update(ctx, hosts[1].ID())

ctx1, cancel1 := context.WithTimeout(ctx, 1*time.Second)
defer cancel1()

peers, err := d.GetClosestPeers(ctx1, testCaseCids[0].KeyString())
if err != nil {
t.Fatal(err)
}

time.Sleep(100 * time.Millisecond)
ctx2, cancel2 := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel2()
_ = d.Provide(ctx2, testCaseCids[0], true)
if ctx2.Err() != context.DeadlineExceeded {
t.Errorf("expected to fail with deadline exceeded, got: %s", ctx2.Err())
}
select {
case <-peers:
t.Error("GetClosestPeers should not have returned yet")
default:
}

}

func TestGetFailures(t *testing.T) {
if testing.Short() {
t.SkipNow()
Expand Down
4 changes: 3 additions & 1 deletion notif.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dht

import (
"context"

"github.com/libp2p/go-libp2p-core/helpers"
"github.com/libp2p/go-libp2p-core/network"

Expand Down Expand Up @@ -130,7 +132,7 @@ func (nn *netNotifiee) Disconnected(n network.Network, v network.Conn) {

// Do this asynchronously as ms.lk can block for a while.
go func() {
ms.lk.Lock()
ms.lk.Lock(context.Background())
defer ms.lk.Unlock()
ms.invalidate()
}()
Expand Down

0 comments on commit dbb3d2c

Please sign in to comment.