Skip to content

Commit

Permalink
Merge pull request nf-core#145 from mathysgrapotte/transform_all_data
Browse files Browse the repository at this point in the history
[feat] transform the entire data
  • Loading branch information
suzannejin authored May 14, 2024
2 parents 505a3ee + 85bbc2a commit 195b6af
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 49 deletions.
44 changes: 11 additions & 33 deletions bin/src/data/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ def add_split(self, config: dict, force=False) -> None:
split_column[validation] = 1
split_column[test] = 2
self.data = self.data.with_columns(pl.Series('split:split:int', split_column))
self.update_categories()

self.update_categories()

def transform(self, transformations: list) -> None:
"""
Expand All @@ -140,40 +139,19 @@ def transform(self, transformations: list) -> None:
key = dictionary['column_name']
data_type = key.split(':')[2]
data_transformer = dictionary['name']
transfomer = self.experiment.get_data_transformer(data_type, data_transformer)

# If the transformer is only for training data, we need to separate the data
# and transform only the training data
if transfomer.training_data_only:
split_colname = self.get_keys_from_header(self.data.columns, category='split')
data_to_transform = self.data.filter(pl.col(split_colname) == 0)
untransformed_data = self.data.filter(pl.col(split_colname) != 0)
else:
data_to_transform = self.data
transformer = self.experiment.get_data_transformer(data_type, data_transformer)

# Transform the data
new_data = transfomer.transform_all(list(data_to_transform[key]), **dictionary['params'])
# transform the data
new_data = transformer.transform_all(list(self.data[key]), **dictionary['params'])

# Add the transformed data to the dataframe

# If the transformer modifies the column, we need to replace the column
if transfomer.add_row:
new_rows = data_to_transform.with_columns(pl.Series(key, new_data))
# if the transformation creates new rows (eg. data augmentation), then add the new rows to the original data
# otherwise just get the transformation of the data
if transformer.add_row:
new_rows = self.data.with_columns(pl.Series(key, new_data))
self.data = self.data.vstack(new_rows)
else:
transformed_data = data_to_transform.with_columns(pl.Series(key, new_data))
# make sure the column has the same type as the new data
# this is necessary because the transformer could change the type of the column (e.g. from int to float)
transformed_data_type = str(transformed_data[key].dtype)
untransformed_data = untransformed_data.with_columns(pl.col(key).cast(getattr(pl, transformed_data_type)))

# If the transformer is only for training data, we need to concatenate the transformed data with the untransformed data
if transfomer.training_data_only:
self.data = transformed_data.vstack(untransformed_data)
else:
self.data = transformed_data


else:
self.data = self.data.with_columns(pl.Series(key, new_data))

def shuffle_labels(self) -> None:
"""
Shuffles the labels in the data.
Expand Down
7 changes: 1 addition & 6 deletions bin/src/data/transform/data_transformation_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class AbstractDataTransformer(ABC):

def __init__(self):
self.add_row = None
self.training_data_only = False

@abstractmethod
def transform(self, data: Any, seed: float = None) -> Any:
Expand All @@ -43,9 +42,7 @@ class AbstractNoiseGenerator(AbstractDataTransformer):

def __init__(self):
super().__init__()
self.add_row = False
self.training_data_only = True

self.add_row = False

class AbstractAugmentationGenerator(AbstractDataTransformer):
"""
Expand All @@ -56,8 +53,6 @@ class AbstractAugmentationGenerator(AbstractDataTransformer):
def __init__(self):
super().__init__()
self.add_row = True
self.training_data_only = False


class UniformTextMasker(AbstractNoiseGenerator):
"""
Expand Down
19 changes: 9 additions & 10 deletions bin/tests/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
sys.path.append('./')
from bin.src.data.csv import CsvProcessing, CsvLoader
from bin.src.data.experiments import DnaToFloatExperiment,ProtDnaToFloatExperiment
from bin.src.data.experiments import DnaToFloatExperiment, ProtDnaToFloatExperiment

class AbstractTestCsvProcessing(unittest.TestCase):
"""
Expand Down Expand Up @@ -82,10 +81,10 @@ def test_split_and_noise(self):
self._transform()
self.data_length = self.data_length * 2
self._test_len()
self._test_all_values_in_column('pet:meta:str', ['dog', 'cat', 'dog','cat'])
self._test_all_values_in_column('hola:label:float', [12.676405, 12.0, 12.676405, 12.0])
self._test_all_values_in_column('hello:input:dna', ['ACTGACTGATCGATNN', 'ACTGACTGATCGATGC', 'NNATCGATCAGTCAGT', 'GCATCGATCAGTCAGT'])
self._test_all_values_in_column('split:split:int', [0, 1, 0, 1])
self._test_all_values_in_column('pet:meta:str', ['cat', 'dog', 'cat','dog'])
self._test_all_values_in_column('hola:label:float', [12.676405, 12.540016, 12.676405, 12.540016])
self._test_all_values_in_column('hello:input:dna', ['ACTGACTGATCGATNN', 'ACTGACTGATCGATNN', 'NNATCGATCAGTCAGT', 'NNATCGATCAGTCAGT'])
self._test_all_values_in_column('split:split:int', [1, 0, 1, 0])

def test_shuffle_labels(self):
# initialize seed to 42 to make the test reproducible
Expand Down Expand Up @@ -119,11 +118,11 @@ def test_split_and_noise(self):
self._transform()
self.data_length = self.data_length * 2
self._test_len()
self._test_all_values_in_column('pet:meta:str', ['dog', 'cat', 'dog','cat'])
self._test_all_values_in_column('hola:label:float', [12.676405, 12.0, 12.676405, 12.0])
self._test_all_values_in_column('hello:input:dna', ['ACTGACTGATCGATNN', 'ACTGACTGATCGATGC', 'NNATCGATCAGTCAGT', 'GCATCGATCAGTCAGT'])
self._test_all_values_in_column('split:split:int', [0, 1, 0,1])
self._test_all_values_in_column('bonjour:input:prot', ['GPRTTIKAKQLETLX', 'GPRTTIKAKQLETLK', 'GPRTTIKAKQLETLX', 'GPRTTIKAKQLETLK'])
self._test_all_values_in_column('pet:meta:str', ['cat', 'dog', 'cat','dog'])
self._test_all_values_in_column('hola:label:float', [12.676405, 12.540016, 12.676405, 12.540016])
self._test_all_values_in_column('hello:input:dna', ['ACTGACTGATCGATNN', 'ACTGACTGATCGATNN', 'NNATCGATCAGTCAGT', 'NNATCGATCAGTCAGT'])
self._test_all_values_in_column('split:split:int', [1,0,1,0])
self._test_all_values_in_column('bonjour:input:prot', ['GPRTTIKAKQLETLX', 'GPRTTIKAKQLETLX', 'GPRTTIKAKQLETLX', 'GPRTTIKAKQLETLX'])

class AbstractTestCsvLoader(unittest.TestCase):
"""
Expand Down

0 comments on commit 195b6af

Please sign in to comment.