Skip to content

Commit

Permalink
兼容普通用户不支持并发分片下载导致分片下载失败的情况
Browse files Browse the repository at this point in the history
  • Loading branch information
jsyzchen committed Apr 23, 2022
1 parent f5ee494 commit b156e6b
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 18 deletions.
1 change: 1 addition & 0 deletions conf/conf.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const (
BaiduOpenApiDomain = "https://openapi.baidu.com"
OpenApiDomain = "https://pan.baidu.com"
PcsDataDomain = "https://d.pcs.baidu.com"
PcsApiDomain = "https://pcs.baidu.com"
)

// 测试参数
Expand Down
12 changes: 11 additions & 1 deletion examples/file_download.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"fmt"
"github.com/jsyzchen/pan/conf"
"github.com/jsyzchen/pan/file"
)

Expand All @@ -23,8 +24,17 @@ func main() {
fsID = 759719327699432
fileDownloader = file.NewDownloaderWithFsID(accessToken, fsID, localFilePath)
if err := fileDownloader.Download(); err != nil {
fmt.Println("2.fileDownloader.Download failed, err:", err)
fmt.Println("2.fileDownloader.DownloadWithFsID failed, err:", err)
return
}
fmt.Println("2.fileDownloader.Download success")

// 方式3:通过文件路径下载,非开放平台公开接口,生产环境谨慎使用
fileDownloader = file.NewDownloaderWithPath(conf.TestData.AccessToken, conf.TestData.Path, conf.TestData.LocalFilePath)
err := fileDownloader.Download()
if err != nil {
fmt.Println("3.fileDownloader.DownloaderWithPath failed, err:", err)
return
}
fmt.Println("3.fileDownloader.DownloaderWithPath success")
}
42 changes: 32 additions & 10 deletions file/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package file

import (
"errors"
"github.com/jsyzchen/pan/account"
"github.com/jsyzchen/pan/conf"
"github.com/jsyzchen/pan/utils/file"
"log"
"net/url"
)

