From d06ee680ced04e694dd6fbb7efe2a999d5ab7a42 Mon Sep 17 00:00:00 2001 From: Piotr Date: Tue, 9 Jul 2024 14:35:23 +0200 Subject: [PATCH] fix text vectorizer from_json --- supervised/preprocessing/text_transformer.py | 4 ++-- tests/tests_preprocessing/test_text_transformer.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/supervised/preprocessing/text_transformer.py b/supervised/preprocessing/text_transformer.py index ceda959ca..b8b3be462 100644 --- a/supervised/preprocessing/text_transformer.py +++ b/supervised/preprocessing/text_transformer.py @@ -1,5 +1,5 @@ import warnings - +import numpy as np import pandas as pd from sklearn.feature_extraction.text import TfidfVectorizer @@ -70,4 +70,4 @@ def from_json(self, data_json): ) self._vectorizer.vocabulary_ = vocabulary self._vectorizer.fixed_vocabulary_ = fixed_vocabulary - self._vectorizer.idf_ = idf + self._vectorizer.idf_ = np.array(idf) diff --git a/tests/tests_preprocessing/test_text_transformer.py b/tests/tests_preprocessing/test_text_transformer.py index 25191024e..5ee385ea3 100644 --- a/tests/tests_preprocessing/test_text_transformer.py +++ b/tests/tests_preprocessing/test_text_transformer.py @@ -23,6 +23,7 @@ def test_transformer(self): transf = TextTransformer() transf.fit(df, "col1") df = transf.transform(df) + self.assertTrue(df.shape[0] == 5) self.assertTrue("col1" not in df.columns)