Skip to content

Commit

Permalink
Fix to fixed batch size bucketing and audio loading network connectio… (
Browse files Browse the repository at this point in the history
#1387)

* Fix to fixed batch size bucketing and audio loading network connection resets

* Fix tests and add more 'paranoia' tests
  • Loading branch information
pzelasko authored Aug 22, 2024
1 parent 66b95ba commit 170046f
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 13 deletions.
2 changes: 2 additions & 0 deletions lhotse/audio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def suppress_audio_loading_errors(enabled: bool = True):
AudioLoadingError,
DurationMismatchError,
NonPositiveEnergyError,
ConnectionResetError, # when reading from object stores / network sources
enabled=enabled,
):
yield
Expand All @@ -141,6 +142,7 @@ def suppress_video_loading_errors(enabled: bool = True):
AudioLoadingError,
DurationMismatchError,
NonPositiveEnergyError,
ConnectionResetError, # when reading from object stores / network sources
enabled=enabled,
):
yield
Expand Down
4 changes: 2 additions & 2 deletions lhotse/dataset/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import warnings
from abc import ABCMeta, abstractmethod
from bisect import bisect_right
from bisect import bisect_left
from copy import deepcopy
from dataclasses import asdict, dataclass
from math import isclose
Expand Down Expand Up @@ -424,7 +424,7 @@ def select_bucket(
), f"select_bucket requires either example= or example_len= as the input (we received {example=} and {example_len=})."
if example_len is None:
example_len = self.measure_length(example)
return bisect_right(buckets, example_len)
return bisect_left(buckets, example_len)

def copy(self) -> "SamplingConstraint":
"""Return a shallow copy of this constraint."""
Expand Down
19 changes: 17 additions & 2 deletions lhotse/dataset/sampling/stateless.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import logging
import random
from pathlib import Path
from typing import Callable, Dict, Generator, Iterable, Optional, Sequence, Tuple, Union
from typing import (
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
)

