diff --git a/tests/test_band_pass_filter.py b/tests/test_band_pass_filter.py index 03dfe59b..bc617a07 100644 --- a/tests/test_band_pass_filter.py +++ b/tests/test_band_pass_filter.py @@ -1,12 +1,10 @@ -import unittest - import numpy as np import torch from torch_audiomentations import BandPassFilter -class TestBandPassFilter(unittest.TestCase): +class TestBandPassFilter: def test_band_pass_filter(self): samples = np.array( [ @@ -23,5 +21,5 @@ def test_band_pass_filter(self): processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate ).samples.numpy() - self.assertEqual(processed_samples.shape, samples.shape) - self.assertEqual(processed_samples.dtype, np.float32) + assert processed_samples.shape == samples.shape + assert processed_samples.dtype == np.float32 diff --git a/tests/test_band_stop_filter.py b/tests/test_band_stop_filter.py index 5eb32a99..63ca2e99 100644 --- a/tests/test_band_stop_filter.py +++ b/tests/test_band_stop_filter.py @@ -1,12 +1,10 @@ -import unittest - import numpy as np import torch from torch_audiomentations import BandStopFilter -class TestBandStopFilter(unittest.TestCase): +class TestBandStopFilter: def test_band_reject_filter(self): samples = np.array( [ @@ -22,5 +20,5 @@ def test_band_reject_filter(self): processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate ).samples.numpy() - self.assertEqual(processed_samples.shape, samples.shape) - self.assertEqual(processed_samples.dtype, np.float32) + assert processed_samples.shape == samples.shape + assert processed_samples.dtype == np.float32 diff --git a/tests/test_convolution.py b/tests/test_convolution.py index 8cbb810b..b83ad393 100644 --- a/tests/test_convolution.py +++ b/tests/test_convolution.py @@ -1,5 +1,3 @@ -import unittest - import librosa import torch from numpy.testing import assert_almost_equal @@ -9,7 +7,7 @@ from torch_audiomentations.utils.convolution import convolve as torch_convolve -class TestConvolution(unittest.TestCase): +class TestConvolution: def test_convolve(self): sample_rate = 16000 diff --git a/tests/test_file_utils.py b/tests/test_file_utils.py index c8c18f2c..bf373e34 100644 --- a/tests/test_file_utils.py +++ b/tests/test_file_utils.py @@ -1,7 +1,6 @@ import os import shutil import tempfile -import unittest import uuid from pathlib import Path diff --git a/tests/test_low_pass_filter.py b/tests/test_low_pass_filter.py index 9bb5559e..a6a6c855 100644 --- a/tests/test_low_pass_filter.py +++ b/tests/test_low_pass_filter.py @@ -1,12 +1,10 @@ -import unittest - import numpy as np import torch from torch_audiomentations import LowPassFilter -class TestLowPassFilter(unittest.TestCase): +class TestLowPassFilter: def test_low_pass_filter(self): samples = np.array( [ @@ -22,5 +20,5 @@ def test_low_pass_filter(self): processed_samples = augment( samples=torch.from_numpy(samples), sample_rate=sample_rate ).samples.numpy() - self.assertEqual(processed_samples.shape, samples.shape) - self.assertEqual(processed_samples.dtype, np.float32) + assert processed_samples.shape == samples.shape + assert processed_samples.dtype == np.float32