Skip to content

Commit

Permalink
bugfix: FeatherFileProcessor set wrong index (#359)
Browse files Browse the repository at this point in the history
* fix: don't set index_name when index_name is 'None'

* chore: feather doesn't allow index name to be 'None'
  • Loading branch information
kitagry authored Mar 22, 2024
1 parent 005c92f commit 837742d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
5 changes: 4 additions & 1 deletion gokart/file_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,9 @@ def load(self, file):
index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX]
index_column = index_columns[0]
index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :]
loaded_df.index = pd.Index(loaded_df[index_column], name=index_name)
if index_name == 'None':
index_name = None
loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name)
loaded_df = loaded_df.drop(columns={index_column})

return loaded_df
Expand All @@ -245,6 +247,7 @@ def dump(self, obj, file):
index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}'
assert index_column_name not in dump_obj.columns, f'column name {index_column_name} already exists in dump_obj. \
Consider not saving index by setting store_index_in_feather=False.'
assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.'

dump_obj[index_column_name] = dump_obj.index
dump_obj = dump_obj.reset_index(drop=True)
Expand Down
48 changes: 47 additions & 1 deletion test/test_file_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
from luigi import LocalTarget

from gokart.file_processor import CsvFileProcessor
from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor


class TestCsvFileProcessor(unittest.TestCase):
Expand Down Expand Up @@ -65,3 +65,49 @@ def test_load_csv_with_cp932(self):
# read with cp932 to check if the file is dumped with cp932
loaded_df = processor.load(f)
pd.testing.assert_frame_equal(df, loaded_df)


class TestFeatherFileProcessor(unittest.TestCase):
def test_feather_should_return_same_dataframe(self):
df = pd.DataFrame({'a': [1]})
processor = FeatherFileProcessor(store_index_in_feather=True)

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)

with local_target.open('r') as f:
loaded_df = processor.load(f)

pd.testing.assert_frame_equal(df, loaded_df)

def test_feather_should_save_index_name(self):
df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='index_name'))
processor = FeatherFileProcessor(store_index_in_feather=True)

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)

with local_target.open('r') as f:
loaded_df = processor.load(f)

pd.testing.assert_frame_equal(df, loaded_df)

def test_feather_should_raise_error_index_name_is_None(self):
df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='None'))
processor = FeatherFileProcessor(store_index_in_feather=True)

with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'

local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
with self.assertRaises(AssertionError):
processor.dump(df, f)

0 comments on commit 837742d

Please sign in to comment.