Skip to content

Commit

Permalink
Take a few steps away from unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
iver56 committed Apr 1, 2022
1 parent 7bc37e5 commit 6e7c3b3
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 19 deletions.
8 changes: 3 additions & 5 deletions tests/test_band_pass_filter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import unittest

import numpy as np
import torch

from torch_audiomentations import BandPassFilter


class TestBandPassFilter(unittest.TestCase):
class TestBandPassFilter:
def test_band_pass_filter(self):
samples = np.array(
[
Expand All @@ -23,5 +21,5 @@ def test_band_pass_filter(self):
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).samples.numpy()
self.assertEqual(processed_samples.shape, samples.shape)
self.assertEqual(processed_samples.dtype, np.float32)
assert processed_samples.shape == samples.shape
assert processed_samples.dtype == np.float32
8 changes: 3 additions & 5 deletions tests/test_band_stop_filter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import unittest

import numpy as np
import torch

from torch_audiomentations import BandStopFilter


class TestBandStopFilter(unittest.TestCase):
class TestBandStopFilter:
def test_band_reject_filter(self):
samples = np.array(
[
Expand All @@ -22,5 +20,5 @@ def test_band_reject_filter(self):
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).samples.numpy()
self.assertEqual(processed_samples.shape, samples.shape)
self.assertEqual(processed_samples.dtype, np.float32)
assert processed_samples.shape == samples.shape
assert processed_samples.dtype == np.float32
4 changes: 1 addition & 3 deletions tests/test_convolution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import unittest

import librosa
import torch
from numpy.testing import assert_almost_equal
Expand All @@ -9,7 +7,7 @@
from torch_audiomentations.utils.convolution import convolve as torch_convolve


class TestConvolution(unittest.TestCase):
class TestConvolution:
def test_convolve(self):
sample_rate = 16000

Expand Down
1 change: 0 additions & 1 deletion tests/test_file_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import shutil
import tempfile
import unittest
import uuid
from pathlib import Path

Expand Down
8 changes: 3 additions & 5 deletions tests/test_low_pass_filter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import unittest

import numpy as np
import torch

from torch_audiomentations import LowPassFilter


class TestLowPassFilter(unittest.TestCase):
class TestLowPassFilter:
def test_low_pass_filter(self):
samples = np.array(
[
Expand All @@ -22,5 +20,5 @@ def test_low_pass_filter(self):
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
).samples.numpy()
self.assertEqual(processed_samples.shape, samples.shape)
self.assertEqual(processed_samples.dtype, np.float32)
assert processed_samples.shape == samples.shape
assert processed_samples.dtype == np.float32

0 comments on commit 6e7c3b3

Please sign in to comment.