Skip to content

Commit a1efa61

Browse files
charlesdong1991HyukjinKwon
authored andcommitted
1 parent 5537d71 commit a1efa61

File tree

3 files changed

+83
-13
lines changed

3 files changed

+83
-13
lines changed

databricks/koalas/plot.py

Lines changed: 62 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ def _get_standard_kind(kind):
4141

4242
if LooseVersion(pd.__version__) < LooseVersion('0.25'):
4343
from pandas.plotting._core import _all_kinds, BarPlot, BoxPlot, HistPlot, MPLPlot, PiePlot, \
44-
AreaPlot, LinePlot, BarhPlot
44+
AreaPlot, LinePlot, BarhPlot, ScatterPlot
4545
else:
4646
from pandas.plotting._core import PlotAccessor
4747
from pandas.plotting._matplotlib import BarPlot, BoxPlot, HistPlot, PiePlot, AreaPlot, \
48-
LinePlot, BarhPlot
48+
LinePlot, BarhPlot, ScatterPlot
4949
from pandas.plotting._matplotlib.core import MPLPlot
5050
_all_kinds = PlotAccessor._all_kinds
5151

@@ -509,6 +509,16 @@ def _make_plot(self):
509509
super(KoalasBarhPlot, self)._make_plot()
510510

511511

512+
class KoalasScatterPlot(ScatterPlot, TopNPlot):
513+
514+
def __init__(self, data, x, y, **kwargs):
515+
super().__init__(self.get_top_n(data), x, y, **kwargs)
516+
517+
def _make_plot(self):
518+
self.set_result_text(self._get_ax(0))
519+
super(KoalasScatterPlot, self)._make_plot()
520+
521+
512522
_klasses = [
513523
KoalasHistPlot,
514524
KoalasBarPlot,
@@ -517,6 +527,7 @@ def _make_plot(self):
517527
KoalasAreaPlot,
518528
KoalasLinePlot,
519529
KoalasBarhPlot,
530+
KoalasScatterPlot,
520531
]
521532
_plot_klass = {getattr(klass, '_kind'): klass for klass in _klasses}
522533

@@ -651,15 +662,20 @@ def _plot(data, x=None, y=None, subplots=False,
651662
else:
652663
raise ValueError("%r is not a valid plot kind" % kind)
653664

654-
# check data type and do preprocess before applying plot
655-
if isinstance(data, DataFrame):
656-
if x is not None:
657-
data = data.set_index(x)
658-
# TODO: check if value of y is plottable
659-
if y is not None:
660-
data = data[y]
665+
# scatter and hexbin are inherited from PlanePlot which require x and y
666+
if kind in ('scatter', 'hexbin'):
667+
plot_obj = klass(data, x, y, subplots=subplots, ax=ax, kind=kind, **kwds)
668+
else:
669+
670+
# check data type and do preprocess before applying plot
671+
if isinstance(data, DataFrame):
672+
if x is not None:
673+
data = data.set_index(x)
674+
# TODO: check if value of y is plottable
675+
if y is not None:
676+
data = data[y]
661677

662-
plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds)
678+
plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds)
663679
plot_obj.generate()
664680
plot_obj.draw()
665681
return plot_obj.result
@@ -1082,8 +1098,41 @@ def box(self, bw_method=None, ind=None, **kwds):
10821098
def hist(self, bw_method=None, ind=None, **kwds):
10831099
return _unsupported_function(class_name='pd.DataFrame', method_name='hist')()
10841100

1085-
def scatter(self, bw_method=None, ind=None, **kwds):
1086-
return _unsupported_function(class_name='pd.DataFrame', method_name='scatter')()
1101+
def scatter(self, x, y, s=None, c=None, **kwds):
1102+
"""
1103+
Create a scatter plot with varying marker point size and color.
1104+
1105+
The coordinates of each point are defined by two dataframe columns and
1106+
filled circles are used to represent each point. This kind of plot is
1107+
useful to see complex correlations between two variables. Points could
1108+
be for instance natural 2D coordinates like longitude and latitude in
1109+
a map or, in general, any pair of metrics that can be plotted against
1110+
each other.
1111+
1112+
Parameters
1113+
----------
1114+
x : int or str
1115+
The column name or column position to be used as horizontal
1116+
coordinates for each point.
1117+
y : int or str
1118+
The column name or column position to be used as vertical
1119+
coordinates for each point.
1120+
s : scalar or array_like, optional
1121+
c : str, int or array_like, optional
1122+
1123+
**kwds: Optional
1124+
Keyword arguments to pass on to :meth:`databricks.koalas.DataFrame.plot`.
1125+
1126+
Returns
1127+
-------
1128+
:class:`matplotlib.axes.Axes` or numpy.ndarray of them
1129+
1130+
See Also
1131+
--------
1132+
matplotlib.pyplot.scatter : Scatter plot using multiple input data
1133+
formats.
1134+
"""
1135+
return self(kind="scatter", x=x, y=y, s=s, c=c, **kwds)
10871136

10881137

10891138
def plot_frame(data, x=None, y=None, kind='line', ax=None,
@@ -1116,6 +1165,7 @@ def plot_frame(data, x=None, y=None, kind='line', ax=None,
11161165
- 'density' : same as 'kde'
11171166
- 'area' : area plot
11181167
- 'pie' : pie plot
1168+
- 'scatter' : scatter plot
11191169
ax : matplotlib axes object
11201170
If not passed, uses gca()
11211171
x : label or position, default None

databricks/koalas/tests/test_frame_plot.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import matplotlib
55
from matplotlib import pyplot as plt
66
import pandas as pd
7+
import numpy as np
78

89
from databricks import koalas
910
from databricks.koalas.exceptions import PandasNotImplementedError
@@ -173,10 +174,28 @@ def test_pie_plot_error_message(self):
173174
error_message = "pie requires either y column or 'subplots=True'"
174175
self.assertTrue(error_message in str(context.exception))
175176

177+
def test_scatter_plot(self):
178+
# Use pandas scatter plot example
179+
pdf = pd.DataFrame(np.random.rand(50, 4), columns=['a', 'b', 'c', 'd'])
180+
kdf = koalas.from_pandas(pdf)
181+
182+
ax1 = pdf.plot.scatter(x='a', y='b')
183+
ax2 = kdf.plot.scatter(x='a', y='b')
184+
self.compare_plots(ax1, ax2)
185+
186+
ax1 = pdf.plot(kind='scatter', x='a', y='b')
187+
ax2 = kdf.plot(kind='scatter', x='a', y='b')
188+
self.compare_plots(ax1, ax2)
189+
190+
# check when keyword c is given as name of a column
191+
ax1 = pdf.plot.scatter(x='a', y='b', c='c', s=50)
192+
ax2 = kdf.plot.scatter(x='a', y='b', c='c', s=50)
193+
self.compare_plots(ax1, ax2)
194+
176195
def test_missing(self):
177196
ks = self.kdf1
178197

179-
unsupported_functions = ['box', 'density', 'hexbin', 'hist', 'kde', 'scatter']
198+
unsupported_functions = ['box', 'density', 'hexbin', 'hist', 'kde']
180199

181200
for name in unsupported_functions:
182201
with self.assertRaisesRegex(PandasNotImplementedError,

docs/source/reference/frame.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,4 @@ specific plotting methods of the form ``DataFrame.plot.<kind>``.
237237
DataFrame.plot.barh
238238
DataFrame.plot.bar
239239
DataFrame.plot.pie
240+
DataFrame.plot.scatter

0 commit comments

Comments
 (0)