Source code for xrview.core.plot

import numpy as np
import pandas as pd
from bokeh.layouts import gridplot
from bokeh.models import FactorRange, Glyph, HoverTool
from bokeh.plotting import figure

from xrview.core.panel import BasePanel
from xrview.elements import Element
from xrview.glyphs import get_glyph_list
from xrview.handlers import DataHandler
from xrview.mappers import _get_overlay_figures, map_figures_and_glyphs
from xrview.palettes import RGB
from xrview.utils import clone_models, is_dataarray, is_dataset, rsetattr


class BasePlot(BasePanel):
    """ Base class for plots. """

    element_type = Element
    handler_type = DataHandler
    default_tools = "pan,wheel_zoom,save,reset,"

    def __init__(
        self,
        data,
        x,
        overlay="dims",
        coords=None,
        glyphs="line",
        title=None,
        share_y=False,
        tooltips=None,
        tools=None,
        toolbar_location="right",
        figsize=(600, 300),
        ncols=1,
        palette=None,
        ignore_index=False,
        theme=None,
        **fig_kwargs,
    ):
        """ Constructor.

        Parameters
        ----------
        data : xarray DataArray or Dataset
            The data to display.

        x : str
            The name of the dimension in ``data`` that contains the x-axis
            values.

        glyphs : str, BaseGlyph or iterable, default 'line'
            The glyph to use for plotting.

        figsize : iterable, default (600, 300)
            The size of the figure in pixels.

        ncols : int, default 1
            The number of columns of the layout.

        overlay : 'dims' or 'data_vars', default 'dims'
            If 'dims', make one figure for each data variable and overlay the
            dimensions. If 'data_vars', make one figure for each dimension and
            overlay the data variables. In the latter case, all variables must
            have the same dimensions.

        tooltips : dict, optional
            Names of tooltips mapping to glyph properties or source columns,
            e.g. ``{'datetime': '$x{%F %T.%3N}'}``.

        tools : str, optional
            bokeh tool string.

        palette : iterable, optional
            The palette to use when overlaying multiple glyphs.

        ignore_index : bool, default False
            If True, replace the x-axis values of the data by an appropriate
            evenly spaced index.
        """
        super(BasePlot, self).__init__()

        # check data
        if is_dataarray(data):
            if data.name is None:
                self.data = data.to_dataset(name="Data")
            else:
                self.data = data.to_dataset()
        elif is_dataset(data):
            self.data = data
        else:
            raise ValueError("data must be xarray DataArray or Dataset.")

        # check x
        if x in self.data.dims:
            self.x = x
        else:
            raise ValueError(
                x + " is not a dimension of the provided dataset."
            )

        # check overlay
        if overlay in ("dims", "data_vars"):
            self.overlay = overlay
        else:
            raise ValueError('overlay must be "dims" or "data_vars"')

        self.coords = coords

        self.glyphs = get_glyph_list(glyphs)

        self.tooltips = tooltips

        # layout parameters
        self.title = title
        self.share_y = share_y
        self.ncols = ncols
        self.figsize = figsize
        self.fig_kwargs = fig_kwargs
        self.theme = theme

        if palette is None:
            self.palette = RGB
        else:
            self.palette = palette

        if tools is None:
            self.tools = self.default_tools
            if self.tooltips is not None:
                self.tools += "hover,"
        else:
            self.tools = tools

        self.toolbar_location = toolbar_location

        self.ignore_index = ignore_index

    def _collect_data(self, data, coords=None):
        """ Base method for _collect. """

        plot_data = dict()

        if coords is True:
            coords = [
                c
                for c in data.coords
                if self.x in data[c].dims and c != self.x
            ]

        for v in list(data.data_vars) + (coords or []):
            if self.x not in data[v].dims:
                raise ValueError(self.x + " is not a dimension of " + v)
            elif len(data[v].dims) == 1:
                plot_data[v] = data[v].values
            elif len(data[v].dims) == 2:
                dim = [d for d in data[v].dims if d != self.x][0]
                for d in data[dim].values:
                    plot_data[v + "_" + str(d)] = (
                        data[v].sel(**{dim: d}).values
                    )
            else:
                raise ValueError(v + " has too many dimensions")

        # TODO: doesn't work for irregularly sampled data
        if self.ignore_index:
            if isinstance(data.indexes[self.x], pd.DatetimeIndex):
                if data.indexes[self.x].freq is None:
                    freq = data.indexes[self.x][1] - data.indexes[self.x][0]
                else:
                    freq = data.indexes[self.x].freq
                index = pd.date_range(
                    start=0, freq=freq, periods=data.sizes[self.x]
                )
            else:
                index = np.arange(data.sizes[self.x])
        else:
            if isinstance(data.indexes[self.x], pd.MultiIndex):
                index = [
                    tuple(str(i) for i in idx)
                    for idx in self.data.indexes[self.x].tolist()
                ]
                for n in data.indexes[self.x].names:
                    plot_data[n] = data.indexes[self.x].get_level_values(n)
            else:
                index = data[self.x].values

        return pd.DataFrame(plot_data, index=index)

    def _collect(self, hooks=None, coords=None):
        """ Collect plottable data in a pandas DataFrame. """
        data = self.data
        if hooks is not None:
            for h in hooks:
                data = h(data)
        return self._collect_data(data, coords=coords)

    def _attach_elements(self):
        """ Attach additional elements to this layout. """
        for element in self.added_figures + self.added_overlays:
            element.attach(self)

    def _make_handlers(self):
        """ Make handlers. """
        self.handlers = [self.handler_type(self._collect(coords=self.coords))]
        for element in self.added_figures + self.added_overlays:
            self.handlers.append(element.handler)

    def _make_maps(self):
        """ Make the figure and glyph map. """
        self.figure_map, self.glyph_map = map_figures_and_glyphs(
            self.data,
            self.x,
            self.handlers,
            self.glyphs,
            self.overlay,
            self.fig_kwargs,
            self.added_figures,
            self.added_overlays,
            self.added_overlay_figures,
            self.palette,
            self.title,
        )

    def _make_figures(self):
        """ Make figures. """
        # TODO: check if we can put this in self.figure_map.figure
        self.figures = []

        for _, f in self.figure_map.iterrows():
            # adjust x axis type for datetime x values
            if isinstance(self.data.indexes[self.x], pd.DatetimeIndex):
                f.fig_kwargs["x_axis_type"] = "datetime"

            # set axis ranges
            if len(self.figures) == 0:
                if isinstance(self.data.indexes[self.x], pd.MultiIndex):
                    f.fig_kwargs["x_range"] = FactorRange(
                        *(
                            tuple(str(i) for i in idx)
                            for idx in self.data.indexes[self.x].tolist()
                        ),
                        range_padding=0.1,
                    )
            else:
                f.fig_kwargs["x_range"] = self.figures[0].x_range
                if self.share_y:
                    f.fig_kwargs["y_range"] = self.figures[0].y_range

            if self.figsize is not None:
                width = self.figsize[0] // self.ncols
                height = (
                    self.figsize[1] // self.figure_map.shape[0] * self.ncols
                )
                f.fig_kwargs.update(dict(plot_width=width, plot_height=height))

            self.figures.append(figure(tools=self.tools, **f.fig_kwargs))

    def _add_glyphs(self):
        """ Add glyphs. """
        for g_idx, g in self.glyph_map.iterrows():
            glyph_kwargs = clone_models(g.glyph_kwargs)
            if isinstance(g.method, str):
                getattr(self.figures[g.figure], g.method)(
                    source=g.handler.source, **glyph_kwargs
                )
            else:
                self.figures[g.figure].add_layout(
                    g.method(source=g.handler.source, **glyph_kwargs)
                )
            # add an invisible circle glyph to make glyph selectable
            if g.method != "circle":
                self.figures[g.figure].circle(
                    source=g.handler.source,
                    size=0,
                    **{"x": glyph_kwargs[g.x_arg], "y": glyph_kwargs[g.y_arg]},
                )

    def _add_annotations(self):
        """ Add annotations. """
        for idx, a in enumerate(self.added_annotations):
            f_idx = _get_overlay_figures(
                self.added_annotation_figures[idx], self.figure_map
            )
            for f in f_idx:
                if isinstance(a, Glyph):
                    self.figures[f].add_glyph(a)
                else:
                    self.figures[f].add_layout(a)

    def _add_tooltips(self):
        """ Add tooltips. """
        if self.tooltips is not None:
            tooltips = [(k, v) for k, v in self.tooltips.items()]
            for f in self.figures:
                f.select(HoverTool).tooltips = tooltips
                if isinstance(self.data.indexes[self.x], pd.DatetimeIndex):
                    f.select(HoverTool).formatters = {"$x": "datetime"}

    def _finalize_layout(self):
        """ Finalize layout. """
        self.layout = gridplot(
            self.figures,
            ncols=self.ncols,
            toolbar_location=self.toolbar_location,
        )

    def _modify_figures(self):
        """ Modify the attributes of multiple figures. """
        for figures, modifiers in self.modifiers:
            if figures is None:
                figures = self.figures
            elif isinstance(figures, int):
                figures = [self.figures[figures]]
            else:
                figures = [self.figures[idx] for idx in figures]

            for f in figures:
                self._modify_figure(modifiers, f)

    def _modify_figure(self, modifiers, f):
        """ Modify the attributes of a figure. """
        for m in modifiers:
            rsetattr(f, m, modifiers[m])

    def make_layout(self):
        """ Make the layout. """
        self._attach_elements()
        self._make_handlers()
        self._make_maps()
        self._make_figures()
        self._modify_figures()
        self._add_glyphs()
        self._add_annotations()
        self._add_tooltips()
        self._finalize_layout()

        return self.layout

    def add_figure(self, data, glyphs="line", coords=None, name=None):
        """ Add a figure to the layout.

        Parameters
        ----------
        data : xarray.DataArray
            The data to display.

        glyphs : str, BaseGlyph or iterable thereof, default 'line'
            The glyph (or glyphs) to display.

        coords : iterable of str, optional
            The coordinates of the DataArray to include. This is necessary
            for composite glyphs such as BoxWhisker.

        name : str, optional
            The name of the DataArray which will be used as the title of the
            figure. If not provided, the name of the DataArray will be used.
        """
        element = self.element_type(glyphs, data, coords, name)
        self.added_figures.append(element)

    def add_overlay(
        self, data, glyphs="line", coords=None, name=None, onto=None
    ):
        """ Add an overlay to a figure in the layout.

        Parameters
        ----------
        data : xarray.DataArray
            The data to display.

        glyphs : str, BaseGlyph or iterable thereof, default 'line'
            The glyph (or glyphs) to display.

        coords : iterable of str, optional
            The coordinates of the DataArray to include. This is necessary
            for composite glyphs such as BoxWhisker.

        name : str, optional
            The name of the DataArray which will be used to identify the
            overlay. If not provided, the name of the DataArray will be used.

        onto : str or int, optional
            Title or index of the figure on which the element will be
            overlaid. By default, the element is overlaid on all figures.
        """
        element = self.element_type(glyphs, data, coords, name)
        self.added_overlays.append(element)
        self.added_overlay_figures.append(onto)

    def add_annotation(self, annotation, onto=None):
        """ Add an annotation to a figure in the layout.

        Parameters
        ----------
        annotation :
        onto : str or int, optional
            Title or index of the figure on which the annotation will be
            overlaid. By default, the annotation is overlaid on all figures.
        """
        self.added_annotations.append(annotation)
        self.added_annotation_figures.append(onto)

    def modify_figures(self, modifiers, figures=None):
        """ Modify the attributes of a figure.

        Parameters
        ----------
        modifiers : dict
            The attributes to modify. Keys can reference sub-attributes,
            e.g. 'xaxis.axis_label'.

        figures : int or iterable of int, optional
            The index(es) of the figure(s) to modify.
        """
        self.modifiers.append((figures, modifiers))