Skip to content

Commit

Permalink
fix(tqdm): import tqdm to support jupyter (#812)
Browse files Browse the repository at this point in the history
  • Loading branch information
shcheklein authored Jan 13, 2025
1 parent 4d6ab7b commit 08ff958
Show file tree
Hide file tree
Showing 13 changed files with 23 additions and 138 deletions.
6 changes: 3 additions & 3 deletions src/datachain/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from dvc_objects.fs.utils import remove
from fsspec.callbacks import Callback, TqdmCallback

from .progress import Tqdm

if TYPE_CHECKING:
from datachain.client import Client
from datachain.lib.file import File
Expand Down Expand Up @@ -86,9 +84,11 @@ async def download(
size = file.size
if size < 0:
size = await client.get_size(from_path, version_id=file.version)
from tqdm.auto import tqdm

cb = callback or TqdmCallback(
tqdm_kwargs={"desc": odb_fs.name(from_path), "bytes": True, "leave": False},
tqdm_cls=Tqdm,
tqdm_cls=tqdm,
size=size,
)
try:
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import sqlalchemy as sa
import yaml
from sqlalchemy import Column
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.cache import DataChainCache
from datachain.client import Client
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/client/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from urllib.parse import parse_qs, urlsplit, urlunsplit

from adlfs import AzureBlobFileSystem
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.file import File

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/client/fsspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from dvc_objects.fs.system import reflink
from fsspec.asyn import get_loop, sync
from fsspec.callbacks import DEFAULT_CALLBACK, Callback
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.cache import DataChainCache
from datachain.client.fileslice import FileWrapper
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/client/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dateutil.parser import isoparse
from gcsfs import GCSFileSystem
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.file import File

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/client/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from botocore.exceptions import NoCredentialsError
from s3fs import S3FileSystem
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.file import File

Expand Down
2 changes: 1 addition & 1 deletion src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from sqlalchemy.sql import func
from sqlalchemy.sql.expression import bindparam, cast
from sqlalchemy.sql.selectable import Select
from tqdm import tqdm
from tqdm.auto import tqdm

import datachain.sql.sqlite
from datachain.data_storage import AbstractDBMetastore, AbstractWarehouse
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/data_storage/warehouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from sqlalchemy import Table, case, select
from sqlalchemy.sql import func
from sqlalchemy.sql.expression import true
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.client import Client
from datachain.data_storage.schema import convert_rows_custom_column_types
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/lib/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pyarrow as pa
from fsspec.core import split_protocol
from pyarrow.dataset import CsvFileFormat, dataset
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.data_model import dict_to_data_model
from datachain.lib.file import ArrowRow, File
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/lib/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from typing import TYPE_CHECKING, Any, Union

import PIL
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.lib.arrow import arrow_type_mapper
from datachain.lib.data_model import DataModel, DataType, dict_to_data_model
Expand Down
2 changes: 1 addition & 1 deletion src/datachain/listing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from sqlalchemy import Column
from sqlalchemy.sql import func
from tqdm import tqdm
from tqdm.auto import tqdm

from datachain.node import DirType, Node, NodeWithPath
from datachain.sql.functions import path as pathfunc
Expand Down
126 changes: 3 additions & 123 deletions src/datachain/progress.py
Original file line number Diff line number Diff line change
@@ -1,138 +1,16 @@
"""Manages progress bars."""

import logging
import sys
from threading import RLock
from typing import Any, ClassVar

from fsspec import Callback
from fsspec.callbacks import TqdmCallback
from tqdm import tqdm

from datachain.utils import env2bool
from tqdm.auto import tqdm

logger = logging.getLogger(__name__)
tqdm.set_lock(RLock())


class Tqdm(tqdm):
"""
maximum-compatibility tqdm-based progressbars
"""

BAR_FMT_DEFAULT = (
"{percentage:3.0f}% {desc}|{bar}|"
"{postfix[info]}{n_fmt}/{total_fmt}"
" [{elapsed}<{remaining}, {rate_fmt:>11}]"
)
# nested bars should have fixed bar widths to align nicely
BAR_FMT_DEFAULT_NESTED = (
"{percentage:3.0f}%|{bar:10}|{desc:{ncols_desc}.{ncols_desc}}"
"{postfix[info]}{n_fmt}/{total_fmt}"
" [{elapsed}<{remaining}, {rate_fmt:>11}]"
)
BAR_FMT_NOTOTAL = "{desc}{bar:b}|{postfix[info]}{n_fmt} [{elapsed}, {rate_fmt:>11}]"
BYTES_DEFAULTS: ClassVar[dict[str, Any]] = {
"unit": "B",
"unit_scale": True,
"unit_divisor": 1024,
"miniters": 1,
}

def __init__(
self,
iterable=None,
disable=None,
level=logging.ERROR,
desc=None,
leave=False,
bar_format=None,
bytes=False,
file=None,
total=None,
postfix=None,
**kwargs,
):
"""
bytes : shortcut for
`unit='B', unit_scale=True, unit_divisor=1024, miniters=1`
desc : persists after `close()`
level : effective logging level for determining `disable`;
used only if `disable` is unspecified
disable : If (default: None) or False,
will be determined by logging level.
May be overridden to `True` due to non-TTY status.
Skip override by specifying env var `DATACHAIN_IGNORE_ISATTY`.
kwargs : anything accepted by `tqdm.tqdm()`
"""
kwargs = kwargs.copy()
if bytes:
kwargs = self.BYTES_DEFAULTS | kwargs
else:
kwargs.setdefault("unit_scale", total > 999 if total else True)
if file is None:
file = sys.stderr
# auto-disable based on `logger.level`
if not disable:
disable = logger.getEffectiveLevel() > level
# auto-disable based on TTY
if (
not disable
and not env2bool("DATACHAIN_IGNORE_ISATTY")
and hasattr(file, "isatty")
):
disable = not file.isatty()
super().__init__(
iterable=iterable,
disable=disable,
leave=leave,
desc=desc,
bar_format="!",
lock_args=(False,),
total=total,
**kwargs,
)
self.postfix = postfix or {"info": ""}
if bar_format is None:
if self.__len__():
self.bar_format = (
self.BAR_FMT_DEFAULT_NESTED if self.pos else self.BAR_FMT_DEFAULT
)
else:
self.bar_format = self.BAR_FMT_NOTOTAL
else:
self.bar_format = bar_format
self.refresh()

def close(self):
self.postfix["info"] = ""
# remove ETA (either unknown or zero); remove completed bar
self.bar_format = self.bar_format.replace("<{remaining}", "").replace(
"|{bar:10}|", " "
)
super().close()

@property
def format_dict(self):
"""inject `ncols_desc` to fill the display width (`ncols`)"""
d = super().format_dict
ncols = d["ncols"] or 80
# assumes `bar_format` has max one of ("ncols_desc" & "ncols_info")

meter = self.format_meter( # type: ignore[call-arg]
ncols_desc=1, ncols_info=1, **d
)
ncols_left = ncols - len(meter) + 1
ncols_left = max(ncols_left, 0)
if ncols_left:
d["ncols_desc"] = d["ncols_info"] = ncols_left
else:
# work-around for zero-width description
d["ncols_desc"] = d["ncols_info"] = 1
d["prefix"] = ""
return d


class CombinedDownloadCallback(Callback):
def set_size(self, size):
# This is a no-op to prevent fsspec's .get_file() from setting the combined
Expand All @@ -148,6 +26,8 @@ def __init__(self, tqdm_kwargs=None, *args, **kwargs):
self.files_count = 0
tqdm_kwargs = tqdm_kwargs or {}
tqdm_kwargs.setdefault("postfix", {}).setdefault("files", self.files_count)
kwargs = kwargs or {}
kwargs["tqdm_cls"] = tqdm
super().__init__(tqdm_kwargs, *args, **kwargs)

def increment_file_count(self, n: int = 1) -> None:
Expand Down
9 changes: 7 additions & 2 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from sqlalchemy.sql.expression import label
from sqlalchemy.sql.schema import TableClause
from sqlalchemy.sql.selectable import Select
from tqdm.auto import tqdm

from datachain.asyn import ASYNC_WORKERS, AsyncMapper, OrderedMapper
from datachain.catalog.catalog import clone_catalog_with_cache
Expand Down Expand Up @@ -366,12 +367,16 @@ def get_download_callback(suffix: str = "", **kwargs) -> CombinedDownloadCallbac


def get_processed_callback() -> Callback:
return TqdmCallback({"desc": "Processed", "unit": " rows", "leave": False})
return TqdmCallback(
{"desc": "Processed", "unit": " rows", "leave": False}, tqdm_cls=tqdm
)


def get_generated_callback(is_generator: bool = False) -> Callback:
if is_generator:
return TqdmCallback({"desc": "Generated", "unit": " rows", "leave": False})
return TqdmCallback(
{"desc": "Generated", "unit": " rows", "leave": False}, tqdm_cls=tqdm
)
return DEFAULT_CALLBACK


Expand Down

0 comments on commit 08ff958

Please sign in to comment.