Skip to content

WIP: Draft strawman implementation of draft strawman data frame "__dataframe__" interchange protocol for discussion #32908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
35 changes: 35 additions & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
)
from pandas.core.ops.missing import dispatch_fill_zeros
from pandas.core.series import Series
from pandas.protocol.wrapper import DataFrame as DataFrameWrapper

from pandas.io.common import get_filepath_or_buffer
from pandas.io.formats import console, format as fmt
Expand All @@ -138,6 +139,7 @@
if TYPE_CHECKING:
from pandas.core.groupby.generic import DataFrameGroupBy
from pandas.io.formats.style import Styler
from pandas.wesm import dataframe as dataframe_protocol # noqa: F401

# ---------------------------------------------------------------------
# Docstring templates
Expand Down Expand Up @@ -435,6 +437,32 @@ def __init__(
if isinstance(data, DataFrame):
data = data._data

elif hasattr(data, "__dataframe__"):
# construct using dict of numpy arrays
# TODO(simonjayhawkins) index, columns, dtype and copy arguments
obj = cast("dataframe_protocol.DataFrame", data.__dataframe__)

def _get_column(col):
try:
return col.to_numpy()
except NotImplementedError:
return col.to_arrow()

data = {
column_name: _get_column(obj[column_name])
for column_name in obj.column_names
}

if not index:
try:
index = MultiIndex.from_tuples(obj.row_names)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I assumed this would fail if the row names were not tuples.

>>> pd.MultiIndex.from_tuples([[0, 0], [0, 1]])
MultiIndex([(0, 0),
            (0, 1)],
           )

except TypeError:
index = obj.row_names
except NotImplementedError:
# It is not necessary to implement row_names in the
# dataframe interchange protocol
pass

if isinstance(data, BlockManager):
mgr = self._init_mgr(
data, axes=dict(index=index, columns=columns), dtype=dtype, copy=copy
Expand Down Expand Up @@ -520,6 +548,13 @@ def __init__(

NDFrame.__init__(self, mgr)

@property
def __dataframe__(self) -> DataFrameWrapper:
"""
DataFrame interchange protocol
"""
return DataFrameWrapper(self)

# ----------------------------------------------------------------------

@property
Expand Down
Empty file added pandas/protocol/__init__.py
Empty file.
91 changes: 91 additions & 0 deletions pandas/protocol/wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import TYPE_CHECKING, Any, Hashable, Iterable, Sequence

from pandas.wesm import dataframe as dataframe_protocol
from pandas.wesm.example_dict_of_ndarray import NumPyColumn

if TYPE_CHECKING:
import pandas as pd


class Column(NumPyColumn):
"""
Construct generic column from pandas Series

Parameters
----------
ser : pd.Series
"""

_ser: "pd.Series"

def __init__(self, ser: "pd.Series"):
self._ser = ser
super().__init__(ser.name, ser.to_numpy())


class DataFrame(dataframe_protocol.DataFrame):
"""
Construct generic data frame from pandas DataFrame

Parameters
----------
df : pd.DataFrame
"""

_df: "pd.DataFrame"

def __init__(self, df: "pd.DataFrame"):
self._df = df

def __str__(self) -> str:
return str(self._df)

def __repr__(self) -> str:
return repr(self._df)

def column_by_index(self, i: int) -> dataframe_protocol.Column:
"""
Return the column at the indicated position.
"""
return Column(self._df.iloc[:, i])

def column_by_name(self, key: Hashable) -> dataframe_protocol.Column:
"""
Return the column whose name is the indicated key.
"""
return Column(self._df[key])

@property
def column_names(self) -> Sequence[Any]:
"""
Return the column names as a materialized sequence.
"""
return self._df.columns.to_list()

@property
def row_names(self) -> Sequence[Any]:
"""
Return the row names (if any) as a materialized sequence. It is not
necessary to implement this method
"""
return self._df.index.to_list()

def iter_column_names(self) -> Iterable[Any]:
"""
Return the column names as an iterable.
"""
return self.column_names

@property
def num_columns(self) -> int:
"""
Return the number of columns in the DataFrame.
"""
return self._df.shape[1]

@property
def num_rows(self) -> int:
"""
Return the number of rows in the DataFrame.
"""
return len(self._df)
2 changes: 2 additions & 0 deletions pandas/tests/api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ class TestPDApi(Base):
"_tslib",
"_typing",
"_version",
"protocol",
"wesm",
]

def test_api(self):
Expand Down
99 changes: 99 additions & 0 deletions pandas/tests/test_downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from pandas import DataFrame
import pandas._testing as tm
from pandas.protocol.wrapper import DataFrame as DataFrameWrapper
from pandas.wesm import dataframe as dataframe_protocol, example_dict_of_ndarray


def import_module(name):
Expand Down Expand Up @@ -147,3 +149,100 @@ def test_missing_required_dependency():
output = exc.value.stdout.decode()
for name in ["numpy", "pytz", "dateutil"]:
assert name in output


# -----------------------------------------------------------------------------
# DataFrame interchange protocol
# -----------------------------------------------------------------------------


class TestDataFrameProtocol:
def test_interface_smoketest(self):
df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})

