diff --git a/doc/source/development/extending.rst b/doc/source/development/extending.rst index e341dcb8318bc..89d43e8a43825 100644 --- a/doc/source/development/extending.rst +++ b/doc/source/development/extending.rst @@ -251,6 +251,48 @@ To use a test, subclass it: See https://github.com/pandas-dev/pandas/blob/master/pandas/tests/extension/base/__init__.py for a list of all the tests available. +.. _extending.extension.arrow: + +Compatibility with Apache Arrow +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +An ``ExtensionArray`` can support conversion to / from ``pyarrow`` arrays +(and thus support for example serialization to the Parquet file format) +by implementing two methods: ``ExtensionArray.__arrow_array__`` and +``ExtensionDtype.__from_arrow__``. + +The ``ExtensionArray.__arrow_array__`` ensures that ``pyarrow`` knowns how +to convert the specific extension array into a ``pyarrow.Array`` (also when +included as a column in a pandas DataFrame): + +.. code-block:: python + + class MyExtensionArray(ExtensionArray): + ... + + def __arrow_array__(self, type=None): + # convert the underlying array values to a pyarrow Array + import pyarrow + return pyarrow.array(..., type=type) + +The ``ExtensionDtype.__from_arrow__`` method then controls the conversion +back from pyarrow to a pandas ExtensionArray. This method receives a pyarrow +``Array`` or ``ChunkedArray`` as only argument and is expected to return the +appropriate pandas ``ExtensionArray`` for this dtype and the passed values: + +.. code-block:: none + + class ExtensionDtype: + ... + + def __from_arrow__(self, array: pyarrow.Array/ChunkedArray) -> ExtensionArray: + ... + +See more in the `Arrow documentation `__. + +Those methods have been implemented for the nullable integer and string extension +dtypes included in pandas, and ensure roundtrip to pyarrow and the Parquet file format. + .. _extension dtype dtypes: https://github.com/pandas-dev/pandas/blob/master/pandas/core/dtypes/dtypes.py .. _extension dtype source: https://github.com/pandas-dev/pandas/blob/master/pandas/core/dtypes/base.py .. _extension array source: https://github.com/pandas-dev/pandas/blob/master/pandas/core/arrays/base.py diff --git a/doc/source/user_guide/io.rst b/doc/source/user_guide/io.rst index 6e45d6748c2a5..fa47a5944f7bf 100644 --- a/doc/source/user_guide/io.rst +++ b/doc/source/user_guide/io.rst @@ -4716,6 +4716,9 @@ Several caveats. * The ``pyarrow`` engine preserves the ``ordered`` flag of categorical dtypes with string types. ``fastparquet`` does not preserve the ``ordered`` flag. * Non supported types include ``Period`` and actual Python object types. These will raise a helpful error message on an attempt at serialization. +* The ``pyarrow`` engine preserves extension data types such as the nullable integer and string data + type (requiring pyarrow >= 1.0.0, and requiring the extension type to implement the needed protocols, + see the :ref:`extension types documentation `). You can specify an ``engine`` to direct the serialization. This can be one of ``pyarrow``, or ``fastparquet``, or ``auto``. If the engine is NOT specified, then the ``pd.options.io.parquet.engine`` option is checked; if this is also ``auto``, diff --git a/doc/source/whatsnew/v1.0.0.rst b/doc/source/whatsnew/v1.0.0.rst index 0027343a13b60..de8273f890436 100644 --- a/doc/source/whatsnew/v1.0.0.rst +++ b/doc/source/whatsnew/v1.0.0.rst @@ -114,6 +114,9 @@ Other enhancements - Added ``encoding`` argument to :meth:`DataFrame.to_string` for non-ascii text (:issue:`28766`) - Added ``encoding`` argument to :func:`DataFrame.to_html` for non-ascii text (:issue:`28663`) - :meth:`Styler.background_gradient` now accepts ``vmin`` and ``vmax`` arguments (:issue:`12145`) +- Roundtripping DataFrames with nullable integer or string data types to parquet + (:meth:`~DataFrame.to_parquet` / :func:`read_parquet`) using the `'pyarrow'` engine + now preserve those data types with pyarrow >= 1.0.0 (:issue:`20612`). Build Changes ^^^^^^^^^^^^^ diff --git a/pandas/core/arrays/integer.py b/pandas/core/arrays/integer.py index af7755fb1373d..63296b4a26354 100644 --- a/pandas/core/arrays/integer.py +++ b/pandas/core/arrays/integer.py @@ -85,6 +85,35 @@ def construct_array_type(cls): """ return IntegerArray + def __from_arrow__(self, array): + """Construct IntegerArray from passed pyarrow Array/ChunkedArray""" + import pyarrow + + if isinstance(array, pyarrow.Array): + chunks = [array] + else: + # pyarrow.ChunkedArray + chunks = array.chunks + + results = [] + for arr in chunks: + buflist = arr.buffers() + data = np.frombuffer(buflist[1], dtype=self.type)[ + arr.offset : arr.offset + len(arr) + ] + bitmask = buflist[0] + if bitmask is not None: + mask = pyarrow.BooleanArray.from_buffers( + pyarrow.bool_(), len(arr), [None, bitmask] + ) + mask = np.asarray(mask) + else: + mask = np.ones(len(arr), dtype=bool) + int_arr = IntegerArray(data.copy(), ~mask, copy=False) + results.append(int_arr) + + return IntegerArray._concat_same_type(results) + def integer_array(values, dtype=None, copy=False): """ diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 7c487b227de20..8599b5e39f34a 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -85,6 +85,24 @@ def construct_array_type(cls) -> "Type[StringArray]": def __repr__(self) -> str: return "StringDtype" + def __from_arrow__(self, array): + """Construct StringArray from passed pyarrow Array/ChunkedArray""" + import pyarrow + + if isinstance(array, pyarrow.Array): + chunks = [array] + else: + # pyarrow.ChunkedArray + chunks = array.chunks + + results = [] + for arr in chunks: + # using _from_sequence to ensure None is convered to np.nan + str_arr = StringArray._from_sequence(np.array(arr)) + results.append(str_arr) + + return StringArray._concat_same_type(results) + class StringArray(PandasArray): """ diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index efe2b4e0b2deb..1ce62d8f8b3d9 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -171,3 +171,19 @@ def test_arrow_array(): arr = pa.array(data) expected = pa.array(list(data), type=pa.string(), from_pandas=True) assert arr.equals(expected) + + +@td.skip_if_no("pyarrow", min_version="0.15.1.dev") +def test_arrow_roundtrip(): + # roundtrip possible from arrow 1.0.0 + import pyarrow as pa + + data = pd.array(["a", "b", None], dtype="string") + df = pd.DataFrame({"a": data}) + table = pa.table(df) + assert table.field("a").type == "string" + result = table.to_pandas() + assert isinstance(result["a"].dtype, pd.StringDtype) + tm.assert_frame_equal(result, df) + # ensure the missing value is represented by NaN and not None + assert np.isnan(result.loc[2, "a"]) diff --git a/pandas/tests/arrays/test_integer.py b/pandas/tests/arrays/test_integer.py index 025366e5b210b..443a0c7e71616 100644 --- a/pandas/tests/arrays/test_integer.py +++ b/pandas/tests/arrays/test_integer.py @@ -829,6 +829,18 @@ def test_arrow_array(data): assert arr.equals(expected) +@td.skip_if_no("pyarrow", min_version="0.15.1.dev") +def test_arrow_roundtrip(data): + # roundtrip possible from arrow 1.0.0 + import pyarrow as pa + + df = pd.DataFrame({"a": data}) + table = pa.table(df) + assert table.field("a").type == str(data.dtype.numpy_dtype) + result = table.to_pandas() + tm.assert_frame_equal(result, df) + + @pytest.mark.parametrize( "pandasmethname, kwargs", [ diff --git a/pandas/tests/io/test_parquet.py b/pandas/tests/io/test_parquet.py index 5dd671c659263..bcbbee3b86769 100644 --- a/pandas/tests/io/test_parquet.py +++ b/pandas/tests/io/test_parquet.py @@ -514,13 +514,19 @@ def test_additional_extension_arrays(self, pa): "b": pd.Series(["a", None, "c"], dtype="string"), } ) - # currently de-serialized as plain int / object - expected = df.assign(a=df.a.astype("int64"), b=df.b.astype("object")) + if LooseVersion(pyarrow.__version__) >= LooseVersion("0.15.1.dev"): + expected = df + else: + # de-serialized as plain int / object + expected = df.assign(a=df.a.astype("int64"), b=df.b.astype("object")) check_round_trip(df, pa, expected=expected) df = pd.DataFrame({"a": pd.Series([1, 2, 3, None], dtype="Int64")}) - # if missing values in integer, currently de-serialized as float - expected = df.assign(a=df.a.astype("float64")) + if LooseVersion(pyarrow.__version__) >= LooseVersion("0.15.1.dev"): + expected = df + else: + # if missing values in integer, currently de-serialized as float + expected = df.assign(a=df.a.astype("float64")) check_round_trip(df, pa, expected=expected)