diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index a908de65362..9d1af9a53b1 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -20,7 +20,6 @@ from .facetgrid import FacetGrid from xarray.core.pycompat import basestring - def _valid_numpy_subdtype(x, numpy_types): """ Is any dtype from numpy_types superior to the dtype of x? @@ -164,6 +163,10 @@ def line(darray, *args, **kwargs): ---------- darray : DataArray Must be 1 dimensional + x : string, optional + Coordinate for x axis. If None and y not None: 1D plot vertically oriented. + y : string, optional + Coordinate for y axis. If None: 1D plot horizontally oriented. figsize : tuple, optional A tuple (width, height) of the figure in inches. Mutually exclusive with ``size`` and ``ax``. @@ -193,28 +196,138 @@ def line(darray, *args, **kwargs): aspect = kwargs.pop('aspect', None) size = kwargs.pop('size', None) ax = kwargs.pop('ax', None) + x = kwargs.pop('x', None) + y = kwargs.pop('y', None) - ax = get_axis(figsize, size, aspect, ax) - - xlabel, = darray.dims - x = darray.coords[xlabel] + ax = get_axis(figsize, size, aspect, ax) - _ensure_plottable(x) + #if isinstance(darray, DataArray): + if y != None: + y_arr = darray.coords[y] + _ensure_plottable(y_arr) + primitive = ax.plot(darray, y_arr, *args, **kwargs) + ax.set_ylabel(y) + else: + if x != None: + x_arr = darray.coords[x] + _ensure_plottable(x_arr) + primitive = ax.plot(x_arr, darray, *args, **kwargs) + ax.set_xlabel(x) + else: + xlabel, = darray.dims + x_arr = darray.coords[xlabel] + _ensure_plottable(x_arr) + primitive = ax.plot(x_arr, darray, *args, **kwargs) + ax.set_xlabel(xlabel) + + # Rotate dates on xlabels + if np.issubdtype(x_arr.dtype, np.datetime64): + plt.gcf().autofmt_xdate() + + """ + elif isinstance(darray, Dataset): + if x != None and y != None: + xval = darray[x].values + yval = darray[y].values + _ensure_plottable(xval, yval) + primitive = ax.plot(xval, yval, *args, **kwargs) + ax.set_xlabel(darray[x].name) + ax.set_ylabel(darray[y].name) + else: + raise ValueError('Two variables are necessary to lineplot Dataset') + - primitive = ax.plot(x, darray, *args, **kwargs) + else: + raise ValueError('Only DataArray and Dataset is acceptable') - ax.set_xlabel(xlabel) + """ ax.set_title(darray._title_for_slice()) if darray.name is not None: - ax.set_ylabel(darray.name) - - # Rotate dates on xlabels - if np.issubdtype(x.dtype, np.datetime64): - plt.gcf().autofmt_xdate() + if y != None: + ax.set_xlabel(darray.name) + else: + ax.set_ylabel(darray.name) return primitive +def errorbar(darray, *args, **kwargs): + """ + Errorbar plot of 1 dimensional DataArray index against values + + Wraps matplotlib.pyplot.plot + + Parameters + ---------- + darray : DataArray + Must be 1 dimensional + x : string, optional + Coordinate for x axis. If None and y not None: 1D plot vertically oriented. + y : string, optional + Coordinate for y axis. If None: 1D plot horizontally oriented. + figsize : tuple, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. + aspect : scalar, optional + Aspect ratio of plot, so that ``aspect * size`` gives the width in + inches. Only used if a ``size`` is provided. + size : scalar, optional + If provided, create a new figure for the plot with the given size. + Height (in inches) of each plot. See also: ``aspect``. + ax : matplotlib axes object, optional + Axis on which to plot this figure. By default, use the current axis. + Mutually exclusive with ``size`` and ``figsize``. + *args, **kwargs : optional + Additional arguments to matplotlib.pyplot.plot + + """ + plt = import_matplotlib_pyplot() + ndims = len(darray.dims) + if ndims != 1: + raise ValueError('Line plots are for 1 dimensional DataArrays. ' + 'Passed DataArray has {ndims} ' + 'dimensions'.format(ndims=ndims)) + + # Ensures consistency with .plot method + figsize = kwargs.pop('figsize', None) + aspect = kwargs.pop('aspect', None) + size = kwargs.pop('size', None) + ax = kwargs.pop('ax', None) + x = kwargs.pop('x', None) + y = kwargs.pop('y', None) + + ax = get_axis(figsize, size, aspect, ax) + + #if isinstance(darray, DataArray): + if y != None: + y_arr = darray.coords[y] + _ensure_plottable(y_arr) + primitive = ax.errorbar(darray, y_arr.values, *args, **kwargs) + ax.set_ylabel(y) + else: + if x != None: + x_arr = darray.coords[x] + _ensure_plottable(x_arr) + primitive = ax.errorbar(x_arr.values, darray, *args, **kwargs) + ax.set_xlabel(x) + else: + xlabel, = darray.dims + x_arr = darray.coords[xlabel] + _ensure_plottable(x_arr) + primitive = ax.errorbar(x_arr.values, darray, *args, **kwargs) + ax.set_xlabel(xlabel) + + # Rotate dates on xlabels + if np.issubdtype(x_arr.dtype, np.datetime64): + plt.gcf().autofmt_xdate() + + if darray.name is not None: + if y != None: + ax.set_xlabel(darray.name) + else: + ax.set_ylabel(darray.name) + + return primitive def hist(darray, figsize=None, size=None, aspect=None, ax=None, **kwargs): """ @@ -301,6 +414,9 @@ def hist(self, ax=None, **kwargs): def line(self, *args, **kwargs): return line(self._da, *args, **kwargs) + @functools.wraps(errorbar) + def errorbar(self, *args, **kwargs): + return errorbar(self._da, *args, **kwargs) def _plot2d(plotfunc): """