diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ba4bde693d..bdb01caf1a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,7 +23,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). this feature was anonymously sponsored: thank you to our sponsor! - Add `legend.xref` and `legend.yref` to enable container-referenced positioning of legends [[#6589](https://github.com/plotly/plotly.js/pull/6589)], with thanks to [Gamma Technologies](https://www.gtisoft.com/) for sponsoring the related development. - Add `colorbar.xref` and `colorbar.yref` to enable container-referenced positioning of colorbars [[#6593](https://github.com/plotly/plotly.js/pull/6593)], with thanks to [Gamma Technologies](https://www.gtisoft.com/) for sponsoring the related development. - - `px` methods now accept data-frame-like objects that support a `to_pandas()` method, such as polars, cudf, vaex etc + - `px` methods now accept data-frame-like objects that support a `to_pandas()` method, such as polars, cudf, vaex etc [[#4244](https://github.com/plotly/plotly.py/pull/4244)], [[#4286](https://github.com/plotly/plotly.py/pull/4286)] ### Fixed - Fixed another compatibility issue with Pandas 2.0, just affecting `px.*(line_close=True)` [[#4190](https://github.com/plotly/plotly.py/pull/4190)] diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 452c0a7ff79..937ef9b5da7 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1018,7 +1018,7 @@ def _get_reserved_col_names(args): return reserved_names -def _is_col_list(df_input, arg): +def _is_col_list(columns, arg): """Returns True if arg looks like it's a list of columns or references to columns in df_input, and False otherwise (in which case it's assumed to be a single column or reference to a column). @@ -1033,7 +1033,7 @@ def _is_col_list(df_input, arg): return False # not iterable for c in arg: if isinstance(c, str) or isinstance(c, int): - if df_input is None or c not in df_input.columns: + if columns is None or c not in columns: return False else: try: @@ -1059,8 +1059,8 @@ def _isinstance_listlike(x): return True -def _escape_col_name(df_input, col_name, extra): - while df_input is not None and (col_name in df_input.columns or col_name in extra): +def _escape_col_name(columns, col_name, extra): + while columns is not None and (col_name in columns or col_name in extra): col_name = "_" + col_name return col_name @@ -1307,6 +1307,7 @@ def build_dataframe(args, constructor): # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.) df_provided = args["data_frame"] is not None + needs_interchanging = False if df_provided and not isinstance(args["data_frame"], pd.DataFrame): if hasattr(args["data_frame"], "__dataframe__") and version.parse( pd.__version__ @@ -1314,30 +1315,28 @@ def build_dataframe(args, constructor): import pandas.api.interchange df_not_pandas = args["data_frame"] - try: - df_pandas = pandas.api.interchange.from_dataframe(df_not_pandas) - except (ImportError, NotImplementedError) as exc: - # temporary workaround; developers of third-party libraries themselves - # should try a different implementation, if available. For example: - # def __dataframe__(self, ...): - # if not some_condition: - # self.to_pandas(...) - if not hasattr(df_not_pandas, "to_pandas"): - raise exc - df_pandas = df_not_pandas.to_pandas() - args["data_frame"] = df_pandas + args["data_frame"] = df_not_pandas.__dataframe__() + columns = args["data_frame"].column_names() + needs_interchanging = True elif hasattr(args["data_frame"], "to_pandas"): args["data_frame"] = args["data_frame"].to_pandas() + columns = args["data_frame"].columns else: args["data_frame"] = pd.DataFrame(args["data_frame"]) + columns = args["data_frame"].columns + elif df_provided: + columns = args["data_frame"].columns + else: + columns = None + df_input = args["data_frame"] # now we handle special cases like wide-mode or x-xor-y specification # by rearranging args to tee things up for process_args_into_dataframe to work no_x = args.get("x") is None no_y = args.get("y") is None - wide_x = False if no_x else _is_col_list(df_input, args["x"]) - wide_y = False if no_y else _is_col_list(df_input, args["y"]) + wide_x = False if no_x else _is_col_list(columns, args["x"]) + wide_y = False if no_y else _is_col_list(columns, args["y"]) wide_mode = False var_name = None # will likely be "variable" in wide_mode @@ -1352,15 +1351,18 @@ def build_dataframe(args, constructor): ) if df_provided and no_x and no_y: wide_mode = True - if isinstance(df_input.columns, pd.MultiIndex): + if isinstance(columns, pd.MultiIndex): raise TypeError( "Data frame columns is a pandas MultiIndex. " "pandas MultiIndex is not supported by plotly express " "at the moment." ) - args["wide_variable"] = list(df_input.columns) - var_name = df_input.columns.name - if var_name in [None, "value", "index"] or var_name in df_input: + args["wide_variable"] = list(columns) + if isinstance(columns, pd.Index): + var_name = columns.name + else: + var_name = None + if var_name in [None, "value", "index"] or var_name in columns: var_name = "variable" if constructor == go.Funnel: wide_orientation = args.get("orientation") or "h" @@ -1371,12 +1373,12 @@ def build_dataframe(args, constructor): elif wide_x != wide_y: wide_mode = True args["wide_variable"] = args["y"] if wide_y else args["x"] - if df_provided and args["wide_variable"] is df_input.columns: - var_name = df_input.columns.name + if df_provided and args["wide_variable"] is columns: + var_name = columns.name if isinstance(args["wide_variable"], pd.Index): args["wide_variable"] = list(args["wide_variable"]) if var_name in [None, "value", "index"] or ( - df_provided and var_name in df_input + df_provided and var_name in columns ): var_name = "variable" if hist1d_orientation: @@ -1389,8 +1391,35 @@ def build_dataframe(args, constructor): wide_cross_name = "__x__" if wide_y else "__y__" if wide_mode: - value_name = _escape_col_name(df_input, "value", []) - var_name = _escape_col_name(df_input, var_name, []) + value_name = _escape_col_name(columns, "value", []) + var_name = _escape_col_name(columns, var_name, []) + + if needs_interchanging: + try: + if wide_mode or not hasattr(args["data_frame"], "select_columns_by_name"): + args["data_frame"] = pd.api.interchange.from_dataframe( + args["data_frame"] + ) + else: + # Save precious resources by only interchanging columns that are + # actually going to be plotted. + columns = [ + i for i in args.values() if isinstance(i, str) and i in columns + ] + args["data_frame"] = pd.api.interchange.from_dataframe( + args["data_frame"].select_columns_by_name(columns) + ) + except (ImportError, NotImplementedError) as exc: + # temporary workaround; developers of third-party libraries themselves + # should try a different implementation, if available. For example: + # def __dataframe__(self, ...): + # if not some_condition: + # self.to_pandas(...) + if not hasattr(df_not_pandas, "to_pandas"): + raise exc + args["data_frame"] = df_not_pandas.to_pandas() + + df_input = args["data_frame"] missing_bar_dim = None if ( diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_input.py index 1acbf3f1e64..fa0f1298fdc 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_input.py @@ -252,21 +252,58 @@ def test_build_df_with_index(): def test_build_df_using_interchange_protocol_mock( add_interchange_module_for_old_pandas, ): + class InterchangeDataFrame: + def __init__(self, columns): + self._columns = columns + + def column_names(self): + return self._columns + + interchange_dataframe = InterchangeDataFrame( + ["petal_width", "sepal_length", "sepal_width"] + ) + interchange_dataframe_reduced = InterchangeDataFrame( + ["petal_width", "sepal_length"] + ) + interchange_dataframe.select_columns_by_name = mock.MagicMock( + return_value=interchange_dataframe_reduced + ) + interchange_dataframe_reduced.select_columns_by_name = mock.MagicMock( + return_value=interchange_dataframe_reduced + ) + class CustomDataFrame: def __dataframe__(self): - pass + return interchange_dataframe + + class CustomDataFrameReduced: + def __dataframe__(self): + return interchange_dataframe_reduced input_dataframe = CustomDataFrame() - args = dict(data_frame=input_dataframe, x="petal_width", y="sepal_length") + input_dataframe_reduced = CustomDataFrameReduced() iris_pandas = px.data.iris() with mock.patch("pandas.__version__", "2.0.2"): + args = dict(data_frame=input_dataframe, x="petal_width", y="sepal_length") with mock.patch( "pandas.api.interchange.from_dataframe", return_value=iris_pandas ) as mock_from_dataframe: build_dataframe(args, go.Scatter) - mock_from_dataframe.assert_called_once_with(input_dataframe) + mock_from_dataframe.assert_called_once_with(interchange_dataframe_reduced) + interchange_dataframe.select_columns_by_name.assert_called_with( + ["petal_width", "sepal_length"] + ) + + args = dict(data_frame=input_dataframe_reduced, color=None) + with mock.patch( + "pandas.api.interchange.from_dataframe", + return_value=iris_pandas[["petal_width", "sepal_length"]], + ) as mock_from_dataframe: + build_dataframe(args, go.Scatter) + mock_from_dataframe.assert_called_once_with(interchange_dataframe_reduced) + interchange_dataframe_reduced.select_columns_by_name.assert_not_called() @pytest.mark.skipif(