From 7769bfdf68d0008e1dbfb457e34b1d066b71178d Mon Sep 17 00:00:00 2001 From: Nikita Ofitserov Date: Sat, 8 Jul 2023 12:33:57 +0300 Subject: [PATCH] ENH: Add support for writing more ExtensionArray types --- pyogrio/_io.pyx | 12 ++++++++++-- pyogrio/geopandas.py | 17 +++++++++++------ pyogrio/raw.py | 2 ++ pyogrio/tests/test_geopandas_io.py | 26 ++++++++++++++++++++++++++ pyogrio/tests/test_raw_io.py | 24 ++++++++++++++++++++++++ 5 files changed, 73 insertions(+), 8 deletions(-) diff --git a/pyogrio/_io.pyx b/pyogrio/_io.pyx index 9065904c..3f5f4c9a 100644 --- a/pyogrio/_io.pyx +++ b/pyogrio/_io.pyx @@ -80,6 +80,8 @@ DTYPE_OGR_FIELD_TYPES = { 'float': (OFTReal, OFSTNone), 'float64': (OFTReal, OFSTNone), + 'string': (OFTString, OFSTNone), + 'datetime64[D]': (OFTDate, OFSTNone), 'datetime64': (OFTDateTime, OFSTNone), } @@ -1470,7 +1472,7 @@ cdef infer_field_types(list dtypes): # TODO: set geometry and field data as memory views? def ogr_write( - str path, str layer, str driver, geometry, fields, field_data, field_mask, + str path, str layer, str driver, geometry, fields, field_dtype, field_data, field_mask, str crs, str geometry_type, str encoding, object dataset_kwargs, object layer_kwargs, bint promote_to_multi=False, bint nan_as_null=True, bint append=False, dataset_metadata=None, layer_metadata=None @@ -1517,6 +1519,12 @@ def ogr_write( else: field_mask = [None] * len(field_data) + if field_dtype is not None: + if len(field_dtype) != len(field_data): + raise ValueError("field_dtype and field_data must be same length") + else: + field_dtype = [field.dtype for field in field_data] + path_b = path.encode('UTF-8') path_c = path_b @@ -1641,7 +1649,7 @@ def ogr_write( layer_options = NULL ### Create the fields - field_types = infer_field_types([field.dtype for field in field_data]) + field_types = infer_field_types(field_dtype) ### Create the fields if create_layer: diff --git a/pyogrio/geopandas.py b/pyogrio/geopandas.py index a69ddb14..fb45a5f0 100644 --- a/pyogrio/geopandas.py +++ b/pyogrio/geopandas.py @@ -323,22 +323,26 @@ def write_dataframe( geometry = df[geometry_column] fields = [c for c in df.columns if not c == geometry_column] - # TODO: may need to fill in pd.NA, etc - field_data = [] + field_data = [df[f].values for f in fields] + field_dtype = [] field_mask = [] - for name in fields: + for i, name in enumerate(fields): col = df[name].values if isinstance(col, pd.api.extensions.ExtensionArray): from pandas.arrays import IntegerArray, FloatingArray, BooleanArray if isinstance(col, (IntegerArray, FloatingArray, BooleanArray)): - field_data.append(col._data) + field_data[i] = col._data # Direct access optimization + field_dtype.append(col.dtype.numpy_dtype) field_mask.append(col._mask) else: - field_data.append(np.asarray(col)) + if hasattr(col.dtype, "numpy_dtype"): + field_dtype.append(col.dtype.numpy_dtype) + else: + field_dtype.append(col.dtype) field_mask.append(np.asarray(col.isna())) else: - field_data.append(col) + field_dtype.append(col.dtype) field_mask.append(None) # Determine geometry_type and/or promote_to_multi @@ -414,6 +418,7 @@ def write_dataframe( driver=driver, geometry=to_wkb(geometry.values), field_data=field_data, + field_dtype=field_dtype, field_mask=field_mask, fields=fields, crs=crs, diff --git a/pyogrio/raw.py b/pyogrio/raw.py index 27bc5ed4..1e593ec7 100644 --- a/pyogrio/raw.py +++ b/pyogrio/raw.py @@ -371,6 +371,7 @@ def write( field_data, fields, field_mask=None, + field_dtype=None, layer=None, driver=None, # derived from meta if roundtrip @@ -460,6 +461,7 @@ def write( geometry=geometry, geometry_type=geometry_type, field_data=field_data, + field_dtype=field_dtype, field_mask=field_mask, fields=fields, crs=crs, diff --git a/pyogrio/tests/test_geopandas_io.py b/pyogrio/tests/test_geopandas_io.py index 8e9fd23a..0f768d08 100644 --- a/pyogrio/tests/test_geopandas_io.py +++ b/pyogrio/tests/test_geopandas_io.py @@ -1062,6 +1062,32 @@ def test_write_nullable_dtypes(tmp_path): assert_geodataframe_equal(output_gdf, expected) +@pytest.mark.skipif( + not has_pyarrow, reason="PyArrow dtype support in Pandas requires PyArrow" +) +def test_write_pyarrow_dtypes(tmp_path): + path = tmp_path / "test_pyarrow_dtypes.gpkg" + test_data = { + "col1": pd.Series([1, 2, 3], dtype="int64[pyarrow]"), + "col2": pd.Series([1, 2, None], dtype="int64[pyarrow]"), + "col3": pd.Series([0.1, None, 0.3], dtype="float32[pyarrow]"), + "col4": pd.Series([True, False, None], dtype="boolean[pyarrow]"), + "col5": pd.Series(["a", None, "b"], dtype="string[pyarrow]"), + } + input_gdf = gp.GeoDataFrame(test_data, geometry=[Point(0, 0)] * 3, crs="epsg:31370") + write_dataframe(input_gdf, path) + output_gdf = read_dataframe(path) + # We read it back as default (non-nullable) numpy dtypes, so we cast + # to those for the expected result, explicitly filling the NA values in + expected = input_gdf.copy() + expected["col1"] = expected["col1"].astype("int64") + expected["col2"] = expected["col2"].astype(object).fillna(np.nan).astype("float64") + expected["col3"] = expected["col3"].astype(object).fillna(np.nan).astype("float32") + expected["col4"] = expected["col4"].astype(object).fillna(np.nan).astype("float64") + expected["col5"] = expected["col5"].astype(object) + assert_geodataframe_equal(output_gdf, expected) + + @pytest.mark.parametrize( "metadata_type", ["dataset_metadata", "layer_metadata", "metadata"] ) diff --git a/pyogrio/tests/test_raw_io.py b/pyogrio/tests/test_raw_io.py index cb342618..e114c497 100644 --- a/pyogrio/tests/test_raw_io.py +++ b/pyogrio/tests/test_raw_io.py @@ -845,3 +845,27 @@ def test_write_with_mask(tmp_path): field_mask = [np.array([False, True, False])] * 2 with pytest.raises(ValueError): write(filename, geometry, field_data, fields, field_mask, **meta) + + +def test_write_with_explicit_dtype(tmp_path): + # Point(0, 0), null + geometry = np.array( + [bytes.fromhex("010100000000000000000000000000000000000000")] * 3, + dtype=object, + ) + field_data = [np.array([1, 2, 3], dtype="int32")] + field_dtype = [np.dtype("float64")] + fields = ["col"] + meta = dict(geometry_type="Point", crs="EPSG:4326") + + filename = tmp_path / "test.geojson" + write(filename, geometry, field_data, fields, field_dtype=field_dtype, **meta) + result_geometry, result_fields = read(filename)[2:] + assert np.array_equal(result_geometry, geometry) + np.testing.assert_allclose(result_fields[0], np.array([1, 2, 3])) + assert result_fields[0].dtype.name == "float64" + + # wrong number of dtypes + field_dtype = [np.dtype("int32")] * 2 + with pytest.raises(ValueError): + write(filename, geometry, field_data, fields, field_dtype=field_dtype, **meta)