diff --git a/supervised/preprocessing/text_transformer.py b/supervised/preprocessing/text_transformer.py index ceda959c..b8b3be46 100644 --- a/supervised/preprocessing/text_transformer.py +++ b/supervised/preprocessing/text_transformer.py @@ -1,5 +1,5 @@ import warnings - +import numpy as np import pandas as pd from sklearn.feature_extraction.text import TfidfVectorizer @@ -70,4 +70,4 @@ def from_json(self, data_json): ) self._vectorizer.vocabulary_ = vocabulary self._vectorizer.fixed_vocabulary_ = fixed_vocabulary - self._vectorizer.idf_ = idf + self._vectorizer.idf_ = np.array(idf) diff --git a/tests/tests_preprocessing/test_text_transformer.py b/tests/tests_preprocessing/test_text_transformer.py index 25191024..5ee385ea 100644 --- a/tests/tests_preprocessing/test_text_transformer.py +++ b/tests/tests_preprocessing/test_text_transformer.py @@ -23,6 +23,7 @@ def test_transformer(self): transf = TextTransformer() transf.fit(df, "col1") df = transf.transform(df) + self.assertTrue(df.shape[0] == 5) self.assertTrue("col1" not in df.columns)