Skip to content

Commit d682e3d

Browse files
committed
ENH: Add support for writing more ExtensionArray types
1 parent 6b07e7d commit d682e3d

File tree

5 files changed

+71
-8
lines changed

5 files changed

+71
-8
lines changed

pyogrio/_io.pyx

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ DTYPE_OGR_FIELD_TYPES = {
8080
'float': (OFTReal, OFSTNone),
8181
'float64': (OFTReal, OFSTNone),
8282

83+
'string': (OFTString, OFSTNone),
84+
8385
'datetime64[D]': (OFTDate, OFSTNone),
8486
'datetime64': (OFTDateTime, OFSTNone),
8587
}
@@ -1470,7 +1472,7 @@ cdef infer_field_types(list dtypes):
14701472

14711473
# TODO: set geometry and field data as memory views?
14721474
def ogr_write(
1473-
str path, str layer, str driver, geometry, fields, field_data, field_mask,
1475+
str path, str layer, str driver, geometry, fields, field_dtype, field_data, field_mask,
14741476
str crs, str geometry_type, str encoding, object dataset_kwargs,
14751477
object layer_kwargs, bint promote_to_multi=False, bint nan_as_null=True,
14761478
bint append=False, dataset_metadata=None, layer_metadata=None
@@ -1517,6 +1519,12 @@ def ogr_write(
15171519
else:
15181520
field_mask = [None] * len(field_data)
15191521

1522+
if field_dtype is not None:
1523+
if len(field_dtype) != len(field_data):
1524+
raise ValueError("field_dtype and field_data must be same length")
1525+
else:
1526+
field_dtype = [field.dtype for field in field_data]
1527+
15201528
path_b = path.encode('UTF-8')
15211529
path_c = path_b
15221530

@@ -1641,7 +1649,7 @@ def ogr_write(
16411649
layer_options = NULL
16421650

16431651
### Create the fields
1644-
field_types = infer_field_types([field.dtype for field in field_data])
1652+
field_types = infer_field_types(field_dtype)
16451653

16461654
### Create the fields
16471655
if create_layer:

pyogrio/geopandas.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -323,22 +323,26 @@ def write_dataframe(
323323
geometry = df[geometry_column]
324324
fields = [c for c in df.columns if not c == geometry_column]
325325

326-
# TODO: may need to fill in pd.NA, etc
327-
field_data = []
326+
field_data = [df[f].values for f in fields]
327+
field_dtype = []
328328
field_mask = []
329-
for name in fields:
329+
for i, name in enumerate(fields):
330330
col = df[name].values
331331
if isinstance(col, pd.api.extensions.ExtensionArray):
332332
from pandas.arrays import IntegerArray, FloatingArray, BooleanArray
333333

334334
if isinstance(col, (IntegerArray, FloatingArray, BooleanArray)):
335-
field_data.append(col._data)
335+
field_data[i] = col._data # Direct access optimization
336+
field_dtype.append(col.dtype.numpy_dtype)
336337
field_mask.append(col._mask)
337338
else:
338-
field_data.append(np.asarray(col))
339+
if hasattr(col.dtype, 'numpy_dtype'):
340+
field_dtype.append(col.dtype.numpy_dtype)
341+
else:
342+
field_dtype.append(col.dtype)
339343
field_mask.append(np.asarray(col.isna()))
340344
else:
341-
field_data.append(col)
345+
field_dtype.append(col.dtype)
342346
field_mask.append(None)
343347

344348
# Determine geometry_type and/or promote_to_multi
@@ -414,6 +418,7 @@ def write_dataframe(
414418
driver=driver,
415419
geometry=to_wkb(geometry.values),
416420
field_data=field_data,
421+
field_dtype=field_dtype,
417422
field_mask=field_mask,
418423
fields=fields,
419424
crs=crs,

pyogrio/raw.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ def write(
371371
field_data,
372372
fields,
373373
field_mask=None,
374+
field_dtype=None,
374375
layer=None,
375376
driver=None,
376377
# derived from meta if roundtrip
@@ -460,6 +461,7 @@ def write(
460461
geometry=geometry,
461462
geometry_type=geometry_type,
462463
field_data=field_data,
464+
field_dtype=field_dtype,
463465
field_mask=field_mask,
464466
fields=fields,
465467
crs=crs,

pyogrio/tests/test_geopandas_io.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,30 @@ def test_write_nullable_dtypes(tmp_path):
10621062
assert_geodataframe_equal(output_gdf, expected)
10631063

10641064

1065+
@pytest.mark.skipif(not has_pyarrow, reason="PyArrow dtype support in Pandas requires PyArrow")
1066+
def test_write_pyarrow_dtypes(tmp_path):
1067+
path = tmp_path / "test_pyarrow_dtypes.gpkg"
1068+
test_data = {
1069+
"col1": pd.Series([1, 2, 3], dtype="int64[pyarrow]"),
1070+
"col2": pd.Series([1, 2, None], dtype="int64[pyarrow]"),
1071+
"col3": pd.Series([0.1, None, 0.3], dtype="float32[pyarrow]"),
1072+
"col4": pd.Series([True, False, None], dtype="boolean[pyarrow]"),
1073+
"col5": pd.Series(["a", None, "b"], dtype="string[pyarrow]"),
1074+
}
1075+
input_gdf = gp.GeoDataFrame(test_data, geometry=[Point(0, 0)] * 3, crs="epsg:31370")
1076+
write_dataframe(input_gdf, path)
1077+
output_gdf = read_dataframe(path)
1078+
# We read it back as default (non-nullable) numpy dtypes, so we cast
1079+
# to those for the expected result, explicitly filling the NA values in
1080+
expected = input_gdf.copy()
1081+
expected["col1"] = expected["col1"].astype("int64")
1082+
expected["col2"] = expected["col2"].astype(object).fillna(np.nan).astype("float64")
1083+
expected["col3"] = expected["col3"].astype(object).fillna(np.nan).astype("float32")
1084+
expected["col4"] = expected["col4"].astype(object).fillna(np.nan).astype("float64")
1085+
expected["col5"] = expected["col5"].astype(object)
1086+
assert_geodataframe_equal(output_gdf, expected)
1087+
1088+
10651089
@pytest.mark.parametrize(
10661090
"metadata_type", ["dataset_metadata", "layer_metadata", "metadata"]
10671091
)

pyogrio/tests/test_raw_io.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -845,3 +845,27 @@ def test_write_with_mask(tmp_path):
845845
field_mask = [np.array([False, True, False])] * 2
846846
with pytest.raises(ValueError):
847847
write(filename, geometry, field_data, fields, field_mask, **meta)
848+
849+
850+
def test_write_with_explicit_dtype(tmp_path):
851+
# Point(0, 0), null
852+
geometry = np.array(
853+
[bytes.fromhex("010100000000000000000000000000000000000000")] * 3,
854+
dtype=object,
855+
)
856+
field_data = [np.array([1, 2, 3], dtype="int32")]
857+
field_dtype = [np.dtype("float64")]
858+
fields = ["col"]
859+
meta = dict(geometry_type="Point", crs="EPSG:4326")
860+
861+
filename = tmp_path / "test.geojson"
862+
write(filename, geometry, field_data, fields, field_dtype=field_dtype, **meta)
863+
result_geometry, result_fields = read(filename)[2:]
864+
assert np.array_equal(result_geometry, geometry)
865+
np.testing.assert_allclose(result_fields[0], np.array([1, 2, 3]))
866+
assert result_fields[0].dtype.name == 'float64'
867+
868+
# wrong number of dtypes
869+
field_dtype = [np.dtype("int32")] * 2
870+
with pytest.raises(ValueError):
871+
write(filename, geometry, field_data, fields, field_dtype=field_dtype, **meta)

0 commit comments

Comments
 (0)