type Downloader struct {
Expand All @@ -15,7 +18,11 @@ type Downloader struct {
TotalPart int
}

func NewDownloader(accessToken string, downloadLink string, localFilePath string, ) *Downloader {
const (
PcsFileDownloadUri = "/rest/2.0/pcs/file?method=download"
)

func NewDownloader(accessToken string, downloadLink string, localFilePath string) *Downloader {
return &Downloader{
AccessToken: accessToken,
LocalFilePath: localFilePath,
Expand All @@ -31,13 +38,14 @@ func NewDownloaderWithFsID(accessToken string, fsID uint64, localFilePath string
}
}

//func NewDownloaderWithPath(accessToken string, path string, localFilePath string) *Downloader {
// return &Downloader{
// AccessToken: accessToken,
// Path: path,
// LocalFilePath: localFilePath,
// }
//}
// 非开放平台公开接口,生产环境谨慎使用
func NewDownloaderWithPath(accessToken string, path string, localFilePath string) *Downloader {
return &Downloader{
AccessToken: accessToken,
Path: path,
LocalFilePath: localFilePath,
}
}

// 执行下载
func (d *Downloader) Download() error {
Expand All @@ -61,8 +69,12 @@ func (d *Downloader) Download() error {
return errors.New("file don't exist")
}
downloadLink = metas.List[0].DLink
} else if d.Path != "" {

} else if d.Path != "" { // TODO 如何通过文件路径获取下载地址
v := url.Values{}
v.Add("path", d.Path)
v.Add("access_token", d.AccessToken)
body := v.Encode()
downloadLink = conf.PcsApiDomain + PcsFileDownloadUri + "&" + body
} else {
return errors.New("param error")
}
Expand All @@ -73,6 +85,16 @@ func (d *Downloader) Download() error {

downloadLink += "&access_token=" + d.AccessToken
downloader := file.NewFileDownloader(downloadLink, d.LocalFilePath)

accountClient := account.NewAccountClient(d.AccessToken)
if userInfo, err := accountClient.UserInfo(); err == nil {
log.Println("VipType:", userInfo.VipType)
if userInfo.VipType == 2 { //当前用户是超级会员
downloader.SetPartSize(52428800) //设置每分片下载文件大小,50M
downloader.SetCoroutineNum(10) //分片下载并发数,普通用户不支持并发分片下载
}
}

if err := downloader.Download(); err != nil {
log.Println("download failed, err:", err)
return err
Expand Down
12 changes: 12 additions & 0 deletions file/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,15 @@ func TestDownload(t *testing.T) {
}
}

func TestDownloaderWithPath(t *testing.T) {
fileDownloader := NewDownloaderWithPath(conf.TestData.AccessToken, conf.TestData.Path, conf.TestData.LocalFilePath)
err := fileDownloader.Download()
if err != nil {
t.Fail()
} else {
t.Logf("TestDownload Success")
}
}



53 changes: 46 additions & 7 deletions utils/file/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/http"
"os"
"path"
"path/filepath"
"strconv"
"sync"
"time"
Expand All @@ -23,6 +24,7 @@ type Downloader struct {
TotalPart int //下载线程
DoneFilePart []Part
PartSize int
PartCoroutineNum int //分片下载协程数
}

//filePart 文件分片
Expand All @@ -40,15 +42,21 @@ func NewFileDownloader(downloadLink, filePath string) *Downloader {
FileSize: 0,
Link: downloadLink,
FilePath: filePath,
PartSize: 10485760,// 10M
PartCoroutineNum: 1,
}
}

func (d *Downloader) SetTotalPart(totalPart int) {
d.TotalPart = totalPart
}

func (d *Downloader) SetPartSize(PartSize int) {
d.PartSize = PartSize
func (d *Downloader) SetPartSize(partSize int) {
d.PartSize = partSize
}

func (d *Downloader) SetCoroutineNum(partCoroutineNum int) {
d.PartCoroutineNum = partCoroutineNum
}

//Run 开始下载任务
Expand All @@ -68,6 +76,8 @@ func (d *Downloader) Download() error {
d.PartSize = 10485760 // 10M
}

log.Println("fileTotalSize:", fileTotalSize)

if isSupportRange == false || fileTotalSize <= d.PartSize {//不支持Range下载或者文件比较小,直接下载文件
err := d.downloadWhole()
return err
Expand All @@ -89,6 +99,8 @@ func (d *Downloader) Download() error {
jobs := make([]Part, d.TotalPart)
eachSize := fileTotalSize / d.TotalPart

log.Println("eachSize:", eachSize)

for i := range jobs {
jobs[i].Index = i
if i == 0 {
Expand All @@ -109,7 +121,11 @@ func (d *Downloader) Download() error {

var wg sync.WaitGroup
isFailed := false
sem := make(chan int, 10) //限制并发数,以防大文件下载导致占用服务器大量网络宽带和磁盘io
partCoroutineNum := d.PartCoroutineNum
if len(jobs) < partCoroutineNum {
partCoroutineNum = len(jobs)
}
sem := make(chan int, partCoroutineNum) //限制并发数,以防大文件下载导致占用服务器大量网络宽带和磁盘io
for _, job := range jobs {
wg.Add(1)
sem <- 1 //当通道已满的时候将被阻塞
Expand Down Expand Up @@ -173,19 +189,22 @@ func (d *Downloader) downloadPart(c Part) error {
if err != nil {
return err
}
if resp.StatusCode > 299 {
return errors.New(fmt.Sprintf("服务器错误状态码: %v", resp.StatusCode))
}
defer resp.Body.Close()
bs, err := ioutil.ReadAll(resp.Body)
if resp.StatusCode > 299 {
log.Println(fmt.Sprintf("服务器错误,状态码: %v, msg:%s", resp.StatusCode, string(bs)))
return errors.New(fmt.Sprintf("服务器错误,状态码: %v, msg:%s", resp.StatusCode, string(bs)))
}

if err != nil {
if err != io.EOF && err != io.ErrUnexpectedEOF {//unexpected EOF 处理
log.Println("ioutil.ReadAll error :", err)
return err
}
}

if len(bs) != (c.To - c.From + 1) {
return errors.New("下载文件分片长度错误")
return errors.New(fmt.Sprintf("下载文件分片长度错误, len bs:%d", len(bs)))
}
//c.Data = bs

Expand All @@ -194,11 +213,15 @@ func (d *Downloader) downloadPart(c Part) error {
fileNamePrefix := fileName[0:len(path.Base(d.FilePath)) - len(path.Ext(d.FilePath))]
nowTime := time.Now().UnixNano() / 1e6
partFilePath := path.Join(os.TempDir(), fileNamePrefix + "_" + strconv.Itoa(c.Index) + "_" + strconv.FormatInt(nowTime, 10))

log.Printf("partFilePath[%d]:%s", c.Index, partFilePath)

f, err := os.Create(partFilePath)
if err != nil {
log.Println("open file error :", err)
return err
}

// 关闭文件
defer f.Close()
// 字节方式写入
Expand All @@ -207,6 +230,7 @@ func (d *Downloader) downloadPart(c Part) error {
log.Println(err)
return err
}

c.FilePath = partFilePath

d.DoneFilePart[c.Index] = c
Expand All @@ -218,6 +242,21 @@ func (d *Downloader) downloadPart(c Part) error {
//mergeFileParts 合并下载的文件
func (d *Downloader) mergeFileParts() error {
log.Println("开始合并文件")

//存储文件夹不存在的话先创建文件夹
fileDir := filepath.Dir(d.FilePath)
_, err := os.Stat(fileDir)
if err != nil {
if os.IsNotExist(err){
//递归创建文件夹
err := os.MkdirAll(fileDir, os.ModePerm)
if err != nil{
log.Println("MkdirAll failed:", err)
return err
}
}
}

mergedFile, err := os.Create(d.FilePath)
if err != nil {
return err
Expand Down

0 comments on commit b156e6b

Please sign in to comment.