diff --git a/exetera/core/dataframe.py b/exetera/core/dataframe.py index a99af88..e736f17 100644 --- a/exetera/core/dataframe.py +++ b/exetera/core/dataframe.py @@ -998,8 +998,8 @@ def _write_groupby_keys(self, ddf: DataFrame, write_keys=True): Write groupby keys to ddf only if write_key = True """ if write_keys: - by_fields = np.asarray([self._columns[k] for k in self._by]) - for field in by_fields: + for k in self._by: + field = self._columns[k] newfld = field.create_like(ddf, field.name) if self._sorted_index is not None: diff --git a/exetera/core/fields.py b/exetera/core/fields.py index 0448091..05bc6fd 100644 --- a/exetera/core/fields.py +++ b/exetera/core/fields.py @@ -127,7 +127,7 @@ def get_spans(self): """ raise NotImplementedError("Please use get_spans() on specific fields, not the field base class.") - def apply_filter(self, filter_to_apply, dstfld=None): + def apply_filter(self, filter_to_apply, target=None, in_place=False): """ Apply filter on the field. """ @@ -143,6 +143,33 @@ def _ensure_valid(self): if not self._valid_reference: raise ValueError("This field no longer refers to a valid underlying field object") + def __getitem__(self, item:Union[list, tuple, np.ndarray]): + if isinstance(item, slice): + data = self.data[item] + memfield = self.create_like() + memfield.data.write(data) + return memfield + + elif isinstance(item, int): + data = self.data[item] + memfield = self.create_like() + memfield.data.write(np.array([data])) + return memfield + + elif isinstance(item, (list, tuple, np.ndarray)): + allBooleanFlag = True + for x in item: + if not isinstance(x, bool): + allBooleanFlag = False + break + + if allBooleanFlag: + filter_to_apply = np.array(item, dtype='bool') if not isinstance(item, np.ndarray) else item + return self.apply_filter(filter_to_apply, target=None, in_place=False) + else: + index_to_apply = np.array(item, dtype=np.int64) if not isinstance(item, np.ndarray) else item + return self.apply_index(index_to_apply, target=None, in_place=False) + class MemoryField(Field): @@ -210,7 +237,7 @@ def __bool__(self): # if f is not None: return True - def apply_filter(self, filter_to_apply, dstfld=None): + def apply_filter(self, filter_to_apply, target=None, in_place=False): """ Apply filter on the field. """ diff --git a/tests/test_fields.py b/tests/test_fields.py index 8d7709d..00dfd45 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2288,3 +2288,71 @@ def test_argsort(self, creator, name, kwargs, data): else: with self.assertRaises(ValueError): fields.argsort(f) + + +ARRAY_DEREFERENCE_FILTER_TESTS = [ + ([True, False, True], "create_indexed_string", {}, ['a', 'bb', 'ccc']), + ([True, False, True], "create_fixed_string", {"length": 3}, ['a', 'b', 'c']), + ([True, False, True], "create_numeric", {"nformat": "int8"}, [20,30,40]), + ([True, False, True], "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, [1,2,3]) +] + +ARRAY_DEREFERENCE_INDEX_TESTS = [ + ([0, 2], "create_indexed_string", {}, ['a', 'bb', 'ccc']), + ([0, 2], "create_fixed_string", {"length": 3}, ['a', 'b', 'c']), + ([0, 2], "create_numeric", {"nformat": "int8"}, [20,30,40]), + ([0, 2], "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, [1,2,3]) +] + +ARRAY_SLICE_AND_INT_TESTS = [ + (slice(0,2,1), "create_indexed_string", {}, ['a', 'bb', 'ccc']), + (slice(0,3,1), "create_fixed_string", {"length": 3}, [b'a', b'b', b'c']), + (slice(0,2,2), "create_numeric", {"nformat": "int8"}, [20,30,40]), + (slice(0,3,2), "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, [1,2,3]), + (0, "create_indexed_string", {}, ['a', 'bb', 'ccc']), + (1, "create_fixed_string", {"length": 3}, [b'a', b'b', b'c']), + (2, "create_numeric", {"nformat": "int8"}, [20,30,40]), + (0, "create_categorical", {"nformat": "int32", "key": {"a": 1, "b": 2, "c": 3}}, [1,2,3]) +] +class TestArrayDereferenceFunctions(SessionTestCase): + + def assertIfMemFieldAndIfSameTypeAsField(self, memfield, field): + self.assertIsInstance(memfield, fields.MemoryField) + if not (isinstance(field, fields.IndexedStringField) and isinstance(memfield, fields.IndexedStringMemField)) \ + and not (isinstance(field, fields.FixedStringField) and isinstance(memfield, fields.FixedStringMemField)) \ + and not (isinstance(field, fields.NumericField) and isinstance(memfield, fields.NumericMemField)) \ + and not (isinstance(field, fields.CategoricalField) and isinstance(memfield, fields.CategoricalMemField)): + raise AssertionError(f"{type(memfield)} is not the MemField for {type(field)}") + + + @parameterized.expand(ARRAY_DEREFERENCE_FILTER_TESTS) + def test_field_filter_dereference(self, filter, creator, kwargs, data): + f = self.setup_field(self.df, creator, 'f', (), kwargs, data) + result = f[filter] + + filter_to_apply = filter if isinstance(filter, np.ndarray) else np.array(filter, dtype=np.int8) + expected_result = f.apply_filter(filter_to_apply, target=None, in_place=False) + + self.assertIfMemFieldAndIfSameTypeAsField(result, f) + np.testing.assert_array_equal(result.data[:], expected_result.data[:]) + + @parameterized.expand(ARRAY_DEREFERENCE_INDEX_TESTS) + def test_field_index_dereference(self, index, creator, kwargs, data): + f = self.setup_field(self.df, creator, 'f', (), kwargs, data) + result = f[index] + + index_to_apply = index if isinstance(index, np.ndarray) else np.array(index, dtype=np.int8) + expected_result = f.apply_index(index_to_apply, target=None, in_place=False) + + self.assertIfMemFieldAndIfSameTypeAsField(result, f) + np.testing.assert_array_equal(result.data[:], expected_result.data[:]) + + + @parameterized.expand(ARRAY_SLICE_AND_INT_TESTS) + def test_field_slice(self, slice, creator, kwargs, data): + f = self.setup_field(self.df, creator, 'f', (), kwargs, data) + result = f[slice] + expected_result = data[slice] + + self.assertIfMemFieldAndIfSameTypeAsField(result, f) + np.testing.assert_array_equal(result.data[:], expected_result) \ No newline at end of file