diff --git a/chunk.go b/chunk.go index f0e6708..4372372 100644 --- a/chunk.go +++ b/chunk.go @@ -14,10 +14,10 @@ type Chunk struct { *Progress // Chunk start pos. - Start int64 + Start uint64 // Chunk end. - End int64 + End uint64 // Path name where this chunk downloaded. Path string diff --git a/cmd/got/main.go b/cmd/got/main.go index df96434..5c5fa83 100644 --- a/cmd/got/main.go +++ b/cmd/got/main.go @@ -3,18 +3,19 @@ package main import ( "flag" "fmt" - "github.com/dustin/go-humanize" - "github.com/melbahja/got" "log" "time" + + "github.com/dustin/go-humanize" + "github.com/melbahja/got" ) var ( url string version string dest = flag.String("out", "", "Downloaded file destination.") - chunkSize = flag.Int("size", 0, "Maximum chunk size in bytes.") - concurrency = flag.Int("concurrency", 10, "Maximum chunks to download at the same time.") + chunkSize = flag.Uint64("size", 0, "Maximum chunk size in bytes.") + concurrency = flag.Uint("concurrency", 10, "Maximum chunks to download at the same time.") ) func init() { @@ -46,9 +47,9 @@ func main() { } d := got.Download{ - URL: url, - Dest: *dest, - ChunkSize: int64(*chunkSize), + URL: url, + Dest: *dest, + ChunkSize: *chunkSize, Interval: 100, Concurrency: *concurrency, } @@ -62,10 +63,10 @@ func main() { fmt.Printf( "\r\r\bTotal: %s | Chunk: %s | Concurrency: %d | Received: %s | Time: %s | Avg: %s/s | Speed: %s/s", - humanize.Bytes(uint64(p.TotalSize)), - humanize.Bytes(uint64(d.ChunkSize)), + humanize.Bytes(p.TotalSize), + humanize.Bytes(d.ChunkSize), d.Concurrency, - humanize.Bytes(uint64(p.Size)), + humanize.Bytes(p.Size), p.TotalCost().Round(time.Second), humanize.Bytes(p.AvgSpeed()), humanize.Bytes(p.Speed()), diff --git a/got.go b/got.go index 0b09f5a..e8abfcb 100644 --- a/got.go +++ b/got.go @@ -19,7 +19,7 @@ type ( Info struct { // File content length. - Length int64 + Length uint64 // Supports partial content? Rangeable bool @@ -32,7 +32,7 @@ type ( Download struct { // Download file info. - *Info + Info // URL to download. URL string @@ -41,25 +41,25 @@ type ( Dest string // Split file into chunks by ChunkSize in bytes. - ChunkSize int64 + ChunkSize uint64 // Set maximum chunk size. - MaxChunkSize int64 + MaxChunkSize uint64 // Set min chunk size. - MinChunkSize int64 + MinChunkSize uint64 // Max chunks to download at same time. - Concurrency int + Concurrency uint // Progress... Progress *Progress // Progress interval in ms. - Interval int + Interval uint64 // Download file chunks. - chunks []*Chunk + chunks []Chunk // Http client. client *http.Client @@ -75,7 +75,7 @@ func (d *Download) Init() error { var ( err error - i, startRange, endRange, chunksLen int64 + i, startRange, endRange, chunksLen uint64 ) // Set http client @@ -108,12 +108,12 @@ func (d *Download) Init() error { d.Progress = &Progress{ startedAt: time.Now(), Interval: d.Interval, - TotalSize: d.Info.Length, + TotalSize: d.Length, } } // Partial content not supported 😢! - if d.Info.Rangeable == false || d.Info.Length == 0 { + if d.Rangeable == false || d.Length == 0 { return nil } @@ -125,7 +125,7 @@ func (d *Download) Init() error { // Set default chunk size if d.ChunkSize == 0 { - d.ChunkSize = d.Info.Length / int64(d.Concurrency) + d.ChunkSize = d.Length / uint64(d.Concurrency) // if chunk size >= 102400000 bytes set default to (ChunkSize / 2) if d.ChunkSize >= 102400000 { @@ -137,8 +137,8 @@ func (d *Download) Init() error { d.MinChunkSize = 1000000 - if d.MinChunkSize > d.Info.Length { - d.MinChunkSize = d.Info.Length / 2 + if d.MinChunkSize > d.Length { + d.MinChunkSize = d.Length / 2 } } @@ -152,12 +152,14 @@ func (d *Download) Init() error { d.ChunkSize = d.MaxChunkSize } - } else if d.ChunkSize > d.Info.Length { + } else if d.ChunkSize > d.Length { - d.ChunkSize = d.Info.Length / 2 + d.ChunkSize = d.Length / 2 } - chunksLen = d.Info.Length / d.ChunkSize + chunksLen = d.Length / d.ChunkSize + + d.chunks = make([]Chunk, 0, chunksLen) // Set chunk ranges. for ; i < chunksLen; i++ { @@ -174,11 +176,11 @@ func (d *Download) Init() error { break } - if endRange > d.Info.Length || i == (chunksLen-1) { + if endRange > d.Length || i == (chunksLen-1) { endRange = 0 } - d.chunks = append(d.chunks, &Chunk{ + d.chunks = append(d.chunks, Chunk{ Start: startRange, End: endRange, Progress: d.Progress, @@ -199,7 +201,7 @@ func (d *Download) Start() (err error) { } defer os.RemoveAll(temp) - ctx, cancel := context.WithCancel(context.TODO()) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Run progress func. @@ -238,6 +240,8 @@ func (d *Download) Start() (err error) { // Wait for chunks... if err := eg.Wait(); err != nil { + // In case of an error, destination file should be removed + _ = os.Remove(d.Dest) return err } @@ -250,32 +254,32 @@ func (d *Download) Start() (err error) { } // GetInfo gets Info, it returns error if status code > 500 or 404. -func (d *Download) GetInfo() (*Info, error) { +func (d *Download) GetInfo() (Info, error) { req, err := NewRequest("HEAD", d.URL) if err != nil { - return nil, err + return Info{}, err } res, err := d.client.Do(req) if err != nil { - return nil, err + return Info{}, err } if res.StatusCode < 200 || res.StatusCode >= 400 { // On 4xx HEAD request (work around for #3). if res.StatusCode != 404 && res.StatusCode >= 400 && res.StatusCode < 500 { - return &Info{}, nil + return Info{}, nil } - return nil, fmt.Errorf("Response status code is not ok: %d", res.StatusCode) + return Info{}, fmt.Errorf("Response status code is not ok: %d", res.StatusCode) } - return &Info{ - Length: res.ContentLength, + return Info{ + Length: uint64(res.ContentLength), Rangeable: res.Header.Get("accept-ranges") == "bytes", Redirected: d.redirected, }, nil @@ -327,7 +331,7 @@ func (d *Download) dl(ctx context.Context, temp string) error { for i := 0; i < len(d.chunks); i++ { max <- 1 - i := i + current := i eg.Go(func() error { @@ -336,7 +340,7 @@ func (d *Download) dl(ctx context.Context, temp string) error { }() // Create chunk in temp dir. - chunk, err := os.Create(filepath.Join(temp, fmt.Sprintf("chunk-%d", i))) + chunk, err := os.Create(filepath.Join(temp, fmt.Sprintf("chunk-%d", current))) if err != nil { return err @@ -346,13 +350,13 @@ func (d *Download) dl(ctx context.Context, temp string) error { defer chunk.Close() // Download chunk. - err = d.chunks[i].Download(d.URL, d.client, chunk) + err = d.chunks[current].Download(d.URL, d.client, chunk) if err != nil { return err } - d.chunks[i].Path = chunk.Name() - close(d.chunks[i].Done) + d.chunks[current].Path = chunk.Name() + close(d.chunks[current].Done) return nil }) } diff --git a/got_test.go b/got_test.go index b90a1c7..64f13f1 100644 --- a/got_test.go +++ b/got_test.go @@ -2,6 +2,7 @@ package got_test import ( "fmt" + "io/ioutil" "net/http" "net/http/httptest" "os" @@ -9,8 +10,6 @@ import ( "testing" "time" - "io/ioutil" - "github.com/melbahja/got" ) @@ -79,7 +78,7 @@ func TestGot(t *testing.T) { t.Run("info", func(t *testing.T) { expect := got.Info{ - Length: stat.Size(), + Length: uint64(stat.Size()), Rangeable: true, Redirected: false, } @@ -93,7 +92,7 @@ func TestGot(t *testing.T) { t.Run("downloadChunksTest", func(t *testing.T) { // test info size and chunks. - downloadChunksTest(t, httpt.URL+"/file1", stat.Size()) + downloadChunksTest(t, httpt.URL+"/file1", uint64(stat.Size())) }) t.Run("downloadTest", func(t *testing.T) { @@ -147,7 +146,7 @@ func getInfoTest(t *testing.T, url string, expect got.Info) { return } - if expect != *info { + if expect != info { t.Error("invalid info") } @@ -168,7 +167,7 @@ func initTest(t *testing.T, url string) { } } -func downloadChunksTest(t *testing.T, url string, size int64) { +func downloadChunksTest(t *testing.T, url string, size uint64) { tmpFile := createTemp() defer clean(tmpFile) @@ -262,7 +261,7 @@ func downloadHeadNotSupported(t *testing.T, url string) { return } - if *info != (got.Info{}) { + if info != (got.Info{}) { t.Error("It should have a empty Info{}") } diff --git a/progress.go b/progress.go index 737f0ea..e72318a 100644 --- a/progress.go +++ b/progress.go @@ -12,10 +12,10 @@ type ( Progress struct { ProgressFunc - Size, TotalSize int64 - Interval int + Size, TotalSize uint64 + Interval uint64 - lastSize int64 + lastSize uint64 startedAt time.Time } @@ -38,7 +38,7 @@ func (p *Progress) Run(ctx context.Context, d *Download) { p.ProgressFunc(p, d) // Update last size - atomic.StoreInt64(&p.lastSize, atomic.LoadInt64(&p.Size)) + atomic.StoreUint64(&p.lastSize, atomic.LoadUint64(&p.Size)) time.Sleep(time.Duration(d.Interval) * time.Millisecond) } @@ -47,14 +47,14 @@ func (p *Progress) Run(ctx context.Context, d *Download) { // Speed returns download speed. func (p *Progress) Speed() uint64 { - return uint64((atomic.LoadInt64(&p.Size) - atomic.LoadInt64(&p.lastSize)) / int64(p.Interval) * 1000) + return (atomic.LoadUint64(&p.Size) - atomic.LoadUint64(&p.lastSize)) / p.Interval * 1000 } // AvgSpeed returns average download speed. func (p *Progress) AvgSpeed() uint64 { if totalMills := p.TotalCost().Milliseconds(); totalMills > 0 { - return uint64(atomic.LoadInt64(&p.Size) / totalMills * 1000) + return uint64(atomic.LoadUint64(&p.Size) / uint64(totalMills) * 1000) } return 0 @@ -68,6 +68,6 @@ func (p *Progress) TotalCost() time.Duration { // Write updates progress size. func (p *Progress) Write(b []byte) (int, error) { n := len(b) - atomic.AddInt64(&p.Size, int64(n)) + atomic.AddUint64(&p.Size, uint64(n)) return n, nil }