Skip to content

Commit

Permalink
Patching stuck-go routines causing deadline errors (#381)
Browse files Browse the repository at this point in the history
* feat: added timeout to dns + singleflight for caching initial bulk resolutions

* fixing singleflight

* .

* atomic

* removing log + finalize

* fix routine leak

* fix race

---------

Co-authored-by: Ice3man <[email protected]>
  • Loading branch information
Mzack9999 and Ice3man543 authored Dec 8, 2024
1 parent 93a6f7b commit 3c45460
Show file tree
Hide file tree
Showing 9 changed files with 203 additions and 68 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ jobs:
run: |
go run -race example/simple/main.go
go run -race example/impersonate/main.go
go run -race example/concurrent/concurrent.go
88 changes: 88 additions & 0 deletions example/concurrent/concurrent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package main

// this example is to test the concurrency of the dialer along
// with ensuring that maximum connection time doesn't exceed 3 seconds

import (
"context"
"errors"
"sync"
"time"

"github.com/projectdiscovery/fastdialer/fastdialer"
)

func main() {
err := BenchmarkDial("scanme.sh", 1000)
if err != nil {
panic(err)
}
}

type connResult struct {
target string
elapsed time.Duration
err error
}

func BenchmarkDial(target string, iterations int) error {
options := fastdialer.DefaultOptions
fd, err := fastdialer.NewDialer(options)
if err != nil {
return errors.Join(err, errors.New("failed to create dialer"))
}

ctx := context.Background()

tasks := make(chan string, iterations)
results := make(chan connResult, iterations)

var wg sync.WaitGroup
for w := 0; w < 10; w++ {
wg.Add(1)
go worker(ctx, fd, tasks, results, &wg)
}

go func() {
for i := 0; i < iterations; i++ {
tasks <- target
}
close(tasks)
}()

go func() {
wg.Wait()
close(results)
}()

for result := range results {
if result.err != nil {
return result.err
}
if result.elapsed.Seconds() > 3 {
return errors.New("connection took too long")
}
}

return nil
}

func worker(ctx context.Context, fd *fastdialer.Dialer, tasks <-chan string, results chan<- connResult, wg *sync.WaitGroup) {
defer wg.Done()

for task := range tasks {
start := time.Now()
conn, err := fd.Dial(ctx, "tcp", task+":443")
elapsed := time.Since(start)

if err == nil && conn != nil {
conn.Close()
}

results <- connResult{
target: task,
elapsed: elapsed,
err: err,
}
}
}
34 changes: 22 additions & 12 deletions fastdialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import (
"net"
"strings"
"sync/atomic"
"time"

"golang.org/x/sync/singleflight"

"github.com/Mzack9999/gcache"
gounit "github.com/docker/go-units"
Expand Down Expand Up @@ -64,6 +67,8 @@ type Dialer struct {
networkpolicy *networkpolicy.NetworkPolicy
dialCache gcache.Cache[string, *utils.DialWrap]
dialTimeoutErrors gcache.Cache[string, *atomic.Uint32]

resolutionsGroup *singleflight.Group
}

// NewDialer instance
Expand Down Expand Up @@ -136,7 +141,11 @@ func NewDialer(options Options) (*Dialer, error) {
options.Logger.Printf("could not load hosts file: %s\n", err)
}
}
dnsclient, err := retryabledns.New(resolvers, options.MaxRetries)
dnsclient, err := retryabledns.NewWithOptions(retryabledns.Options{
BaseResolvers: resolvers,
MaxRetries: options.MaxRetries,
Timeout: 1 * time.Second,
})
if err != nil {
return nil, err
}
Expand All @@ -152,17 +161,18 @@ func NewDialer(options Options) (*Dialer, error) {
}

d := &Dialer{
dnsclient: dnsclient,
mDnsCache: dnsCache,
hmDnsCache: hmDnsCache,
hostsFileData: hostsFileData,
dialerHistory: dialerHistory,
dialerTLSData: dialerTLSData,
dialer: dialer,
proxyDialer: options.ProxyDialer,
options: &options,
networkpolicy: np,
dialCache: gcache.New[string, *utils.DialWrap](MaxDialCacheSize).Build(),
dnsclient: dnsclient,
mDnsCache: dnsCache,
hmDnsCache: hmDnsCache,
hostsFileData: hostsFileData,
dialerHistory: dialerHistory,
dialerTLSData: dialerTLSData,
dialer: dialer,
proxyDialer: options.ProxyDialer,
options: &options,
networkpolicy: np,
dialCache: gcache.New[string, *utils.DialWrap](MaxDialCacheSize).Build(),
resolutionsGroup: &singleflight.Group{},
}

if options.MaxTemporaryErrors > 0 && options.MaxTemporaryToPermanentDuration > 0 {
Expand Down
14 changes: 7 additions & 7 deletions fastdialer/dialer_private.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate"
"github.com/projectdiscovery/fastdialer/fastdialer/utils"
retryabledns "github.com/projectdiscovery/retryabledns"
ctxutil "github.com/projectdiscovery/utils/context"
cryptoutil "github.com/projectdiscovery/utils/crypto"
"github.com/projectdiscovery/utils/errkit"
Expand Down Expand Up @@ -110,14 +111,14 @@ func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, er
if fixedIP != "" {
IPS = append(IPS, fixedIP)
} else {
data, err := d.GetDNSData(hostname)
if err != nil {
// otherwise attempt to retrieve it
data, err = d.dnsclient.Resolve(hostname)
}
if data == nil {
cacheData, err, _ := d.resolutionsGroup.Do(hostname, func() (interface{}, error) {
return d.GetDNSData(hostname)
})

if cacheData == nil {
return nil, ResolveHostError
}
data := cacheData.(*retryabledns.DNSData)
if err != nil || len(data.A)+len(data.AAAA) == 0 {
return nil, NoAddressFoundError
}
Expand Down Expand Up @@ -161,7 +162,6 @@ func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, er
// 2. it is a domain and not ip
// 3. it has at least 1 valid ip
// 4. proxy dialer is not set

dw, err = utils.NewDialWrap(d.dialer, IPS, opts.network, opts.address, opts.port)
if err != nil {
return nil, errkit.Wrap(err, "could not create dialwrap")
Expand Down
Empty file added fastdialer/perf_test
Empty file.
124 changes: 80 additions & 44 deletions fastdialer/utils/dialwrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,15 @@ type DialWrap struct {
network string
address string
port string
// below fields implement a singleflight like pattern
// where first connection is established and subsequent calls receive
// a shared result
wg sync.WaitGroup
mu sync.Mutex
completedFirstFlight *atomic.Bool
dups uint8
err error // error returned by first flight

// all connections blocks until a first connection is established
// subsequent calls will behave upon first result
busyFirstConnection *atomic.Bool
completedFirstConnection *atomic.Bool
firstConnectionDuration time.Duration
mu sync.RWMutex
// error returned by first connection
err error
}

// NewDialWrap creates a new dial wrap instance and returns it.
Expand All @@ -88,31 +89,22 @@ func NewDialWrap(dialer *net.Dialer, ips []string, network, address, port string
return nil, ErrNoIPs
}
return &DialWrap{
dialer: dialer,
ipv4: ipv4,
ipv6: ipv6,
ips: valid,
completedFirstFlight: &atomic.Bool{},
network: network,
address: address,
port: port,
dialer: dialer,
ipv4: ipv4,
ipv6: ipv6,
ips: valid,
completedFirstConnection: &atomic.Bool{},
busyFirstConnection: &atomic.Bool{},
network: network,
address: address,
port: port,
}, nil
}

// DialContext is the main entry point for dialing
func (d *DialWrap) DialContext(ctx context.Context, _ string, _ string) (net.Conn, error) {
if d.completedFirstFlight.Load() {
// if first flight completed and it failed due to other reasons
// and not due to context cancellation
if d.err != nil && !errkit.Is(d.err, ErrInflightCancel) && !errkit.Is(d.err, context.Canceled) {
return nil, d.err
}
return d.dial(ctx)
}
select {
case <-ctx.Done():
return nil, errkit.Append(ErrInflightCancel, ctx.Err())
case res, ok := <-d.firstFlight(ctx):
case res, ok := <-d.doFirstConnection(ctx):
if !ok {
// closed channel so depending on the error
// either dial new or return the error
Expand All @@ -133,27 +125,35 @@ func (d *DialWrap) DialContext(ctx context.Context, _ string, _ string) (net.Con
return nil, d.err
}
return nil, res.error
case <-d.hasCompletedFirstConnection(ctx):
// if first connection completed and it failed due to other reasons
// and not due to context cancellation
if d.err != nil && !errkit.Is(d.err, ErrInflightCancel) && !errkit.Is(d.err, context.Canceled) {
return nil, d.err
}
return d.dial(ctx)
case <-ctx.Done():
return nil, errkit.Append(ErrInflightCancel, ctx.Err())
}
}

// firstFlight is a singleflight pattern implementation
func (d *DialWrap) firstFlight(ctx context.Context) chan *dialResult {
func (d *DialWrap) doFirstConnection(ctx context.Context) chan *dialResult {
if d.busyFirstConnection.Load() {
return nil
}
d.busyFirstConnection.Store(true)
now := time.Now()
defer func() {
d.SetFirstConnectionDuration(time.Since(now))
}()

size := len(d.ipv4) + len(d.ipv6)
ch := make(chan *dialResult, size)
d.mu.Lock()
if d.dups > 0 {
d.mu.Unlock()
d.wg.Wait()
return ch
}
d.dups++
d.wg.Add(1)
d.mu.Unlock()
defer d.wg.Done()

// dial parallel
conns, err := d.dialAllParallel(ctx)
defer func() {
d.completedFirstFlight.Store(true)
d.completedFirstConnection.Store(true)
close(ch)
}()
if err != nil {
Expand All @@ -167,6 +167,27 @@ func (d *DialWrap) firstFlight(ctx context.Context) chan *dialResult {
return ch
}

func (d *DialWrap) hasCompletedFirstConnection(ctx context.Context) chan struct{} {
ch := make(chan struct{}, 1)

go func() {
defer close(ch)
for {
if d.completedFirstConnection.Load() {
ch <- struct{}{}
return
}
select {
case <-ctx.Done():
return
default:
}
}
}()

return ch
}

// dialAllParallel connects to all the given addresses in parallel, returning
// the first successful connection, or the first error.
func (d *DialWrap) dialAllParallel(ctx context.Context) ([]*dialResult, error) {
Expand Down Expand Up @@ -261,11 +282,12 @@ func (d *DialWrap) dial(ctx context.Context) (net.Conn, error) {
//
// Or zero, if none of Timeout, Deadline, or context's deadline is set.
func (d *DialWrap) deadline(ctx context.Context, now time.Time) (earliest time.Time) {
if d.dialer.Timeout != 0 { // including negative, for historical reasons
earliest = now.Add(d.dialer.Timeout)
// including negative, for historical reasons
if d.dialer.Timeout != 0 {
earliest = now.Add(d.dialer.Timeout + d.FirstConnectionTook())
}
if d, ok := ctx.Deadline(); ok {
earliest = minNonzeroTime(earliest, d)
if de, ok := ctx.Deadline(); ok {
earliest = minNonzeroTime(earliest, de.Add(d.FirstConnectionTook()))
}
return earliest
}
Expand Down Expand Up @@ -408,3 +430,17 @@ func minNonzeroTime(a, b time.Time) time.Time {
}
return b
}

func (d *DialWrap) FirstConnectionTook() time.Duration {
d.mu.RLock()
defer d.mu.RUnlock()

return d.firstConnectionDuration
}

func (d *DialWrap) SetFirstConnectionDuration(dur time.Duration) {
d.mu.Lock()
defer d.mu.Unlock()

d.firstConnectionDuration = dur
}
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,17 @@ require (
github.com/dimchansky/utfbom v1.1.1
github.com/docker/go-units v0.5.0
github.com/pkg/errors v0.9.1
github.com/projectdiscovery/goleak v0.0.0-20240729222606-a7d18edc33f8
github.com/projectdiscovery/hmap v0.0.69
github.com/projectdiscovery/networkpolicy v0.0.9
github.com/projectdiscovery/retryabledns v1.0.87
github.com/projectdiscovery/utils v0.3.0
github.com/refraction-networking/utls v1.6.7
github.com/stretchr/testify v1.9.0
github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9
github.com/zmap/zcrypto v0.0.0-20230422215203-9a665e1e9968
golang.org/x/exp v0.0.0-20221205204356-47842c84f3db
golang.org/x/net v0.29.0
golang.org/x/sync v0.8.0
)

require (
Expand Down Expand Up @@ -55,7 +56,6 @@ require (
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/crypto v0.27.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/sync v0.8.0 // indirect
golang.org/x/sys v0.25.0 // indirect
golang.org/x/text v0.18.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
Expand Down
Loading

0 comments on commit 3c45460

Please sign in to comment.