diff --git a/inference_schema/parameter_types/pandas_parameter_type.py b/inference_schema/parameter_types/pandas_parameter_type.py index c22b441..9b40e19 100644 --- a/inference_schema/parameter_types/pandas_parameter_type.py +++ b/inference_schema/parameter_types/pandas_parameter_type.py @@ -7,6 +7,7 @@ from .abstract_parameter_type import AbstractParameterType from ._util import get_swagger_for_list, get_swagger_for_nested_dict from ._constants import SWAGGER_FORMAT_CONSTANTS +from io import StringIO from warnings import warn @@ -78,7 +79,8 @@ def deserialize_input(self, input_data): if not isinstance(input_data, list) and not isinstance(input_data, dict): raise Exception("Error, unable to convert input of type {} into Pandas Dataframe".format(type(input_data))) - data_frame = pd.read_json(json.dumps(input_data), orient=self.orient) + string_stream = StringIO(json.dumps(input_data)) + data_frame = pd.read_json(string_stream, orient=self.orient) if self.apply_column_names: data_frame.columns = self.sample_input.columns.copy() diff --git a/tests/conftest.py b/tests/conftest.py index bae83d9..36e00a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,6 +63,12 @@ def pandas_sample_input_int_column_labels(): return pd.DataFrame(data=pandas_input_data) +@pytest.fixture(scope="session") +def pandas_sample_input_with_url(): + pandas_input_data = {'state': ['WA'], 'website': ['http://wa.website.foo']} + return pd.DataFrame(data=pandas_input_data) + + @pytest.fixture(scope="session") def decorated_pandas_func(pandas_sample_input, pandas_sample_output): @input_schema('param', PandasParameterType(pandas_sample_input)) @@ -122,7 +128,6 @@ def pandas_split_orient_func(param): @pytest.fixture(scope="session") def decorated_pandas_func_int_column_labels(pandas_sample_input_int_column_labels): - @input_schema('param', PandasParameterType(pandas_sample_input_int_column_labels)) def pandas_int_column_labels_func(param): """ @@ -139,6 +144,23 @@ def pandas_int_column_labels_func(param): return pandas_int_column_labels_func +@pytest.fixture(scope="session") +def decorated_pandas_uri_func(pandas_sample_input_with_url): + @input_schema('param', PandasParameterType(pandas_sample_input_with_url)) + def pandas_url_func(param): + """ + + :param param: + :type param: pd.DataFrame + :return: + :rtype: string + """ + assert type(param) is pd.DataFrame + return param['website'][0] + + return pandas_url_func + + @pytest.fixture(scope="session") def decorated_spark_func(): spark_session = SparkSession.builder.config('spark.driver.host', '127.0.0.1').getOrCreate() diff --git a/tests/test_pandas_parameter_type.py b/tests/test_pandas_parameter_type.py index 56d9324..0f7b0da 100644 --- a/tests/test_pandas_parameter_type.py +++ b/tests/test_pandas_parameter_type.py @@ -31,6 +31,10 @@ def test_pandas_handling(self, decorated_pandas_func): assert '3.0' in version_list assert '3.1' in version_list + pandas_input = {'state': ['WA'], 'url': ['http://fakeurl.com']} + result = decorated_pandas_func(pandas_input) + assert_frame_equal(result, state) + def test_pandas_orient_handling(self, decorated_pandas_func_split_orient): pandas_input = {"columns": ["name", "state"], "index": [0], "data": [["Sarah", "WA"]]} state = pd.DataFrame(pd.read_json(json.dumps(pandas_input), orient='split')['state']) @@ -52,6 +56,17 @@ def test_pandas_int_column_labels(self, decorated_pandas_func_int_column_labels, result = decorated_pandas_func_int_column_labels(input) assert_frame_equal(result, pandas_sample_input_int_column_labels) + def test_pandas_url_handling(self, decorated_pandas_uri_func): + pandas_input = {'state': ['WA'], 'website': ['http://wa.website.foo']} + website = pandas_input['website'][0] + result = decorated_pandas_uri_func(pandas_input) + assert website == result + + pandas_input = {'state': ['WA'], 'website': ['This is an embedded url: http://wa.website.foo']} + website = pandas_input['website'][0] + result = decorated_pandas_uri_func(pandas_input) + assert website == result + class TestNestedType(object):