@@ -41,11 +41,11 @@ def _get_standard_kind(kind):
4141
4242if LooseVersion (pd .__version__ ) < LooseVersion ('0.25' ):
4343 from pandas .plotting ._core import _all_kinds , BarPlot , BoxPlot , HistPlot , MPLPlot , PiePlot , \
44- AreaPlot , LinePlot , BarhPlot
44+ AreaPlot , LinePlot , BarhPlot , ScatterPlot
4545else :
4646 from pandas .plotting ._core import PlotAccessor
4747 from pandas .plotting ._matplotlib import BarPlot , BoxPlot , HistPlot , PiePlot , AreaPlot , \
48- LinePlot , BarhPlot
48+ LinePlot , BarhPlot , ScatterPlot
4949 from pandas .plotting ._matplotlib .core import MPLPlot
5050 _all_kinds = PlotAccessor ._all_kinds
5151
@@ -509,6 +509,16 @@ def _make_plot(self):
509509 super (KoalasBarhPlot , self )._make_plot ()
510510
511511
512+ class KoalasScatterPlot (ScatterPlot , TopNPlot ):
513+
514+ def __init__ (self , data , x , y , ** kwargs ):
515+ super ().__init__ (self .get_top_n (data ), x , y , ** kwargs )
516+
517+ def _make_plot (self ):
518+ self .set_result_text (self ._get_ax (0 ))
519+ super (KoalasScatterPlot , self )._make_plot ()
520+
521+
512522_klasses = [
513523 KoalasHistPlot ,
514524 KoalasBarPlot ,
@@ -517,6 +527,7 @@ def _make_plot(self):
517527 KoalasAreaPlot ,
518528 KoalasLinePlot ,
519529 KoalasBarhPlot ,
530+ KoalasScatterPlot ,
520531]
521532_plot_klass = {getattr (klass , '_kind' ): klass for klass in _klasses }
522533
@@ -651,15 +662,20 @@ def _plot(data, x=None, y=None, subplots=False,
651662 else :
652663 raise ValueError ("%r is not a valid plot kind" % kind )
653664
654- # check data type and do preprocess before applying plot
655- if isinstance (data , DataFrame ):
656- if x is not None :
657- data = data .set_index (x )
658- # TODO: check if value of y is plottable
659- if y is not None :
660- data = data [y ]
665+ # scatter and hexbin are inherited from PlanePlot which require x and y
666+ if kind in ('scatter' , 'hexbin' ):
667+ plot_obj = klass (data , x , y , subplots = subplots , ax = ax , kind = kind , ** kwds )
668+ else :
669+
670+ # check data type and do preprocess before applying plot
671+ if isinstance (data , DataFrame ):
672+ if x is not None :
673+ data = data .set_index (x )
674+ # TODO: check if value of y is plottable
675+ if y is not None :
676+ data = data [y ]
661677
662- plot_obj = klass (data , subplots = subplots , ax = ax , kind = kind , ** kwds )
678+ plot_obj = klass (data , subplots = subplots , ax = ax , kind = kind , ** kwds )
663679 plot_obj .generate ()
664680 plot_obj .draw ()
665681 return plot_obj .result
@@ -1082,8 +1098,41 @@ def box(self, bw_method=None, ind=None, **kwds):
10821098 def hist (self , bw_method = None , ind = None , ** kwds ):
10831099 return _unsupported_function (class_name = 'pd.DataFrame' , method_name = 'hist' )()
10841100
1085- def scatter (self , bw_method = None , ind = None , ** kwds ):
1086- return _unsupported_function (class_name = 'pd.DataFrame' , method_name = 'scatter' )()
1101+ def scatter (self , x , y , s = None , c = None , ** kwds ):
1102+ """
1103+ Create a scatter plot with varying marker point size and color.
1104+
1105+ The coordinates of each point are defined by two dataframe columns and
1106+ filled circles are used to represent each point. This kind of plot is
1107+ useful to see complex correlations between two variables. Points could
1108+ be for instance natural 2D coordinates like longitude and latitude in
1109+ a map or, in general, any pair of metrics that can be plotted against
1110+ each other.
1111+
1112+ Parameters
1113+ ----------
1114+ x : int or str
1115+ The column name or column position to be used as horizontal
1116+ coordinates for each point.
1117+ y : int or str
1118+ The column name or column position to be used as vertical
1119+ coordinates for each point.
1120+ s : scalar or array_like, optional
1121+ c : str, int or array_like, optional
1122+
1123+ **kwds: Optional
1124+ Keyword arguments to pass on to :meth:`databricks.koalas.DataFrame.plot`.
1125+
1126+ Returns
1127+ -------
1128+ :class:`matplotlib.axes.Axes` or numpy.ndarray of them
1129+
1130+ See Also
1131+ --------
1132+ matplotlib.pyplot.scatter : Scatter plot using multiple input data
1133+ formats.
1134+ """
1135+ return self (kind = "scatter" , x = x , y = y , s = s , c = c , ** kwds )
10871136
10881137
10891138def plot_frame (data , x = None , y = None , kind = 'line' , ax = None ,
@@ -1116,6 +1165,7 @@ def plot_frame(data, x=None, y=None, kind='line', ax=None,
11161165 - 'density' : same as 'kde'
11171166 - 'area' : area plot
11181167 - 'pie' : pie plot
1168+ - 'scatter' : scatter plot
11191169 ax : matplotlib axes object
11201170 If not passed, uses gca()
11211171 x : label or position, default None
0 commit comments