result = df.__dataframe__
assert isinstance(result, dataframe_protocol.DataFrame)
assert isinstance(result["a"], dataframe_protocol.Column)
assert isinstance(result.column_by_index(0), dataframe_protocol.Column)
assert isinstance(result["a"].type, dataframe_protocol.DataType)

assert result.num_rows == 3
assert result.num_columns == 2
assert result.column_names == ["a", "b"]
assert list(result.iter_column_names()) == ["a", "b"]
assert result.row_names == [0, 1, 2]

expected = np.array([1, 2, 3], dtype=np.int64)
res = result["a"].to_numpy()
tm.assert_numpy_array_equal(res, expected)
res = result.column_by_index(0).to_numpy()
tm.assert_numpy_array_equal(res, expected)

assert result["a"].name == "a"
assert result.column_by_index(0).name == "a"

expected_type = dataframe_protocol.Int64()
assert result["a"].type == expected_type
assert result.column_by_index(0).type == expected_type

def test_pandas_dataframe_constructor(self):
# TODO(simonjayhawkins): move to test_constructors.py
df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})

result = DataFrame(df)
tm.assert_frame_equal(result, df)
assert result is not df

result = DataFrame(df.__dataframe__)
tm.assert_frame_equal(result, df)
assert result is not df

# It is not necessary to implement row_names in the
# dataframe interchange protocol

# TODO(simonjayhawkins) how to monkeypatch property with pytest
# raises AttributeError: can't set attribute

class _DataFrameWrapper(DataFrameWrapper):
@property
def row_names(self):
raise NotImplementedError("row_names")

result = _DataFrameWrapper(df)
with pytest.raises(NotImplementedError, match="row_names"):
result.row_names

result = DataFrame(result)
tm.assert_frame_equal(result, df)

def test_multiindex(self):
df = (
DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
.reset_index()
.set_index(["index", "a"])
)
result = df.__dataframe__

assert result.row_names == [(0, 1), (1, 2), (2, 3)]

# TODO(simonjayhawkins) split this test and move to test_constructors.py
result = DataFrame(result)
# index and column names are not available from the protocol api
tm.assert_frame_equal(result, df, check_names=False)

df = df.unstack()
result = df.__dataframe__

assert result.column_names == [("b", 1), ("b", 2), ("b", 3)]

# TODO(simonjayhawkins) split this test and move to test_constructors.py
result = DataFrame(result)
# index and column names are not available from the protocol api
tm.assert_frame_equal(result, df, check_names=False)

def test_example_dict_of_ndarray(self):
data, names, df = example_dict_of_ndarray.get_example()
df = DataFrame(df)
expected = DataFrame(data)
tm.assert_frame_equal(df, expected)
assert df.columns.to_list() == names
Empty file added pandas/wesm/__init__.py
Empty file.
Loading