From b0b156e22ffc35fd04e9ce818a1315644a3fd6d6 Mon Sep 17 00:00:00 2001 From: raviolli Date: Tue, 17 Oct 2017 16:34:16 -0400 Subject: [PATCH] Update transformers.py --- distkeras/transformers.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/distkeras/transformers.py b/distkeras/transformers.py index 4a9490a..ae7cb3d 100644 --- a/distkeras/transformers.py +++ b/distkeras/transformers.py @@ -283,9 +283,15 @@ def _transform(self, row): Only for internal use. """ label = row[self.input_column] - vector = to_one_hot_encoded_dense(label, self.output_dimensionality) - new_row = new_dataframe_row(row, self.output_column, vector.tolist()) + if(isinstance(label, types.ListType)): + vector = np.zeros((len(label), self.output_dimensionality)) + for i in range(len(label)): + vector[i] = to_one_hot_encoded_dense(label[i], self.output_dimensionality) + else: + vector = to_one_hot_encoded_dense(label, self.output_dimensionality) + + new_row = new_dataframe_row(row, self.output_column, vector.tolist()) return new_row def transform(self, dataframe):