Skip to content

API for adding labels: mpf.make_addplot(..., label="myLabel") #605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 1, 2023
Merged
595 changes: 595 additions & 0 deletions examples/addplot_legends.ipynb

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions src/mplfinance/_arg_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import matplotlib as mpl
import warnings


def _check_and_prepare_data(data, config):
'''
Check and Prepare the data input:
Expand Down Expand Up @@ -94,6 +95,16 @@ def _check_and_prepare_data(data, config):

return dates, opens, highs, lows, closes, volumes


def _label_validator(label_value):
''' Validates the input of label for the added plots.
label_value may be a str or a list of str.
'''
if isinstance(label_value,str):
return True
return False


def _get_valid_plot_types(plottype=None):

_alias_types = {
Expand Down
2 changes: 1 addition & 1 deletion src/mplfinance/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version_info = (0, 12, 9, 'beta', 8)
version_info = (0, 12, 9, 'beta', 9)

_specifier_ = {'alpha': 'a','beta': 'b','candidate': 'rc','final': ''}

Expand Down
22 changes: 16 additions & 6 deletions src/mplfinance/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from mplfinance import _styles

from mplfinance._arg_validators import _check_and_prepare_data, _mav_validator
from mplfinance._arg_validators import _check_and_prepare_data, _mav_validator, _label_validator
from mplfinance._arg_validators import _get_valid_plot_types, _fill_between_validator
from mplfinance._arg_validators import _process_kwargs, _validate_vkwargs_dict
from mplfinance._arg_validators import _kwarg_not_implemented, _bypass_kwarg_validation
Expand Down Expand Up @@ -752,6 +752,8 @@ def plot( data, **kwargs ):

elif not _list_of_dict(addplot):
raise TypeError('addplot must be `dict`, or `list of dict`, NOT '+str(type(addplot)))

contains_legend_label=[] # a list of axes that contains legend labels

for apdict in addplot:

Expand Down Expand Up @@ -779,6 +781,10 @@ def plot( data, **kwargs ):
ydata = apdata.loc[:,column] if havedf else column
ax = _addplot_columns(panid,panels,ydata,apdict,xdates,config)
_addplot_apply_supplements(ax,apdict,xdates)
if apdict["label"]: # not supported for aptype == 'ohlc' or 'candle'
contains_legend_label.append(ax)
for ax in set(contains_legend_label): # there might be duplicates
ax.legend()

# fill_between is NOT supported for external_axes_mode
# (caller can easily call ax.fill_between() themselves).
Expand Down Expand Up @@ -1088,6 +1094,7 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):
ax = apdict['ax']

aptype = apdict['type']
label = apdict['label']
if aptype == 'scatter':
size = apdict['markersize']
mark = apdict['marker']
Expand All @@ -1098,27 +1105,27 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config):

if isinstance(mark,(list,tuple,np.ndarray)):
_mscatter(xdates, ydata, ax=ax, m=mark, s=size, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths)
else:
ax.scatter(xdates, ydata, s=size, marker=mark, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths)
else:
ax.scatter(xdates, ydata, s=size, marker=mark, color=color, alpha=alpha, edgecolors=edgecolors, linewidths=linewidths,label=label)
elif aptype == 'bar':
width = 0.8 if apdict['width'] is None else apdict['width']
bottom = apdict['bottom']
color = apdict['color']
alpha = apdict['alpha']
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha)
ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha,label=label)
elif aptype == 'line':
ls = apdict['linestyle']
color = apdict['color']
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
alpha = apdict['alpha']
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha)
ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha,label=label)
elif aptype == 'step':
stepwhere = apdict['stepwhere']
ls = apdict['linestyle']
color = apdict['color']
width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width']
alpha = apdict['alpha']
ax.step(xdates,ydata,where = stepwhere,linestyle=ls,color=color,linewidth=width,alpha=alpha)
ax.step(xdates,ydata,where = stepwhere,linestyle=ls,color=color,linewidth=width,alpha=alpha,label=label)
else:
raise ValueError('addplot type "'+str(aptype)+'" NOT yet supported.')

Expand Down Expand Up @@ -1371,6 +1378,9 @@ def _valid_addplot_kwargs():
'fill_between': { 'Default' : None, # added by Wen
'Description' : " fill region",
'Validator' : _fill_between_validator },
"label" : { 'Default' : None,
'Description' : 'Label for the added plot. One per added plot.',
'Validator' : _label_validator },

}

Expand Down