import torch
from cytoolz import compose_left
Expand Down Expand Up @@ -89,6 +99,8 @@ class StatelessSampler(torch.utils.data.Sampler, Dillable):
:param max_duration: Maximum total number of audio seconds in a mini-batch (dynamic batch size).
:param max_cuts: Maximum number of examples in a mini-batch (static batch size).
:param num_buckets: If set, enables bucketing (each mini-batch has examples of a similar duration).
:param duration_bins: A list of floats (seconds); when provided, we'll skip the initial
estimation of bucket duration bins (useful to speed-up the launching of experiments).
:param quadratic_duration: If set, adds a penalty term for longer duration cuts.
Works well with models that have quadratic time complexity to keep GPU utilization similar
when using bucketing. Suggested values are between 30 and 45.
Expand All @@ -102,6 +114,7 @@ def __init__(
max_duration: Optional[Seconds] = None,
max_cuts: Optional[int] = None,
num_buckets: Optional[int] = None,
duration_bins: List[Seconds] = None,
quadratic_duration: Optional[Seconds] = None,
) -> None:
super().__init__(data_source=None)
Expand Down Expand Up @@ -146,6 +159,7 @@ def __init__(
self.max_duration = max_duration
self.max_cuts = max_cuts
self.num_buckets = num_buckets
self.duration_bins = duration_bins
self.quadratic_duration = quadratic_duration
self.base_seed = base_seed
assert any(
Expand Down Expand Up @@ -216,12 +230,13 @@ def _inner():
yield cut
n += 1

if self.num_buckets is not None and self.num_buckets > 1:
if self.num_buckets is not None or self.duration_bins is not None:
inner_sampler = DynamicBucketingSampler(
_inner(),
max_duration=self.max_duration,
max_cuts=self.max_cuts,
num_buckets=self.num_buckets,
duration_bins=self.duration_bins,
shuffle=False,
drop_last=False,
quadratic_duration=self.quadratic_duration,
Expand Down
26 changes: 26 additions & 0 deletions test/audio/test_audio_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from io import BytesIO
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from unittest.mock import Mock

import numpy as np
import pytest
Expand All @@ -10,6 +11,7 @@

import lhotse
from lhotse import AudioSource, Recording
from lhotse.audio import suppress_audio_loading_errors
from lhotse.audio.backend import (
info,
read_opus_ffmpeg,
Expand Down Expand Up @@ -260,3 +262,27 @@ def test_set_audio_backend():
)
audio2 = recording.load_audio()
np.testing.assert_array_almost_equal(audio1, audio2)


def test_fault_tolerant_audio_network_exception():
def _mock_load_audio(*args, **kwargs):
raise ConnectionResetError()

source = Mock()
source.load_audio = _mock_load_audio
source.has_video = False

recording = Recording(
id="irrelevant",
sources=[source],
sampling_rate=16000,
num_samples=16000,
duration=1.0,
channel_ids=[0],
)

with pytest.raises(ConnectionResetError):
recording.load_audio() # does raise

with suppress_audio_loading_errors(True):
recording.load_audio() # is silently caught
77 changes: 69 additions & 8 deletions test/dataset/sampling/test_dynamic_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_dynamic_bucketing_drop_last_false():
rng = random.Random(0)

sampler = DynamicBucketer(
cuts, duration_bins=[2], max_duration=5, rng=rng, world_size=1
cuts, duration_bins=[1.5], max_duration=5, rng=rng, world_size=1
)
batches = [b for b in sampler]
sampled_cuts = [c for b in batches for c in b]
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_dynamic_bucketing_drop_last_true():
rng = random.Random(0)

sampler = DynamicBucketer(
cuts, duration_bins=[2], max_duration=5, rng=rng, drop_last=True, world_size=1
cuts, duration_bins=[1.5], max_duration=5, rng=rng, drop_last=True, world_size=1
)
batches = [b for b in sampler]
sampled_cuts = [c for b in batches for c in b]
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_dynamic_bucketing_sampler(concurrent):
c.duration = 2

sampler = DynamicBucketingSampler(
cuts, max_duration=5, num_buckets=2, seed=0, concurrent=concurrent
cuts, max_duration=5, duration_bins=[1.5], seed=0, concurrent=concurrent
)
batches = [b for b in sampler]
sampled_cuts = [c for b in batches for c in b]
Expand Down Expand Up @@ -231,7 +231,9 @@ def test_dynamic_bucketing_sampler_too_small_data_can_be_sampled():
c.duration = 2

# 10 cuts with 30s total are not enough to satisfy max_duration of 100 with 2 buckets
sampler = DynamicBucketingSampler(cuts, max_duration=100, num_buckets=2, seed=0)
sampler = DynamicBucketingSampler(
cuts, max_duration=100, duration_bins=[1.5], seed=0
)
batches = [b for b in sampler]
sampled_cuts = [c for b in batches for c in b]

Expand All @@ -249,6 +251,35 @@ def test_dynamic_bucketing_sampler_too_small_data_can_be_sampled():
assert len(b) == 5


def test_dynamic_bucketing_sampler_much_less_data_than_ddp_ranks():
world_size = 128
orig_cut = dummy_cut(0)
cuts = CutSet([orig_cut])
samplers = [
DynamicBucketingSampler(
cuts,
max_duration=2000.0,
duration_bins=[1.5, 3.7, 15.2, 27.9, 40.0],
drop_last=False,
concurrent=False,
world_size=world_size,
rank=i,
)
for i in range(world_size)
]
# None of the ranks drops anything, all of them return the one cut we have.
for sampler in samplers:
(batch,) = [b for b in sampler]
assert len(batch) == 1
(sampled_cut,) = batch
assert (
sampled_cut.id[: len(orig_cut.id)] == orig_cut.id
) # same stem, possibly added '_dupX' suffix
# otherwise the cuts are identical
sampled_cut.id = orig_cut.id
assert sampled_cut == orig_cut


def test_dynamic_bucketing_sampler_too_small_data_drop_last_true_results_in_no_batches():
cuts = DummyManifest(CutSet, begin_id=0, end_id=10)
for i, c in enumerate(cuts):
Expand Down Expand Up @@ -337,7 +368,9 @@ def test_dynamic_bucketing_sampler_cut_pairs():
else:
c.duration = 2

sampler = DynamicBucketingSampler(cuts, cuts, max_duration=5, num_buckets=2, seed=0)
sampler = DynamicBucketingSampler(
cuts, cuts, max_duration=5, duration_bins=[1.5], seed=0
)
batches = [b for b in sampler]
sampled_cut_pairs = [cut_pair for b in batches for cut_pair in zip(*b)]
source_cuts = [sc for sc, tc in sampled_cut_pairs]
Expand Down Expand Up @@ -473,7 +506,7 @@ def test_dynamic_bucketing_sampler_cut_triplets():
c.duration = 2

sampler = DynamicBucketingSampler(
cuts, cuts, cuts, max_duration=5, num_buckets=2, seed=0
cuts, cuts, cuts, max_duration=5, duration_bins=[1.5], seed=0
)
batches = [b for b in sampler]
sampled_cut_triplets = [cut_triplet for b in batches for cut_triplet in zip(*b)]
Expand Down Expand Up @@ -542,7 +575,7 @@ def test_dynamic_bucketing_quadratic_duration():

# quadratic_duration=30
sampler = DynamicBucketingSampler(
cuts, max_duration=61, num_buckets=2, seed=0, quadratic_duration=30
cuts, max_duration=61, duration_bins=[10.0], seed=0, quadratic_duration=30
)
batches = [b for b in sampler]
assert len(batches) == 6
Expand All @@ -556,7 +589,7 @@ def test_dynamic_bucketing_quadratic_duration():

# quadratic_duration=None (disabled)
sampler = DynamicBucketingSampler(
cuts, max_duration=61, num_buckets=2, seed=0, quadratic_duration=None
cuts, max_duration=61, duration_bins=[10.0], seed=0, quadratic_duration=None
)
batches = [b for b in sampler]
assert len(batches) == 4
Expand Down Expand Up @@ -731,3 +764,31 @@ def test_dynamic_bucketing_sampler_fixed_batch_constraint():

assert len(batches[7]) == 1
assert sum(c.duration for c in batches[7]) == 1


def test_select_bucket_includes_upper_bound_in_bin():
constraint = FixedBucketBatchSizeConstraint(
max_seq_len_buckets=[2.0, 4.0], batch_sizes=[2, 1]
)

# within bounds
assert (
constraint.select_bucket(constraint.max_seq_len_buckets, example_len=1.0) == 0
)
assert (
constraint.select_bucket(constraint.max_seq_len_buckets, example_len=2.0) == 0
)
assert (
constraint.select_bucket(constraint.max_seq_len_buckets, example_len=3.0) == 1
)
assert (
constraint.select_bucket(constraint.max_seq_len_buckets, example_len=4.0) == 1
)
constraint.add(dummy_cut(0, duration=4.0)) # can add max duration without exception

# out of bounds
assert (
constraint.select_bucket(constraint.max_seq_len_buckets, example_len=5.0) == 2
)
with pytest.raises(AssertionError):
constraint.add(dummy_cut(0, duration=5.0))
24 changes: 24 additions & 0 deletions test/dataset/sampling/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,3 +1232,27 @@ def test_sampler_map():
b = batches[1]
assert len(b) == 1
assert b[0].duration == 5.0


def test_sampler_much_less_data_than_ddp_ranks():
world_size = 128
orig_cut = dummy_cut(0)
cuts = CutSet([orig_cut])

samplers = [
DynamicCutSampler(
cuts, max_cuts=256, drop_last=False, world_size=world_size, rank=i
)
for i in range(world_size)
]
# None of the ranks drops anything, all of them return the one cut we have.
for sampler in samplers:
(batch,) = [b for b in sampler]
assert len(batch) == 1
(sampled_cut,) = batch
assert (
sampled_cut.id[: len(orig_cut.id)] == orig_cut.id
) # same stem, possibly added '_dupX' suffix
# otherwise the cuts are identical
sampled_cut.id = orig_cut.id
assert sampled_cut == orig_cut
6 changes: 5 additions & 1 deletion test/dataset/sampling/test_stateless_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ def test_stateless_sampler_in_dataloader_with_iterable_dataset(
def test_stateless_sampler_bucketing(cuts_files: Tuple[Path]):
index_path = cuts_files[0].parent / "cuts.idx"
sampler = StatelessSampler(
cuts_files, index_path=index_path, num_buckets=2, max_duration=4, base_seed=0
cuts_files,
index_path=index_path,
duration_bins=[1.5],
max_duration=4,
base_seed=0,
)

for idx, batch in enumerate(sampler):
Expand Down

0 comments on commit 170046f

Please sign in to comment.