Skip to content

Commit

Permalink
fix: 修复了m3u8流下载时会重复下载某些ts片段的问题
Browse files Browse the repository at this point in the history
针对直播流下载问题进行修复
  • Loading branch information
Johnserf-Seed committed May 1, 2024
1 parent be28368 commit f681fd9
Showing 1 changed file with 67 additions and 40 deletions.
107 changes: 67 additions & 40 deletions f2/dl/base_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
get_segments_from_m3u8,
)

# 最大片段缓存数量,超过这个数量就会进行清理
# (Maximum segment cache count, clear when it exceeds this count)
MAX_SEGMENT_COUNT = 1000


class BaseDownloader(BaseCrawler):
"""基础下载器 (Base Downloader Class)"""
Expand Down Expand Up @@ -255,8 +259,13 @@ async def download_m3u8_stream(
"""
async with self.semaphore:
full_path = self._ensure_path(full_path)
total_downloaded = 1024000
# 设置默认下载总量 (Set default total download)
total_downloaded = 10240000
# 默认块大小 (Default chunk size)
default_chunks = 409600
# 记录已经下载的片段序号
# (Record the segment number that has been downloaded)
downloaded_segments = set()

while not SignalManager.is_shutdown_signaled():
try:
Expand All @@ -279,50 +288,68 @@ async def download_m3u8_stream(
if SignalManager.is_shutdown_signaled():
break

ts_url = segment.absolute_uri
ts_content_length = await get_content_length(
ts_url, self.headers
)
if ts_content_length == 0:
ts_content_length = default_chunks
logger.warning(
_(
"无法读取该TS文件字节长度,将使用默认400kb块大小处理数据"
# 检查是否已经下载过该片段 (Check if the segment has been downloaded)
if segment.absolute_uri not in downloaded_segments:
ts_url = segment.absolute_uri
ts_content_length = await get_content_length(
ts_url, self.headers
)
if ts_content_length == 0:
ts_content_length = default_chunks
logger.warning(
_(
"无法读取该TS文件字节长度,将使用默认400kb块大小处理数据"
)
)
ts_request = self.aclient.build_request(
"GET", ts_url, headers=self.headers
)
ts_response = await self.aclient.send(
ts_request, stream=True
)
ts_request = self.aclient.build_request(
"GET", ts_url, headers=self.headers
)
ts_response = await self.aclient.send(
ts_request, stream=True
)

try:
async for chunk in ts_response.aiter_bytes(
get_chunk_size(ts_content_length)
):
if SignalManager.is_shutdown_signaled():
break

# 直播流分块下载,每次下载后更新进度条
# (Live stream block download, update progress bar after each download)
await file.write(chunk)
total_downloaded += len(chunk)
await self.progress.update(
task_id,
advance=len(chunk),
total=total_downloaded,

try:
async for chunk in ts_response.aiter_bytes(
get_chunk_size(ts_content_length)
):
if SignalManager.is_shutdown_signaled():
break

# 直播流分块下载,每次下载后更新进度条
# (Live stream block download, update progress bar after each download)
await file.write(chunk)
total_downloaded += len(chunk)
await self.progress.update(
task_id,
advance=len(chunk),
total=total_downloaded,
)

# 记录已经下载的片段序号
# (Record the segment number that has been downloaded)
downloaded_segments.add(segment.absolute_uri)

except httpx.ReadTimeout as e:
logger.warning(_("TS文件下载超时: {0}").format(e))
except Exception as e:
logger.error(_("TS文件下载失败: {0}").format(e))
logger.error(traceback.format_exc())
finally:
await ts_response.aclose()
else:
logger.debug(
_("为你跳过已下载的片段,URI: {0}").format(
segment.absolute_uri
)
)

# 每下载一定数量的片段后,清理一次集合
# (After downloading a certain number of segments, clean up the collection)
if len(downloaded_segments) > MAX_SEGMENT_COUNT:
downloaded_segments = set()

except httpx.ReadTimeout as e:
logger.warning(_("TS文件下载超时: {0}").format(e))
except Exception as e:
logger.error(_("TS文件下载失败: {0}").format(e))
logger.error(traceback.format_exc())
finally:
await ts_response.aclose()
# 等待一段时间后再次请求更新 (Request update again after waiting for a while)
await asyncio.sleep(5)
await asyncio.sleep(segment.duration)

except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
Expand Down

0 comments on commit f681fd9

Please sign in to comment.