Skip to content

Commit

Permalink
ENH: Add support for writing more ExtensionArray types
Browse files Browse the repository at this point in the history
  • Loading branch information
himikof committed Jul 8, 2023
1 parent 6b07e7d commit 7769bfd
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 8 deletions.
12 changes: 10 additions & 2 deletions pyogrio/_io.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ DTYPE_OGR_FIELD_TYPES = {
'float': (OFTReal, OFSTNone),
'float64': (OFTReal, OFSTNone),

'string': (OFTString, OFSTNone),

'datetime64[D]': (OFTDate, OFSTNone),
'datetime64': (OFTDateTime, OFSTNone),
}
Expand Down Expand Up @@ -1470,7 +1472,7 @@ cdef infer_field_types(list dtypes):

# TODO: set geometry and field data as memory views?
def ogr_write(
str path, str layer, str driver, geometry, fields, field_data, field_mask,
str path, str layer, str driver, geometry, fields, field_dtype, field_data, field_mask,
str crs, str geometry_type, str encoding, object dataset_kwargs,
object layer_kwargs, bint promote_to_multi=False, bint nan_as_null=True,
bint append=False, dataset_metadata=None, layer_metadata=None
Expand Down Expand Up @@ -1517,6 +1519,12 @@ def ogr_write(
else:
field_mask = [None] * len(field_data)

if field_dtype is not None:
if len(field_dtype) != len(field_data):
raise ValueError("field_dtype and field_data must be same length")
else:
field_dtype = [field.dtype for field in field_data]

path_b = path.encode('UTF-8')
path_c = path_b

Expand Down Expand Up @@ -1641,7 +1649,7 @@ def ogr_write(
layer_options = NULL

### Create the fields
field_types = infer_field_types([field.dtype for field in field_data])
field_types = infer_field_types(field_dtype)

### Create the fields
if create_layer:
Expand Down
17 changes: 11 additions & 6 deletions pyogrio/geopandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,22 +323,26 @@ def write_dataframe(
geometry = df[geometry_column]
fields = [c for c in df.columns if not c == geometry_column]

# TODO: may need to fill in pd.NA, etc
field_data = []
field_data = [df[f].values for f in fields]
field_dtype = []
field_mask = []
for name in fields:
for i, name in enumerate(fields):
col = df[name].values
if isinstance(col, pd.api.extensions.ExtensionArray):
from pandas.arrays import IntegerArray, FloatingArray, BooleanArray

if isinstance(col, (IntegerArray, FloatingArray, BooleanArray)):
field_data.append(col._data)
field_data[i] = col._data # Direct access optimization
field_dtype.append(col.dtype.numpy_dtype)
field_mask.append(col._mask)
else:
field_data.append(np.asarray(col))
if hasattr(col.dtype, "numpy_dtype"):
field_dtype.append(col.dtype.numpy_dtype)
else:
field_dtype.append(col.dtype)
field_mask.append(np.asarray(col.isna()))
else:
field_data.append(col)
field_dtype.append(col.dtype)
field_mask.append(None)

# Determine geometry_type and/or promote_to_multi
Expand Down Expand Up @@ -414,6 +418,7 @@ def write_dataframe(
driver=driver,
geometry=to_wkb(geometry.values),
field_data=field_data,
field_dtype=field_dtype,
field_mask=field_mask,
fields=fields,
crs=crs,
Expand Down
2 changes: 2 additions & 0 deletions pyogrio/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def write(
field_data,
fields,
field_mask=None,
field_dtype=None,
layer=None,
driver=None,
# derived from meta if roundtrip
Expand Down Expand Up @@ -460,6 +461,7 @@ def write(
geometry=geometry,
geometry_type=geometry_type,
field_data=field_data,
field_dtype=field_dtype,
field_mask=field_mask,
fields=fields,
crs=crs,
Expand Down
26 changes: 26 additions & 0 deletions pyogrio/tests/test_geopandas_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,32 @@ def test_write_nullable_dtypes(tmp_path):
assert_geodataframe_equal(output_gdf, expected)


@pytest.mark.skipif(
not has_pyarrow, reason="PyArrow dtype support in Pandas requires PyArrow"
)
def test_write_pyarrow_dtypes(tmp_path):
path = tmp_path / "test_pyarrow_dtypes.gpkg"
test_data = {
"col1": pd.Series([1, 2, 3], dtype="int64[pyarrow]"),
"col2": pd.Series([1, 2, None], dtype="int64[pyarrow]"),
"col3": pd.Series([0.1, None, 0.3], dtype="float32[pyarrow]"),
"col4": pd.Series([True, False, None], dtype="boolean[pyarrow]"),
"col5": pd.Series(["a", None, "b"], dtype="string[pyarrow]"),
}
input_gdf = gp.GeoDataFrame(test_data, geometry=[Point(0, 0)] * 3, crs="epsg:31370")
write_dataframe(input_gdf, path)
output_gdf = read_dataframe(path)
# We read it back as default (non-nullable) numpy dtypes, so we cast
# to those for the expected result, explicitly filling the NA values in
expected = input_gdf.copy()
expected["col1"] = expected["col1"].astype("int64")
expected["col2"] = expected["col2"].astype(object).fillna(np.nan).astype("float64")
expected["col3"] = expected["col3"].astype(object).fillna(np.nan).astype("float32")
expected["col4"] = expected["col4"].astype(object).fillna(np.nan).astype("float64")
expected["col5"] = expected["col5"].astype(object)
assert_geodataframe_equal(output_gdf, expected)


@pytest.mark.parametrize(
"metadata_type", ["dataset_metadata", "layer_metadata", "metadata"]
)
Expand Down
24 changes: 24 additions & 0 deletions pyogrio/tests/test_raw_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,3 +845,27 @@ def test_write_with_mask(tmp_path):
field_mask = [np.array([False, True, False])] * 2
with pytest.raises(ValueError):
write(filename, geometry, field_data, fields, field_mask, **meta)


def test_write_with_explicit_dtype(tmp_path):
# Point(0, 0), null
geometry = np.array(
[bytes.fromhex("010100000000000000000000000000000000000000")] * 3,
dtype=object,
)
field_data = [np.array([1, 2, 3], dtype="int32")]
field_dtype = [np.dtype("float64")]
fields = ["col"]
meta = dict(geometry_type="Point", crs="EPSG:4326")

filename = tmp_path / "test.geojson"
write(filename, geometry, field_data, fields, field_dtype=field_dtype, **meta)
result_geometry, result_fields = read(filename)[2:]
assert np.array_equal(result_geometry, geometry)
np.testing.assert_allclose(result_fields[0], np.array([1, 2, 3]))
assert result_fields[0].dtype.name == "float64"

# wrong number of dtypes
field_dtype = [np.dtype("int32")] * 2
with pytest.raises(ValueError):
write(filename, geometry, field_data, fields, field_dtype=field_dtype, **meta)

0 comments on commit 7769bfd

Please sign in to comment.