Skip to content

Commit

Permalink
Add test for csv_to_wfdb().
Browse files Browse the repository at this point in the history
  • Loading branch information
tompollard committed Jul 9, 2024
1 parent 17b9349 commit 626307a
Showing 1 changed file with 68 additions and 3 deletions.
71 changes: 68 additions & 3 deletions tests/io/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import os
import shutil
import unittest

import numpy as np

from wfdb.io.record import rdrecord
from wfdb.io.convert.edf import read_edf
from wfdb.io.convert.csv import csv_to_wfdb


class TestConvert:
class TestEdfToWfdb:
"""
Tests for the io.convert.edf module.
"""
def test_edf_uniform(self):
"""
EDF format conversion to MIT for uniform sample rates.
"""
# Uniform sample rates
record_MIT = rdrecord("sample-data/n16").__dict__
Expand Down Expand Up @@ -60,7 +67,6 @@ def test_edf_uniform(self):
def test_edf_non_uniform(self):
"""
EDF format conversion to MIT for non-uniform sample rates.
"""
# Non-uniform sample rates
record_MIT = rdrecord("sample-data/wave_4").__dict__
Expand Down Expand Up @@ -108,3 +114,62 @@ def test_edf_non_uniform(self):

target_results = len(fields) * [True]
assert np.array_equal(test_results, target_results)


class TestCsvToWfdb(unittest.TestCase):
"""
Tests for the io.convert.csv module.
"""
def setUp(self):
"""
Create a temporary directory containing data for testing.
Load 100.dat file for comparison to 100.csv file.
"""
self.test_dir = 'test_output'
os.makedirs(self.test_dir, exist_ok=True)

self.record_100_csv = 'sample-data/100.csv'
self.record_100_dat = rdrecord('sample-data/100', physical=True)

def tearDown(self):
"""
Remove the temporary directory after the test.
"""
if os.path.exists(self.test_dir):
shutil.rmtree(self.test_dir)

def test_write_dir(self):
"""
Call the function with the write_dir argument.
"""
csv_to_wfdb(
file_name=self.record_100_csv,
fs=360,
units='mV',
write_dir=self.test_dir
)

# Check if the output files are created in the specified directory
base_name = os.path.splitext(os.path.basename(self.record_100_csv))[0]
expected_dat_file = os.path.join(self.test_dir, f'{base_name}.dat')
expected_hea_file = os.path.join(self.test_dir, f'{base_name}.hea')

self.assertTrue(os.path.exists(expected_dat_file))
self.assertTrue(os.path.exists(expected_hea_file))

# Check that newly written file matches the 100.dat file
record_write = rdrecord(os.path.join(self.test_dir, base_name))

self.assertEqual(record_write.fs, 360)
self.assertEqual(record_write.fs, self.record_100_dat.fs)
self.assertEqual(record_write.units, ['mV', 'mV'])
self.assertEqual(record_write.units, self.record_100_dat.units)
self.assertEqual(record_write.sig_name, ['MLII', 'V5'])
self.assertEqual(record_write.sig_name, self.record_100_dat.sig_name)
self.assertEqual(record_write.p_signal.size, 1300000)
self.assertEqual(record_write.p_signal.size, self.record_100_dat.p_signal.size)


if __name__ == '__main__':
unittest.main()

0 comments on commit 626307a

Please sign in to comment.