diff --git a/src/mplfinance/_arg_validators.py b/src/mplfinance/_arg_validators.py index dc0465b1..2ee5a3c7 100644 --- a/src/mplfinance/_arg_validators.py +++ b/src/mplfinance/_arg_validators.py @@ -57,6 +57,16 @@ def _check_and_prepare_data(data, config): return dates, opens, highs, lows, closes, volumes +def _label_validator(label_value): + if isinstance(label_value,str): + return True + elif not isinstance(label_value,tuple) and not isinstance(label_value,list): + return False + + for label in label_value: + if not isinstance(label,str): + return False + return True def _mav_validator(mav_value): ''' diff --git a/src/mplfinance/plotting.py b/src/mplfinance/plotting.py index 931cd87e..2d7a247f 100644 --- a/src/mplfinance/plotting.py +++ b/src/mplfinance/plotting.py @@ -29,7 +29,7 @@ from mplfinance import _styles -from mplfinance._arg_validators import _check_and_prepare_data, _mav_validator +from mplfinance._arg_validators import _check_and_prepare_data, _mav_validator, _label_validator from mplfinance._arg_validators import _process_kwargs, _validate_vkwargs_dict from mplfinance._arg_validators import _kwarg_not_implemented, _bypass_kwarg_validation from mplfinance._arg_validators import _hlines_validator, _vlines_validator @@ -148,7 +148,7 @@ def _valid_plot_kwargs(): 'Validator' : lambda value: mcolors.is_color_like(value) }, 'title' : { 'Default' : None, # Figure Title - 'Validator' : lambda value: isinstance(value,(str,dict)) }, + 'Validator' : lambda value: isinstance(value,str) }, 'axtitle' : { 'Default' : None, # Axes Title (subplot title) 'Validator' : lambda value: isinstance(value,str) }, @@ -624,28 +624,16 @@ def plot( data, **kwargs ): if external_axes_mode: volumeAxes.tick_params(axis='x',rotation=xrotation) volumeAxes.xaxis.set_major_formatter(formatter) - + if config['title'] is not None: if config['tight_layout']: # IMPORTANT: 0.89 is based on the top of the top panel # being at 0.18+0.7 = 0.88. See _panels.py # If the value changes there, then it needs to change here. - title_kwargs = dict(size='x-large',weight='semibold', va='bottom', y=0.89) - else: - title_kwargs = dict(size='x-large',weight='semibold', va='center') - if isinstance(config['title'],dict): - title_dict = config['title'] - if 'title' not in title_dict: - raise ValueError('Must have "title" entry in title dict') - else: - title = title_dict['title'] - del title_dict['title'] - title_kwargs.update(title_dict) # allows override default values set by mplfinance above + fig.suptitle(config['title'],size='x-large',weight='semibold', va='bottom', y=0.89) else: - title = config['title'] # config['title'] is a string - fig.suptitle(title,**title_kwargs) - - + fig.suptitle(config['title'],size='x-large',weight='semibold', va='center') + if config['axtitle'] is not None: axA1.set_title(config['axtitle']) @@ -787,22 +775,29 @@ def _addplot_columns(panid,panels,ydata,apdict,xdates,config): mark = apdict['marker'] color = apdict['color'] alpha = apdict['alpha'] + labels = apdict["labels"] if isinstance(mark,(list,tuple,np.ndarray)): - _mscatter(xdates,ydata,ax=ax,m=mark,s=size,color=color,alpha=alpha) + _mscatter(xdates,ydata,ax=ax,m=mark,s=size,color=color,alpha=alpha) #labels in this function needs to be added else: ax.scatter(xdates,ydata,s=size,marker=mark,color=color,alpha=alpha) + if apdict["labels"] is not None: + ax.legend(labels=apdict["labels"]) elif aptype == 'bar': width = 0.8 if apdict['width'] is None else apdict['width'] bottom = apdict['bottom'] color = apdict['color'] alpha = apdict['alpha'] ax.bar(xdates,ydata,width=width,bottom=bottom,color=color,alpha=alpha) + if apdict["labels"] is not None: + ax.legend(labels=apdict["labels"]) elif aptype == 'line': ls = apdict['linestyle'] color = apdict['color'] width = apdict['width'] if apdict['width'] is not None else 1.6*config['_width_config']['line_width'] alpha = apdict['alpha'] ax.plot(xdates,ydata,linestyle=ls,color=color,linewidth=width,alpha=alpha) + if apdict["labels"] is not None: + ax.legend(labels=apdict["labels"]) else: raise ValueError('addplot type "'+str(aptype)+'" NOT yet supported.') @@ -885,6 +880,9 @@ def _valid_addplot_kwargs(): valid_types = ('line','scatter','bar', 'ohlc', 'candle') vkwargs = { + "labels" : { "Default" : None, + "Validator" : _label_validator }, + 'scatter' : { 'Default' : False, 'Validator' : lambda value: isinstance(value,bool) },