Skip to content

Commit

Permalink
Convert Pandas json input to StringIO to avoid bug (#67)
Browse files Browse the repository at this point in the history
Working around bug in Pandas with urls in the json string
  • Loading branch information
trangevi authored Jul 12, 2022
1 parent c0d5042 commit 5d95ac8
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
4 changes: 3 additions & 1 deletion inference_schema/parameter_types/pandas_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .abstract_parameter_type import AbstractParameterType
from ._util import get_swagger_for_list, get_swagger_for_nested_dict
from ._constants import SWAGGER_FORMAT_CONSTANTS
from io import StringIO
from warnings import warn


Expand Down Expand Up @@ -78,7 +79,8 @@ def deserialize_input(self, input_data):
if not isinstance(input_data, list) and not isinstance(input_data, dict):
raise Exception("Error, unable to convert input of type {} into Pandas Dataframe".format(type(input_data)))

data_frame = pd.read_json(json.dumps(input_data), orient=self.orient)
string_stream = StringIO(json.dumps(input_data))
data_frame = pd.read_json(string_stream, orient=self.orient)

if self.apply_column_names:
data_frame.columns = self.sample_input.columns.copy()
Expand Down
24 changes: 23 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def pandas_sample_input_int_column_labels():
return pd.DataFrame(data=pandas_input_data)


@pytest.fixture(scope="session")
def pandas_sample_input_with_url():
pandas_input_data = {'state': ['WA'], 'website': ['http://wa.website.foo']}
return pd.DataFrame(data=pandas_input_data)


@pytest.fixture(scope="session")
def decorated_pandas_func(pandas_sample_input, pandas_sample_output):
@input_schema('param', PandasParameterType(pandas_sample_input))
Expand Down Expand Up @@ -122,7 +128,6 @@ def pandas_split_orient_func(param):

@pytest.fixture(scope="session")
def decorated_pandas_func_int_column_labels(pandas_sample_input_int_column_labels):

@input_schema('param', PandasParameterType(pandas_sample_input_int_column_labels))
def pandas_int_column_labels_func(param):
"""
Expand All @@ -139,6 +144,23 @@ def pandas_int_column_labels_func(param):
return pandas_int_column_labels_func


@pytest.fixture(scope="session")
def decorated_pandas_uri_func(pandas_sample_input_with_url):
@input_schema('param', PandasParameterType(pandas_sample_input_with_url))
def pandas_url_func(param):
"""
:param param:
:type param: pd.DataFrame
:return:
:rtype: string
"""
assert type(param) is pd.DataFrame
return param['website'][0]

return pandas_url_func


@pytest.fixture(scope="session")
def decorated_spark_func():
spark_session = SparkSession.builder.config('spark.driver.host', '127.0.0.1').getOrCreate()
Expand Down
15 changes: 15 additions & 0 deletions tests/test_pandas_parameter_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def test_pandas_handling(self, decorated_pandas_func):
assert '3.0' in version_list
assert '3.1' in version_list

pandas_input = {'state': ['WA'], 'url': ['http://fakeurl.com']}
result = decorated_pandas_func(pandas_input)
assert_frame_equal(result, state)

def test_pandas_orient_handling(self, decorated_pandas_func_split_orient):
pandas_input = {"columns": ["name", "state"], "index": [0], "data": [["Sarah", "WA"]]}
state = pd.DataFrame(pd.read_json(json.dumps(pandas_input), orient='split')['state'])
Expand All @@ -52,6 +56,17 @@ def test_pandas_int_column_labels(self, decorated_pandas_func_int_column_labels,
result = decorated_pandas_func_int_column_labels(input)
assert_frame_equal(result, pandas_sample_input_int_column_labels)

def test_pandas_url_handling(self, decorated_pandas_uri_func):
pandas_input = {'state': ['WA'], 'website': ['http://wa.website.foo']}
website = pandas_input['website'][0]
result = decorated_pandas_uri_func(pandas_input)
assert website == result

pandas_input = {'state': ['WA'], 'website': ['This is an embedded url: http://wa.website.foo']}
website = pandas_input['website'][0]
result = decorated_pandas_uri_func(pandas_input)
assert website == result


class TestNestedType(object):

Expand Down

0 comments on commit 5d95ac8

Please sign in to comment.