-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/unsup multichan waveform dataset #532
base: master
Are you sure you want to change the base?
Feature/unsup multichan waveform dataset #532
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I think it's a good start, will take a bit more effort to handle multi-channel properly here, but we'll get there.
# TODO: how to ensure that each track is synced across batches? i.e. dim=1 is the track index | ||
# and should correspond to the same mic across batches | ||
|
||
cuts = maybe_pad(cuts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be needed as you're manually zero-padding later.
# and should correspond to the same mic across batches | ||
|
||
cuts = maybe_pad(cuts) | ||
cuts = remove_pad_tracks(cuts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is a pitfall here, what if a MixedCut looks like:
|-------cut1-------||---padding---||----cut2----|
or any variation of the situation where padding is in between of two cuts. I don't think Lhotse would handle these situations well with your current code. Maybe you should try only removing the padding at the end (and beginning, but for that one you have to be careful about modifying the offsets on the remaining tracks). Rather than manually removing PaddingCuts, I suggest using .truncate()
with carefully computed offset and duration arguments; that method will handle a lot of pitfalls and edge-cases.
for idx, cut in enumerate(cuts): | ||
ntrack = len(cut.tracks) | ||
nsamp = cut.num_samples | ||
audio[idx, 0:ntrack, 0:nsamp] = torch.from_numpy(cut.load_audio(mixed=False)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that if you did cut.mix(musan_cut)
here, it will also add an extra track; as is, the code would not work with additive noise data augmentation
cuts = remove_pad_tracks(cuts) | ||
|
||
# NOTE: what to do when the # of tracks is not the same across cuts, right now | ||
# this is zero-padding but that seems bad ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you won't escape zero-padding of examples with less channels if you need to collate the data. However, I suggest you modify this function to return a 3-tuple: (audio, audio_lens, channel_indexes)
where audio
is the collated data with shape (B, C, T)
, audio_lens
has the length of each multi-channel example of shape (B,)
, and channel_indexes
is a list of lists of which C
dim indexes have meaningful channels for examples (it could also be channel_lens
tensor of shape (B,)
assuming first c
channels are always meaningful, if it's possible to guarantee).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But in the end your models will have to somehow deal with the non-meaningful channels anyway. As long as you're working on same-number-of-channels data, no need to overthink this.
assert all(isinstance(cut, MixedCut) for cut in cuts) | ||
|
||
# TODO: how to ensure that each track is synced across batches? i.e. dim=1 is the track index | ||
# and should correspond to the same mic across batches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can ensure the tracks are sorted by some property; I imagine this is something very corpus specific and should be done by the user, not by the library.
@@ -74,6 +74,41 @@ def _validate(self, cuts: CutSet) -> None: | |||
assert all(cut.has_recording for cut in cuts) | |||
|
|||
|
|||
class UnsupervisedMultiChanWaveformDataset(UnsupervisedDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class UnsupervisedMultiChanWaveformDataset(UnsupervisedDataset): | |
class MultiChannelWaveformDataset(UnsupervisedDataset): |
somehow reads better to me
"audio_lens": audio_lens, | ||
} | ||
else: | ||
return {"cuts": cuts, "audio": [c.load_audio(mixed=False) for c in cuts]} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line would again have the extra padding channels problem. This suggests that maybe the solution should not be (entirely) in the collate
function, but inside load_audio
, e.g. controlled by an extra argument?
@@ -96,7 +96,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]] | |||
Return a new batch, with the batch size automatically determined using the constraints | |||
of max_frames and max_cuts. | |||
""" | |||
validate_for_asr(cuts) | |||
#validate_for_asr(cuts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be uncommented
For anybody interested in this, here's some context of our earlier discussion with @kkarrancsu I expect you to run into issues related to padding and MUSAN data augmentation with it. Basically, padding and augmentation creates extra tracks in MixedCut, and neither MixedCut nor collate_multi_channel_audio know which tracks are the data, and which tracks are the padding / noise. So, for 4-channel audio, you might end up with 6-channel output from |
@kkarrancsu I have a different idea -- we could add a new attribute to We would need to extend Lines 984 to 990 in b41e4f8
so that instead of simply vstacking the right channels, it vstacks only the "separate" channels, downmixes the remaining channels to mono, and adds them to each of the "separate" channels. The analogous operation is needed for Then, Of course we'd need to add more unit tests to make sure this doesn't break anything and works as expected. |
It seems to me that the more "correct" way to do this would be, when adding noise to multi-channel audio, to add multiple channels of noise. I assume this would require some nontrivial simulation, possibly with multiple sources. |
Good point.. I am not sure if implementing that on top of One such tool is e.g. https://github.com/asteroid-team/torch-audiomentations, but I just noticed that they are doing exactly the same simplified mono downmix I was thinking about: Another option is using https://github.com/LCAV/pyroomacoustics as a transform inside your PyTorch Dataset class I think. In any case, we would still need to be able to handle the padding. I think the solution I suggested with |
OK, sure, it was just a thought. |
Added Dataset which supports multichannel audio-samples. Updated collator to drop pad-tracks.