Skip to content

Commit

Permalink
Fix lints
Browse files Browse the repository at this point in the history
  • Loading branch information
XiaohanZhangCMU committed Oct 23, 2023
1 parent e11b76c commit 8362549
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 66 deletions.
181 changes: 119 additions & 62 deletions scripts/ingestion/ingest_from_hf.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,51 @@
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0

"""Utility for HF datasets Ingestion."""

# Improvements on snapshot_download:
# 1. Enable resume = True. retry when bad network happens
# 2. Disable progress_bar to prevent browser/terminal crash
# 3. Add a monitor to print out file stats every 15 seconds

import logging
import os
import time
from huggingface_hub import snapshot_download
from pyspark.sql.functions import concat_ws
from huggingface_hub.utils import are_progress_bars_disabled, disable_progress_bars
import asyncio
import threading
import time
from typing import Any, List, Optional

from watchdog.observers import Observer
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
from watchdog.events import PatternMatchingEventHandler
from watchdog.observers import Observer

from streaming.base.util import retry

logger = logging.getLogger(__name__)


class FolderObserver:
def __init__(self, directory):
patterns = ["*"]
"""A wrapper class of WatchDog."""

def __init__(self, directory: str):
"""Specify the download directory to monitor."""
patterns = ['*']
ignore_patterns = None
ignore_directories = False
case_sensitive = True
self.average_file_size = 0

self.file_count = 0
self.file_size = 0

if not os.path.exists(directory):
os.makedirs(directory)

self.directory = directory
self.get_directory_info()

self.my_event_handler = PatternMatchingEventHandler(patterns, ignore_patterns, ignore_directories, case_sensitive)
self.my_event_handler = PatternMatchingEventHandler(patterns, ignore_patterns,
ignore_directories, case_sensitive)
self.my_event_handler.on_created = self.on_created
self.my_event_handler.on_deleted = self.on_deleted
self.my_event_handler.on_modified = self.on_modified
Expand All @@ -51,49 +66,65 @@ def join(self):
return self.observer.join()

def get_directory_info(self):
file_count = 0
total_file_size = 0
self.file_count = file_count = 0
self.file_size = total_file_size = 0
for root, _, files in os.walk(self.directory):
for file in files:
file_count += 1
file_path = os.path.join(root, file)
total_file_size += os.path.getsize(file_path)
self.file_count, self.file_size = file_count, total_file_size

def on_created(self, event):
def on_created(self, event: Any):
self.file_count += 1

def on_deleted(self, event):
def on_deleted(self, event: Any):
print(type(event))
self.file_count -= 1

def on_modified(self, event):
def on_modified(self, event: Any):
print(type(event))
pass

def on_moved(self, event):
def on_moved(self, event: Any):
print(type(event))
pass


def monitor_directory_changes(interval=5):
def decorator(func):
def wrapper(repo_id, local_dir, max_workers, token, allow_patterns, *args, **kwargs):
def monitor_directory_changes(interval: int = 5):
"""Dataset downloading monitor. Keep file counts N and file size accumulation.
Approximate dataset size by N * avg file size.
"""

def decorator(func: Any):

def wrapper(repo_id: str, local_dir: str, max_workers: int, token: str,
allow_patterns: Optional[List[str]], *args: Any, **kwargs: Any):
event = threading.Event()
start_time = time.time() # Capture the start time
observer = FolderObserver(local_dir)

def beautify(kb):
mb = kb //(1024)
gb = mb //(1024)
return str(mb)+'MB' if mb >= 1 else str(kb) + 'KB'
def beautify(kb: int):
mb = kb // (1024)
gb = mb // (1024)
if gb >= 1:
return str(gb) + 'GB'
elif mb >= 1:
return str(mb) + 'MB'
else:
return str(kb) + 'KB'

def monitor_directory():
observer.start()
while not event.is_set():
try:
elapsed_time = int(time.time() - observer.tik)
if observer.file_size > 1e9: # too large to keep an accurate count of the file size
if observer.file_size > 1e9: # too large to keep an accurate count of the file size
if observer.average_file_size == 0:
observer.average_file_size = observer.file_size // observer.file_count
print(f"approximately: average file size = {beautify(observer.average_file_size//1024)}")
logger.warning(
f'approximately: average file size = {beautify(observer.average_file_size//1024)}'
)
kb = observer.average_file_size * observer.file_count // 1024
else:
observer.get_directory_info()
Expand All @@ -103,30 +134,39 @@ def monitor_directory():
sz = beautify(kb)
cnt = observer.file_count

if elapsed_time % 10 == 0 :
print(f"Downloaded {cnt} files, Total approx file size = {sz}, Time Elapsed: {elapsed_time} seconds.")
if elapsed_time % 10 == 0:
logger.warning(
f'Downloaded {cnt} files, Total approx file size = {sz}, Time Elapsed: {elapsed_time} seconds.'
)

if elapsed_time > 0 and elapsed_time % 120 == 0:
observer.get_directory_info() # Get the actual stats by walking through the directory
observer.get_directory_info(
) # Get the actual stats by walking through the directory
observer.average_file_size = observer.file_size // observer.file_count
print(f"update average file size to {beautify(observer.average_file_size//1024)}")
logger.warning(
f'update average file size to {beautify(observer.average_file_size//1024)}'
)

