Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions databricks/koalas/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ def _get_plot_backend(backend=None):
return KoalasPlotAccessor._backends[backend]

module = KoalasPlotAccessor._find_backend(backend)

if backend == "plotly":
from databricks.koalas.plot.plotly import plot_plotly

module.plot = plot_plotly(module.plot)

KoalasPlotAccessor._backends[backend] = module
return module

Expand Down Expand Up @@ -714,7 +720,7 @@ def area(self, x=None, y=None, **kwds):
elif isinstance(self.data, DataFrame):
return self(kind="area", x=x, y=y, **kwds)

def pie(self, y=None, **kwds):
def pie(self, **kwds):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we annotate the return type here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add them all in a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good!

"""
Generate a pie plot.

Expand All @@ -728,7 +734,7 @@ def pie(self, y=None, **kwds):
----------
y : int or label, optional
Label or position of the column to plot.
If not provided, ``subplots=True`` argument must be passed.
If not provided, ``subplots=True`` argument must be passed (matplotlib-only).
**kwds
Keyword arguments to pass on to :meth:`Koalas.Series.plot`.

Expand Down Expand Up @@ -764,9 +770,15 @@ def pie(self, y=None, **kwds):
return self(kind="pie", **kwds)
else:
# pandas will raise an error if y is None and subplots if not True
if isinstance(self.data, DataFrame) and y is None and not kwds.get("subplots", False):
raise ValueError("pie requires either y column or 'subplots=True'")
return self(kind="pie", y=y, **kwds)
if (
isinstance(self.data, DataFrame)
and kwds.get("y", None) is None
and not kwds.get("subplots", False)
):
raise ValueError(
"pie requires either y column or 'subplots=True' (matplotlib-only)"
)
return self(kind="pie", **kwds)

def scatter(self, x, y, **kwds):
"""
Expand Down
51 changes: 51 additions & 0 deletions databricks/koalas/plot/plotly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#
# Copyright (C) 2019 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pandas as pd


def plot_plotly(origin_plot):
def plot(data, kind, **kwargs):
# Koalas specific plots
if kind == "pie":
return plot_pie(data, **kwargs)

# Other plots.
return origin_plot(data, kind, **kwargs)

return plot


def plot_pie(data, **kwargs):
from plotly import express

if isinstance(data, pd.Series):
pdf = data.to_frame()
return express.pie(pdf, values=pdf.columns[0], names=pdf.index, **kwargs)
elif isinstance(data, pd.DataFrame):
# DataFrame
values = kwargs.pop("y", None)
default_names = None
if values is not None:
default_names = data.index

return express.pie(
data,
values=kwargs.pop("values", values),
names=kwargs.pop("names", default_names),
**kwargs
)
else:
raise RuntimeError("Unexpected type: [%s]" % type(data))
31 changes: 31 additions & 0 deletions databricks/koalas/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import pandas as pd
import numpy as np
from plotly import express

from databricks import koalas as ks
from databricks.koalas.config import set_option, reset_option
Expand Down Expand Up @@ -143,3 +144,33 @@ def check_scatter_plot(pdf, kdf, x, y, c):
pdf1 = pd.DataFrame(np.random.rand(50, 4), columns=["a", "b", "c", "d"])
kdf1 = ks.from_pandas(pdf1)
check_scatter_plot(pdf1, kdf1, x="a", y="b", c="c")

def test_pie_plot(self):
def check_pie_plot(kdf):
pdf = kdf.to_pandas()
self.assertEqual(
kdf.plot(kind="pie", y=kdf.columns[0]),
express.pie(pdf, values="a", names=pdf.index),
)

self.assertEqual(
kdf.plot(kind="pie", values="a"), express.pie(pdf, values="a"),
)

kdf1 = self.kdf1
check_pie_plot(kdf1)

# TODO: support multi-index columns
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not supported in plotly .. 😢

# columns = pd.MultiIndex.from_tuples([("x", "y"), ("y", "z")])
# kdf1.columns = columns
# check_pie_plot(kdf1)

# TODO: support multi-index
# kdf1 = ks.DataFrame(
# {
# "a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 50],
# "b": [2, 3, 4, 5, 7, 9, 10, 15, 34, 45, 49]
# },
# index=pd.MultiIndex.from_tuples([("x", "y")] * 11),
# )
# check_pie_plot(kdf1)
30 changes: 30 additions & 0 deletions databricks/koalas/tests/plot/test_series_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from distutils.version import LooseVersion

import pandas as pd
from plotly import express

from databricks import koalas as ks
from databricks.koalas.config import set_option, reset_option
Expand Down Expand Up @@ -98,3 +99,32 @@ def test_area_plot(self):

# just a sanity check for df.col type
self.assertEqual(pdf.sales.plot(kind="area"), kdf.sales.plot(kind="area"))

def test_pie_plot(self):
kdf = self.kdf1
pdf = kdf.to_pandas()
self.assertEqual(
kdf["a"].plot(kind="pie"), express.pie(pdf, values=pdf.columns[0], names=pdf.index),
)

# TODO: support multi-index columns
# columns = pd.MultiIndex.from_tuples([("x", "y")])
# kdf.columns = columns
# pdf.columns = columns
# self.assertEqual(
# kdf[("x", "y")].plot(kind="pie"),
# express.pie(pdf, values=pdf.iloc[:, 0].to_numpy(), names=pdf.index.to_numpy()),
# )

# TODO: support multi-index
# kdf = ks.DataFrame(
# {
# "a": [1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 50],
# "b": [2, 3, 4, 5, 7, 9, 10, 15, 34, 45, 49]
# },
# index=pd.MultiIndex.from_tuples([("x", "y")] * 11),
# )
# pdf = kdf.to_pandas()
# self.assertEqual(
# kdf["a"].plot(kind="pie"), express.pie(pdf, values=pdf.columns[0], names=pdf.index),
# )