time.sleep(1)
except Exception as exc:
# raise RuntimeError("Something bad happened") from exc
print(str(exc))
logger.warning(str(exc))
time.sleep(1)
continue

monitor_thread = threading.Thread(target=monitor_directory)
monitor_thread.start()

try:
result = func(repo_id, local_dir, max_workers, token, allow_patterns, *args, **kwargs)
result = func(repo_id, local_dir, max_workers, token, allow_patterns, *args,
**kwargs)
return result
finally:
observer.get_directory_info() # Get the actual stats by walking through the directory
print(f"Done! Downloaded {observer.file_count} files, Total file size = {beautify(observer.file_size//1024)}, Time Elapsed: {int(time.time() - observer.tik)} seconds.")
observer.get_directory_info(
) # Get the actual stats by walking through the directory
logger.warning(
f'Done! Downloaded {observer.file_count} files, Total file size = {beautify(observer.file_size//1024)}, Time Elapsed: {int(time.time() - observer.tik)} seconds.'
)
observer.stop()
observer.join()

Expand All @@ -137,49 +177,66 @@ def monitor_directory():

return decorator

def retry(max_retries=3, retry_delay=5):
def decorator(func):
def wrapper(*args, **kwargs):
for _ in range(max_retries):
try:
return func(*args, **kwargs)
except Exception as e:
print(f"An exception occurred: {str(e)}")
print(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
raise Exception(f"Function {func.__name__} failed after {max_retries} retries.")
return wrapper
return decorator

@monitor_directory_changes()
@retry(max_retries=10, retry_delay=10)
def hf_snapshot(repo_id, local_dir, max_workers, token, allow_patterns):
print(f"Now start to download {repo_id} to {local_dir}, with allow_patterns = {allow_patterns}")
output = snapshot_download(repo_id, repo_type="dataset", local_dir=local_dir, local_dir_use_symlinks=False, max_workers=max_workers, resume_download=True, token=token, allow_patterns=allow_patterns)
@retry([Exception, RuntimeError], num_attempts=10, initial_backoff=10)
def hf_snapshot(repo_id: str, local_dir: str, max_workers: int, token: str,
allow_patterns: Optional[List[str]]):
"""API call to HF snapshot_download.
which internally use hf_hub_download
"""
print(
f'Now start to download {repo_id} to {local_dir}, with allow_patterns = {allow_patterns}')
output = snapshot_download(repo_id,
repo_type='dataset',
local_dir=local_dir,
local_dir_use_symlinks=False,
max_workers=max_workers,
resume_download=True,
token=token,
allow_patterns=allow_patterns)
return output

def download_hf_dataset(local_cache_directory, prefix='', submixes =[], max_workers=32, token="", allow_patterns=None):

def download_hf_dataset(local_cache_directory: str,
prefix: str,
submixes: List[str],
token: str,
max_workers: int = 32,
allow_patterns: Optional[List[str]] = None) -> None:
"""Disable progress bar and call hf_snapshot.
Args:
local_cache_directory (str): local output directory the dataset will be written to.
prefix (str): HF namespace, allenai for example.
submixes (List): a list of repos within HF namespace, c4 for example.
token (str): HF access toekn.
max_workers (int): number of processors to parallelize downloading.
allow_patterns (List): only files matching the pattern will be download. E.g., "en/*" along with allenai/c4 means to download allenai/c4/en folder only.
"""
disable_progress_bars()
for submix in submixes:
repo_id = os.path.join(prefix, submix)
local_dir = os.path.join(local_cache_directory, submix)

output = hf_snapshot(
_ = hf_snapshot(
repo_id,
local_dir,
max_workers,
token,
allow_patterns=allow_patterns,
)

if __name__ == "__main__":

if __name__ == '__main__':
#download_hf_dataset(local_cache_directory="/tmp/xiaohan/cifar10_1233", prefix="", submixes=["cifar10"], max_workers=32)
download_hf_dataset(
local_cache_directory="/tmp/xiaohan/c4_1316",
prefix = "allenai/",
submixes = [
"c4",
],
allow_patterns=["en/*"],
max_workers=1,
token = "MY_HUGGINGFACE_ACCESS_TOKEN") # 32 seems to be a sweet point, beyond 32 downloading is not smooth
download_hf_dataset(local_cache_directory='/tmp/xiaohan/c4_1316',
prefix='allenai/',
submixes=[
'c4',
],
allow_patterns=['en/*'],
max_workers=1,
token='MY_HUGGINGFACE_ACCESS_TOKEN'
) # 32 seems to be a sweet point, beyond 32 downloading is not smooth
5 changes: 1 addition & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,7 @@
'databricks-sdk==0.8.0',
]

extra_deps['ingestion'] = [
'huggingface_hub[cli,torch]>=0.16.0,<0.17.0',
'watchdog>=3,<4'
]
extra_deps['ingestion'] = ['huggingface_hub[cli,torch]>=0.16.0,<0.17.0', 'watchdog>=3,<4']

extra_deps['all'] = sorted({dep for deps in extra_deps.values() for dep in deps})

Expand Down

0 comments on commit 8362549

Please sign in to comment.