#!/usr/bin/env python
# Part of the psychopy_ext library
# Copyright 2010-2014 Jonas Kubilius
# The program is distributed under the terms of the GNU General Public License,
# either version 3 of the License, or (at your option) any later version.
"""
A wrapper of matplotlib for producing pretty plots by default. As `pandas`
evolves, some of these improvements will hopefully be merged into it.
Usage::
    import plot
    plt = plot.Plot(nrows_ncols=(1,2))
    plt.plot(data)  # plots data on the first subplot
    plt.plot(data2)  # plots data on the second subplot
    plt.show()
"""
import fractions
import numpy as np
import scipy.stats
import pandas
import pandas.tools.plotting  # for rcParams
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.gridspec as gridspec
from matplotlib.patches import Rectangle
from matplotlib.ticker import MultipleLocator
try:
    import seaborn as sns   # hope you have it
    _has_seaborn = True
except:  # ok, stick to your ugly matplotlib then
    # but I'm still gonna improve it using the ggplot style
    # from https://gist.github.com/huyng/816622
    # inspiration from mpltools
    rc_params = pandas.tools.plotting.mpl_stylesheet
    rc_params['interactive'] = False  # doesn't display otherwise
    plt.rcParams.update(rc_params)
import stats
class Plot(object):
[docs]    def __init__(self, kind='', gridtype='', figsize=None, nrows=1, ncols=1, **kwargs):
        self._create_subplots(kind=kind, gridtype=gridtype, figsize=figsize,
                              nrows=nrows, ncols=ncols, **kwargs)
 
    def _create_subplots(self, kind='', gridtype='', figsize=None,
                         nrows=1, ncols=1, **kwargs):
        """
        :Kwargs:
            - kind ({'', 'matrix', 'imagegrid', 'gridspec'}, default: '')
                The kind of plot. For plotting matrices or images
                (`matplotlib.pyplot.imshow`), choose `matrix` (or `imagegrid`),
                for customizing subplot location and aspect ratios,
                use `gridspec`, otherwise leave blank for simple subplots.
            - figsize (tuple, defaut: None)
                Size of the figure.
            - nrows (int, default: 1)
                Number of subplots vertically.
            - ncols (int, default: 1)
                Number of subplots horizontally.
            - **kwargs
                A dictionary of keyword arguments that `matplotlib.ImageGrid`
                or `matplotlib.pyplot.suplots` accept. Differences:
                    - `rect` (`matplotlib.ImageGrid`) is a keyword argument here
                    - `cbar_mode = 'single'`
                    - `squeeze = False`
        :Returns:
            `matplotlib.pyplot.figure` and a grid of axes.
        """
        self._subplots_kwargs = kwargs  # save the original
        self.figsize = figsize
        if 'nrows_ncols' not in kwargs:
            nrows_ncols = (nrows, ncols)
        else:
            nrows_ncols = kwargs['nrows_ncols']
            del kwargs['nrows_ncols']
        try:
            num = self.fig.number
            self.fig.clf()
        except:
            num = None
        if kind == 'matrix' or gridtype.lower() == 'imagegrid':
            if 'sharex' in kwargs and 'sharey' in kwargs:
                if kwargs['sharex'] and kwargs['sharey']:
                    kwargs['share_all'] = True
                    del kwargs['sharex']
                    del kwargs['sharey']
            self.fig = self.figure(figsize=figsize, num=num)
            if 'label_mode' not in kwargs:
                kwargs['label_mode'] = "L"
            if 'axes_pad' not in kwargs:
                kwargs['axes_pad'] = .5
            if 'share_all' not in kwargs:
                kwargs['share_all'] = True
            if kwargs['share_all']:
                if 'cbar_mode' not in kwargs:
                    kwargs['cbar_mode'] = "single"
            if 'rect' in kwargs:
                rect = kwargs['rect']
                del kwargs['rect']
            else:
                rect = 111
            self.axes = ImageGrid(self.fig, rect,
                                  nrows_ncols=nrows_ncols,
                                  **kwargs
                                  )
            self.naxes = len(self.axes.axes_all)
        elif gridtype.lower() == 'gridspec':
            # useful for specifying subplot composition and
            # no sharex, sharey support
            if 'width_ratios' not in kwargs:
                kwargs['width_ratios'] = None
            if 'height_ratios' not in kwargs:
                kwargs['height_ratios'] = None
            self.fig = self.figure(figsize=figsize, num=num)
            gs = gridspec.GridSpec(nrows_ncols[0], nrows_ncols[1], **kwargs)
            self.axes = [plt.subplot(s) for s in gs]
            self.naxes = len(self.axes)
        else:
            self.fig, self.axes = plt.subplots(
                nrows=nrows_ncols[0],
                ncols=nrows_ncols[1],
                figsize=figsize,
                num=num,
                **kwargs
                )
            try:
                self.axes = self.axes.ravel()  # turn axes into a list
            except:
                self.axes = [self.axes]
            self.naxes = len(self.axes)
        self.kind = kind
        self.subplotno = -1  # will get +1 after the plot command
        self.nrows = nrows_ncols[0]
        self.ncols = nrows_ncols[1]
        try:
            self.sharex = kwargs['sharex']
        except:
            self.sharex = False
        try:
            self.sharey = kwargs['sharey']
        except:
            self.sharey = False
        self.rcParams = plt.rcParams
        return (self.fig, self.axes)
    def __getattr__(self, name):
        """Pass on a `seaborn` or `matplotlib` function that we haven't modified
        """
        def method(*args, **kwargs):
            try:
                return getattr(sns, name)(*args, **kwargs)
            except:
                try:
                    return getattr(plt, name)(*args, **kwargs)
                except:
                    return None
        meth = method  # is it a function?
        if meth is None:  # maybe it's just a self variable
            return getattr(self, name)
        else:
            return meth
    def __getitem__(self, key):
        """Allow to get axes as Plot()[key]
        """
        if key > self.naxes:
            raise IndexError
        if key < 0:
            key += self.naxes
        return self.axes[key]
[docs]    def get_ax(self, subplotno=None):
        """
        Returns the current or the requested axis from the current figure.
        .. note: The :class:`Plot()` is indexable so you should access axes as
                 `Plot()[key]` unless you want to pass a list like (row, col).
        :Kwargs:
            subplotno (int, default: None)
                Give subplot number explicitly if you want to get not the
                current axis
        :Returns:
            ax
        """
        if subplotno is None:
            no = self.subplotno
        else:
            no = subplotno
        if isinstance(no, int):
            try:
                ax = self.axes[no]
            except:  # a single subplot
                ax = self.axes
        else:
            if no[0] < 0: no += len(self.axes._nrows)
            if no[1] < 0: no += len(self.axes._ncols)
            if isinstance(self.axes, ImageGrid):  # axes are a list
                if self.axes._direction == 'row':
                    no = self.axes._ncols * no[0] + no[1]
                else:
                    no = self.axes._nrows * no[0] + no[1]
            else:  # axes are a grid
                no = self.axes._ncols * no[0] + no[1]
            ax = self.axes[no]
        return ax
 
[docs]    def next(self):
        """
        Returns the next axis.
        This is useful when a plotting function is not implemented by
        :mod:`plot` and you have to instead rely on matplotlib's plotting
        which does not advance axes automatically.
        """
        self.subplotno += 1
        return self.get_ax()
 
[docs]    def sample_paired(self, ncolors=2):
        """
        Returns colors for matplotlib.cm.Paired.
        """
        if ncolors <= 12:
            colors_full = [mpl.cm.Paired(i * 1. / 11) for i in range(1, 12, 2)]
            colors_pale = [mpl.cm.Paired(i * 1. / 11) for i in range(10, -1, -2)]
            colors = colors_full + colors_pale
            return colors[:ncolors]
        else:
            return [mpl.cm.Paired(c) for c in np.linspace(0,ncolors)]
 
[docs]    def get_colors(self, ncolors=2, cmap='Paired'):
        """
        Get a list of nice colors for plots.
        FIX: This function is happy to ignore the ugly settings you may have in
        your matplotlibrc settings.
        TODO: merge with mpltools.color
        :Kwargs:
            ncolors (int, default: 2)
                Number of colors required. Typically it should be the number of
                entries in the legend.
            cmap (str or matplotlib.cm, default: 'Paired')
                A colormap to sample from when ncolors > 12
        :Returns:
            a list of colors
        """
        colorc = plt.rcParams['axes.color_cycle']
        if ncolors <= len(colorc):
            colors = colorc[:ncolors]
        elif ncolors <= 12:
            colors = self.sample_paired(ncolors=ncolors)
        else:
            thisCmap = mpl.cm.get_cmap(cmap)
            norm = mpl.colors.Normalize(0, 1)
            z = np.linspace(0, 1, ncolors + 2)
            z = z[1:-1]
            colors = thisCmap(norm(z))
        return colors
 
    def pivot_plot(self,df,rows=None,cols=None,values=None,yerr=None,
                   **kwargs):
        agg = self.aggregate(df, rows=rows, cols=cols,
                                 values=values, yerr=yerr)
        if yerr is None:
            no_yerr = True
        else:
            no_yerr = False
        return self._plot(agg, no_yerr=no_yerr,**kwargs)
    def _plot(self, agg, ax=None,
                   title='', kind='bar', xtickson=True, ytickson=True,
                   no_yerr=False, numb=False, autoscale=True, **kwargs):
        """DEPRECATED plotting function"""
        print "plot._plot() has been DEPRECATED; please don't use it anymore"
        self.plot(agg, ax=ax,
                   title=title, kind=kind, xtickson=xtickson, ytickson=ytickson,
                   no_yerr=no_yerr, numb=numb, autoscale=autoscale, **kwargs)
[docs]    def plot(self, agg, kind='bar', subplots=None, subplots_order=None,
            autoscale=True, title=None, errkind='sem', within=None,
            xlim=None, ylim=None, xlabel=None, ylabel=None, popmean=0,
            numb=False, **kwargs):
        """
        The main plotting function.
        :Args:
            agg (`pandas.DataFrame` or similar)
                A structured input, preferably a `pandas.DataFrame`, but in
                principle accepts anything that can be converted into it.
        :Kwargs:
            - subplots (None, True, or False; default=None)
                Whether you want to split data into subplots or not. If True,
                the top level is treated as a subplot. If None, detects
                automatically based on `agg.columns.names` -- the first entry
                to start with `subplots.` will be used. This is the default
                output from `stats.aggregate` and is recommended.
            - kwargs
                Keyword arguments for plotting
        :Returns:
            A list of axes of all plots.
        """
        #if isinstance(agg, (list, tuple)):
            #agg = np.array(agg)
        try:
            values_name = agg.names
        except:
            values_name = ''
        if len(self.get_fignums()) == 0:
            self.draw()
        #if not isinstance(agg, pandas.DataFrame):
            #agg = pandas.DataFrame(agg)
            #if agg.shape[1] == 1:  # Series
                #agg = pandas.DataFrame(agg).T
        #else:
            #agg = pandas.DataFrame(agg)
        axes = []
        if subplots_order is not None:
            sbp = subplots_order
        elif 'subplots' in agg._splits and subplots!=False:
            sbp = agg.columns.get_level_values(agg._splits['subplots'][0]).unique()
        else:
            sbp = None
            #try:
                #s_idx = [s for s,n in enumerate(agg.columns.names) if n.startswith('subplots.')]
            #except:
                #s_idx = None
            #if s_idx is not None:  # subplots implicit in agg
                #try:
                    #sbp = agg.columns.get_level_values(s_idx[0]).unique() #agg.columns.levels[s_idx[0]]
                #except:
                    #if len(s_idx) > 0:
                        #sbp = agg.columns
                    #else:
                        #sbp = None
            #elif subplots:  # get subplots from the top level column
                #sbp = agg.columns.get_level_values(0).unique() #agg.columns.levels[0]
            #else:
                #sbp = None
        #import pdb; pdb.set_trace()
        if sbp is None:
            axes = self._plot_ax(agg, kind=kind, errkind=errkind, within=within, **kwargs)
            agg.names = values_name
            axes.agg = agg
            #axes, xmin, xmax, ymin, ymax = self._label_ax(agg, mean, p_yerr, axes, kind=kind,
                                    #autoscale=autoscale, **kwargs)
            axes = [axes]
            if title is not None:
                axes[0].set_title(title)
            #if 'title' in kwargs:
                #if kwargs['title'] is not None:
                    #axes[0].set_title(kwargs['title'])
            #if 'title' not in kwargs:
                #kwargs['title']
            #else:
                #title = ''
        else:
            # if we haven't made any plots yet...
            #import pdb; pdb.set_trace()
            if self.subplotno == -1:
                num_subplots = len(sbp)
                # ...can still adjust the number of subplots
                if num_subplots > self.naxes:
                    if 'sharex' not in self._subplots_kwargs:
                        self._subplots_kwargs['sharex'] = True
                    if 'sharey' not in self._subplots_kwargs:
                        self._subplots_kwargs['sharey'] = True
                    self._create_subplots(ncols=num_subplots, kind=kind,
                        figsize=self.figsize, **self._subplots_kwargs)
            #if 'share_all' in kwargs:
                #if kwargs['share_all']:
                    #norm = (
            for no, subname in enumerate(sbp):
                if title is None:
                    sbp_title = subname
                else:
                    sbp_title = title
                #if 'title' not in kwargs:
                    #title = subname
                #else:
                    #title = kwargs['title']
                    #if title is None:
                        #title = subname
                split = agg[subname]
                split._splits = agg._splits
                ax = self._plot_ax(split, kind=kind, errkind=errkind,
                                   within=within, **kwargs)
                #ax, xmin, xmax, ymin, ymax = self._label_ax(agg[subname],
                                        #mean, p_yerr, ax, kind=kind,
                                        #legend=legend, autoscale=autoscale, **kwargs)
                ax.agg = agg[subname]
                ax.agg.names = values_name
                ax.set_title(sbp_title)
                #ax.mean = mean
                #xmins.append(xmin)
                #xmaxs.append(xmax)
                #ymins.append(ymin)
                #ymaxs.append(ymax)
                axes.append(ax)
        #if 'ylabel' not in kwargs:
            #kwargs['ylabel'] = values_name
        self.decorate(axes, kind=kind, xlim=xlim, ylim=ylim,
                 xlabel=xlabel, ylabel=ylabel, popmean=popmean, numb=numb,
                 within=within, errkind=errkind)
        if len(axes) == 1:
            return axes[0]
        else:
            return axes
 
[docs]    def show(self, tight=True):
        if tight and self.kind != 'matrix':
            self.tight_layout()
        plt.show()
 
    def decorate(self, axes, kind='bar', xlim=None, ylim=None,
                 xlabel=None, ylabel=None, popmean=0,
                 within=None, errkind='sem',
                 numb=False):
        if kind == 'matrix':
            lims = np.zeros((len(axes),3,2))
        else:
            lims = np.zeros((len(axes),2,2))
        for i, ax in enumerate(axes):
            if kind in ['scatter', 'mds']:
                lims[i,0] = ax.get_xlim()
                lims[i,1] = ax.get_ylim()
                xticks = ax.get_xticks()
                min_space = abs(xticks[1] - xticks[0])
                lims[i,0,0] -= min_space
                lims[i,0,1] += min_space
                yticks = ax.get_yticks()
                min_space = abs(yticks[1] - yticks[0])
                lims[i,1,0] -= min_space
                lims[i,1,1] += min_space
                #xran = lims[i,0,1] - lims[i,0,0]
                #yran = lims[i,1,1] - lims[i,1,0]
                #if xran > yran:
                    #lims[i,1,0] -= (xran-yran)/2 + xran/10.
                    #lims[i,1,1] += (xran-yran)/2 + xran/10.
                #else:
                    #lims[i,0,0] -= (yran-xran)/2 + yran/10.
                    #lims[i,0,1] += (yran-xran)/2 + yran/10.
            elif kind in ['histogram', 'bean']:
                lims[i,0] = ax.get_xlim()
                lims[i,1] = ax.get_ylim()
            elif kind == 'matrix':
                lims[i,0] = ax.get_xlim()
                lims[i,1] = ax.get_xlim()
                lims[i,2,0] = ax.mean.min().min()
                lims[i,2,1] = ax.mean.max().max()
            else:
                lims[i,0] = ax.get_xlim()
                if kind == 'bar':
                    lims[i,0,0] -= .25  # add some padding from both sides
                    lims[i,0,1] += .25
                lims[i,1] = self._autoscale(ax, ax.mean, ax.p_yerr, kind=kind)
        if xlim is not None:
            for lim in lims:
                lim[0,:] = xlim
        if ylim is not None:
            for lim in lims:
                lim[1,:] = ylim
        if kind == 'matrix':
            for ax in axes:
                for im in ax.images:
                    norm = mpl.colors.Normalize(vmin=np.min(lims[:,2,0]),
                                                vmax=np.max(lims[:,2,1]))
                    im.set_norm(norm)
        else:
            xlim = [np.min(lims[:,0,0]), np.max(lims[:,0,1])]
            ylim = [np.min(lims[:,1,0]), np.max(lims[:,1,1])]
            #if kind in ['scatter', 'mds']:
                #if ((self.sharex and not self.sharey) or
                #(not self.sharex and self.sharey)):
                    #raise Exception('%s plot must either share both x and y axes,'
                        #'or not share them at all.' % kind)
                #else:
                    #xran = xlim[1] - xlim[0]
                    #yran = ylim[1] - ylim[0]
                    #if xran > yran:
                        #ylim[0] -= (xran-yran)/2
                        #ylim[1] += (xran-yran)/2
                    #else:
                        #xlim[0] -= (yran-xran)/2
                        #xlim[1] += (yran-xran)/2
            if kind in ['line', 'scatter', 'mds']:
                if self.sharex:
                    axes[0].set_xlim(xlim)
                    majorLocator = self._space_ticks(axes[0].get_xticks(), kind)
                    if majorLocator is not None:
                        axes[0].xaxis.set_major_locator(majorLocator)
                else:
                    for i, ax in enumerate(axes):
                        ax.set_xlim(lims[i,0])
                        majorLocator = self._space_ticks(ax.get_xticks(), kind)
                        if majorLocator is not None:
                            ax.xaxis.set_major_locator(majorLocator)
            if self.sharey: # set y-axis limits globally
                #import pdb; pdb.set_trace()
                axes[0].set_ylim(ylim)
                majorLocator = self._space_ticks(axes[0].get_yticks(), kind)
                if majorLocator is not None:
                    axes[0].yaxis.set_major_locator(majorLocator)
            else:
                for i, ax in enumerate(axes):
                    ax.set_ylim(lims[i,1])
                    majorLocator = self._space_ticks(ax.get_yticks(), kind)
                    if majorLocator is not None:
                        ax.yaxis.set_major_locator(majorLocator)
        for axno, ax in enumerate(axes):
            #if kind not in ['scatter', 'mds']:
            # put x labels only at the bottom of the subplots figure
            if axno / self.ncols == self.nrows-1 or not self.sharex:
                self._label_x(ax.mean, ax.p_yerr, ax, kind=kind, xlabel=xlabel)
            if axno % self.ncols == 0 or not self.sharey:
                self._label_y(ax.mean, ax.p_yerr, ax, kind=kind, ylabel=ylabel)
            if kind == 'bar':
                try:  # not always possible to compute significance
                    self.draw_sig(ax.agg, ax, popmean=popmean,
                                  within=within, errkind=errkind)
                except:
                    pass
            #if kind in ['matrix', 'scatter', 'mds']:
                #legend = False
            #else:
            # all plots are the same, onle legend will suffice
            if len(axes) > 1:
                if axno == 0:
                    legend = None
                else:
                    legend = False
            else:  # plots vary; each should get a legend
                legend = None
            if kind not in ['matrix', 'scatter', 'mds']:
                self._draw_legend(ax, visible=legend, data=ax.mean, kind=kind)
            if numb:
                self.add_inner_title(ax, title='%s' % self.subplotno, loc=2)
            if kind in ['scatter', 'mds']:
                ax.set_aspect('equal')
            #self._label_ax(ax.agg, ax.mean, ax.p_yerr, ax, kind=kind, legend=legend, **kwargs)
        if self.sharex or len(axes) == 1:
            try:
                self.sig_t = pandas.concat([ax.sig_t for ax in axes], axis=1)
                            #keys=[ax.get_title() for ax in axes])
                self.sig_p = pandas.concat([ax.sig_p for ax in axes], axis=1)
                            #keys=[ax.get_title() for ax in axes])
            except:
                pass
        return axes
    def _space_ticks(self, ticks, kind=None):
        if len(ticks) <= 5:
            nbins = len(ticks)
        else:
            largest = [fractions.gcd(len(ticks)+1,i+1) for i in range(5)]
            if np.max(largest) == 1:
                largest = [fractions.gcd(len(ticks),i+1) for i in range(5)]
            nbins = np.max(largest) + 1
        #if kind in ['scatter', 'mds']:
            #majorLocator = mpl.ticker.LinearLocator(numticks=nbins)
        #else:
        try:
            majorLocator = mpl.ticker.FixedLocator(ticks, nbins=nbins)
        except:
            majorLocator = None
        return majorLocator
[docs]    def printfig(self):
        try:
            #import pdb; pdb.set_trace()
            for ax in self.axes:
                print ax.get_title()
                print ax.mean
        except:
            pass
 
    def _plot_ax(self, agg, ax=None, kind='bar', order=None, errkind='sem',
                 within=None, **kwargs):
        if ax is None:
            ax = self.next()
        # compute means -- so that's a Series
        mean, p_yerr = stats.confidence(agg, kind=errkind, within=within)
        # unstack data into rows and cols, if possible
        if 'rows' in agg._splits and 'cols' in agg._splits:
            mean = stats.unstack(mean, level=agg._splits['cols'])
            p_yerr = stats.unstack(p_yerr, level=agg._splits['cols'])
        if isinstance(mean, pandas.Series):
            if 'rows' in agg._splits:
                mean = pandas.DataFrame(mean).T
                p_yerr = pandas.DataFrame(p_yerr).T
            else:
                mean = pandas.DataFrame(mean)
                p_yerr = pandas.DataFrame(p_yerr)
        if isinstance(agg, pandas.Series) and kind=='bean':
            kind = 'bar'
            print 'WARNING: Beanplot not available for a single measurement'
        if kind == 'bar':
            self.bar_plot(mean, yerr=p_yerr, ax=ax, **kwargs)
        elif kind == 'line':
            self.line_plot(mean, yerr=p_yerr, ax=ax, **kwargs)
        elif kind == 'bean':
            ax = self.bean_plot(agg, ax=ax, order=order, **kwargs)
        elif kind == 'histogram':
            ax = self.histogram(agg, ax=ax, **kwargs)
        elif kind == 'matrix':
            ax = self.matrix_plot(mean, ax=ax)
        elif kind == 'scatter':
            ax = self.scatter_plot(mean, ax=ax)
        elif kind == 'mds':
            ax, mean = self.mds_plot(mean, ax=ax, **kwargs)
        else:
            raise Exception('%s plot not recognized. Choose from '
                            '{bar, line, bean, matrix, scatter, mds}.' %kind)
        ax.mean = mean
        ax.p_yerr = p_yerr
        return ax
    def _label_ax_old(self, agg, mean, p_yerr, ax, title='', kind='bar', legend=None,
                   autoscale=True, **kwargs):
        if kind not in ['scatter', 'mds']:
            if self.subplotno / self.ncols == self.ncols or not self.sharex:
                self._label_x(mean, p_yerr, ax, kind=kind, **kwargs)
            if self.subplotno % self.ncols == 0 or not self.sharey:
                self._label_y(mean, p_yerr, ax, kind=kind, **kwargs)
        if kind == 'bar':
            self.draw_sig(agg, ax)
        #if not self.sharex and kind in ['scatter', 'mds']:
            ## set x-axis limits
            #if 'xlim' in kwargs:
                #ax.set_xlim(kwargs['xlim'])
            #else:
                #ax.set_xlim([xmin, xmax])
        #if not self.sharey:
            ## set y-axis limits
            #if 'ylim' in kwargs:
                #ax.set_ylim(kwargs['ylim'])
            #elif autoscale and kind in ['line', 'bar']:
                #ax.set_ylim([ymin, ymax])
        #if title is not None: ax.set_title(title)
        if kind not in ['matrix', 'scatter', 'mds']:
            self._draw_legend(ax, visible=legend, data=mean, **kwargs)
        if 'numb' in kwargs:
            if kwargs['numb'] == True:
                self.add_inner_title(ax, title='%s' % self.subplotno, loc=2)
        return ax
    def _label_x(self, mean, p_yerr, ax, kind='bar', xtickson=True,
                   rotate=True, xlabel=None):
        if kind not in ['histogram', 'scatter', 'mds']:
            labels = ax.get_xticklabels()
            new_labels = self._format_labels(labels=mean.index)
            if len(labels) > len(mean):
                new_labels = [''] + new_labels
            if kind == 'line':
                try:  # don't set labels for number
                    mean.index[0] + 1
                except:
                    ax.set_xticklabels(new_labels)
                else:
                    loc = map(int, ax.xaxis.get_majorticklocs())
                    try:
                        new_labels = [loc[0]] + mean.index[np.array(loc[1:])].tolist()
                    except:
                        pass
                    else:
                        ax.set_xticklabels(new_labels)
            else:
                ax.set_xticklabels(new_labels)
        labels = ax.get_xticklabels()
        if len(labels) > 0:
            max_len = max([len(label.get_text()) for label in labels])
            if max_len > 10 or (kind == 'matrix' and max_len > 2):
                if rotate :  #FIX to this: http://stackoverflow.com/q/5320205
                    for label in labels:
                        label.set_ha('right')
                        label.set_rotation(30)
            else:
                for label in labels:
                    label.set_rotation(0)
        #if 'xlabel' in kwargs:
            #xlabel = kwargs['xlabel']
        #else:
            #xlabel = None
        if xlabel is None:
            if kind in ['scatter', 'mds']:
                xlabel = mean.columns[0] #'.'.join(mean.columns.names[0].split('.')[1:])
            else:
                xlabel = self._get_title(mean, 'rows')
        ax.set_xlabel(xlabel)
        return ax
    def _label_y(self, mean, p_yerr, ax, kind='bar', ytickson=True, ylabel=None):
        if kind == 'matrix':
            ax.set_yticklabels(self._format_labels(labels=mean.columns))
        if not ytickson:
            ax.set_yticklabels(['']*len(ax.get_yticklabels()))
        #if 'ylabel' in kwargs:
            #ylabel = kwargs['ylabel']
        #else:
            #ylabel = None
        if ylabel is None:
            if kind == 'matrix':
                ylabel = self._get_title(mean, 'cols')
            elif kind in ['scatter', 'mds']:
                ylabel = mean.columns[1] #'.'.join(mean.columns.names[0].split('.')[1:])
            else:
                try:
                    ylabel = ax.agg.names
                except:
                    ylabel = ''
        ax.set_ylabel(ylabel)
        return ax
    def _autoscale(self, ax, mean, p_yerr, kind='bar'):
        #mean_array = np.asarray(mean)
        r = mean.max().max() - mean.min().min()
        ebars = np.where(np.isnan(p_yerr), 0, p_yerr)
        if np.all(ebars == 0):  # basically no error bars
            ymin = mean.min().min()
            ymax = (mean + r/3.).max().max()  # give some space above the bars
        else:
            ymin = (mean - ebars).min().min()
            ymax = (mean + ebars).max().max()
        if kind == 'bar':  # for barplots, 0 must be included
            if ymin > 0:
                ymin = 0
            if ymax < 0:
                ymax = 0
        xyrange = ymax - ymin
        if ymin != 0:
            ymin -= xyrange / 3.
        if ymax != 0:
            ymax += xyrange / 3.
        yticks = ax.get_yticks()
        min_space = abs(yticks[1] - yticks[0])
        ymin =  np.round(ymin/min_space) * min_space
        ymax =  np.round(ymax/min_space) * min_space
        return ymin, ymax
    def _get_title(self, data, pref):
        if pref == 'cols':
            dnames = data.columns.names
            try:
                dlevs = data.columns.levels
            except:
                dlevs = [data.columns]
        else:
            dnames = data.index.names
            try:
                dlevs = data.index.levels
            except:
                dlevs = [data.index]
        if len(dnames) == 0 or dnames[0] == None: dnames = ['']
        title = [n.split('.',1)[1] for n in dnames if n.startswith(pref+'.')]
        levels = [l for l,n in zip(dlevs,dnames) if n.startswith(pref+'.')]
        title = [n for n,l in zip(title,levels) if len(l) > 1]
        title = ', '.join(title)
        return title
[docs]    def draw_sig(self, agg, ax, popmean=0, errkind='sem', within=None):
        # find out how many columns per group there are
        try:
            cols = [i for i,n in enumerate(agg.columns.names) if n.startswith('cols.')]
        except:  # no cols -> everything considered to be separate
            vals_len = 1
        else:
            if len(cols) == 0:
                vals_len = 1
            else:
                try:
                    vals_len = np.max([len(agg.columns.levels[col]) for col in cols])
                except:
                    vals_len = len(agg.columns)
        if vals_len <= 2:  # if >2, cannot compute stats
            if isinstance(agg, pandas.DataFrame):
                mean, p_yerr = stats.confidence(agg, kind=errkind, within=within)
            else:
                mean = agg
            r = mean.max().max() - mean.min().min()
            ebars = np.where(np.isnan(p_yerr), r/3., p_yerr)
            if isinstance(p_yerr, pandas.Series):
                ebars = pandas.Series(ebars, index=p_yerr.index)
            else:
                ebars = pandas.DataFrame(ebars, columns=p_yerr.columns, index=p_yerr.index)
            #eb = np.max([r/6., 1.5*np.max(ebars)/2])
            ylim = ax.get_ylim()
            eb_gap = abs(ylim[0] - ylim[1]) / 8.
            try:  # not guaranteed that columns have names
                inds = [i for i,n in enumerate(agg.columns.names) if n.startswith('rows.')]
            except:
                inds = []
                rlabels = agg.columns
            else:
                #rlevel = inds[-1] + 1
                #try:
                    #rlabels = agg.columns.levels[inds[-1]]
                #except:
                    #if len(inds) == 1:
                        #rlabels = agg.columns
                    #else:
                        #rlabels = 1
                # all columns names start with 'row.'
                if len(inds) == len(agg.columns.names):
                    rlabels = agg.columns
                elif len(inds) == 0:  # no rows at all
                    rlabels = [None]
                else:
                    rlabels = stats.unstack(agg.mean(), level=inds).columns
            ticks = ax.get_xticks()
            sig_t = []
            sig_p = []
            for rno, rlab in enumerate(rlabels):
                if len(inds) == 0:  # no rows at all
                    d = agg
                    m = mean
                    e = ebars
                else:
                    d = _get_multi(agg, rlab, dim='columns')
                    m = _get_multi(mean, rlab, dim='rows')
                    e = _get_multi(ebars, rlab, dim='rows')
                    #d = agg.copy()
                    #m = mean.copy()
                    #e = ebars.copy()
                    #for r in rlab:
                        #d = d[r]
                        #m = m[r]
                        #e = e[r]
                if d.ndim == 1:
                    d = d.dropna() #d[pandas.notnull(d)]
                    t, p = scipy.stats.ttest_1samp(d, popmean=popmean)
                    ypos = m + np.sign(m) * (e + eb_gap)
                    ax.text(ticks[rno], ypos,
                            stats.get_star(p), ha='center')
                elif d.ndim == 2 and d.shape[1] == 2:
                    d1 = d.iloc[:,0].dropna()
                    d2 = d.iloc[:,1].dropna()
                    # two-tailed paired-samples t-test
                    t, p = scipy.stats.ttest_rel(d1, d2)
                    mn = m + np.sign(m) * (e + eb_gap)
                    #import pdb; pdb.set_trace()
                    ax.text(ticks[rno], mn.max(), stats.get_star(p), ha='center')
                if rlab is None:
                    rlab = ''
                try:
                    sig_t.append((rlab, t))
                    sig_p.append((rlab, p))
                except:
                    pass
            if len(sig_t) > 0:
                ax.sig_t = pandas.DataFrame(sig_t)
                ax.sig_p = pandas.DataFrame(sig_p)
 
    def _draw_legend(self, ax, visible=None, data=None, kind=None):
        leg = ax.get_legend()  # get an existing legend
        if leg is None:  # create a new legend
            leg = ax.legend() #loc='center left')
        if leg is not None:
            if kind == 'line':
                handles, labels = ax.get_legend_handles_labels()
                # remove the errorbars
                handles = [h[0] for h in handles]
                leg = ax.legend(handles, labels)
            leg.legendPatch.set_alpha(0.5)
            try:  # may or may not have any columns
                leg.set_title(self._get_title(data, 'cols'))
            except:
                pass
            new_texts = self._format_labels(data.columns)
            texts = leg.get_texts()
            for text, new_text in zip(texts, new_texts):
                text.set_text(new_text)
            #if 'legend_visible' in kwargs:
                #leg.set_visible(kwargs['legend_visible'])
            if visible is not None:
                leg.set_visible(visible)
            else:  #decide automatically
                if len(leg.texts) == 1:  # showing a single legend entry is useless
                    leg.set_visible(False)
                else:
                    leg.set_visible(True)
[docs]    def set_legend_pos(self, subplot=1, loc=6,#'center left',
                        bbox_to_anchor=(1.1, 0.5)):
        #for ax in self.axes:
            ##import pdb; pdb.set_trace()
            #leg = ax.get_legend()
            #if leg is not None: break
        leg = self.axes[subplot-1].get_legend()
        if leg is not None:
            #leg.set_axes(self.axes[subplot-1])
            leg._set_loc(loc)
            leg.set_bbox_to_anchor(bbox_to_anchor)
            leg.set_visible(True)
        # frameon=False
 
    def _format_labels(self, labels='', names=''):
        """Formats labels to avoid uninformative (singular) entries
        """
        if len(labels) > 1:
            try:
                labels.levels
            except:
                new_labs = [str(l) for l in labels]
            else:
                sel = [i for i,l in enumerate(labels.levels) if len(l) > 1]
                new_labs = []
                for r in labels:
                    label = [l for i,l in enumerate(r) if i in sel]
                    if len(label) == 1:
                        label = label[0]
                    else:
                        label = ', '.join([str(lab) for lab in label])
                    new_labs.append(label)
        else:
            new_labs = ['']
        return new_labs
[docs]    def hide_plots(self, nums):
        """
        Hides an axis.
        :Args:
            nums (int, tuple or list of ints)
                Which axes to hide.
        """
        if isinstance(nums, int) or isinstance(nums, tuple):
            nums = [nums]
        for num in nums:
            ax = self.get_ax(num)
            ax.axis('off')
 
[docs]    def display_legend(self, nums, show=False):
        """
        Shows or hides (default) legends on given axes.
        :Args:
            nums (int, tuple or list of ints)
                Axes numbers that need their legend hidden.
        :Kwargs:
            show (bool, default: False)
                Whether legends should be shown or hidden
        """
        if isinstance(nums, int) or isinstance(nums, tuple):
            nums = [nums]
        for num in nums:
            ax = self.get_ax(num)
            leg = ax.get_legend()
            leg.set_visible(show)
 
[docs]    def bar_plot(self, data, yerr=None, ax=None, **kwargs):
        """
        Plots a bar plot.
        :Args:
            data (`pandas.DataFrame` or any other array accepted by it)
                A data frame where rows go to the x-axis and columns go to the
                legend.
        """
        data = pandas.DataFrame(data)
        if yerr is None:
            yerr = np.empty(data.shape)
            yerr = yerr.reshape(data.shape)  # force this shape
            yerr = np.nan
        if ax is None:
            self.subplotno += 1
            ax = self.get_ax()
        colors = self.get_colors(len(data.columns))
        if not 'ecolor' in kwargs:
            kwargs['ecolor'] = 'black'
        n = len(data.columns)
        idx = np.arange(len(data))
        width = .75 / n
        rects = []
        for i, (label, column) in enumerate(data.iteritems()):
            rect = ax.bar(idx + i*width - .75/2, column, width,
                label=str(label), yerr=yerr[label].tolist(),
                color = colors[i], **kwargs)
            # TODO: yerr indexing might need fixing
            rects.append(rect)
        #gap = .25
        #xlim = ax.get_xlim()
        #if xlim[0] != np.min(idx) - gap:  # if sharex, this might have been set
            #import pdb; pdb.set_trace()
            #ax.set_xlim((xlim[0] - gap, xlim[1] + gap))
            #ax.set_xticks(idx)
        ax.set_xticks(idx)# + width*n/2 + width/2)
        ax.legend(rects, data.columns.tolist())
        ax.axhline(color=plt.rcParams['axes.edgecolor'])
        return ax
 
[docs]    def line_plot(self, data, yerr=None, ax=None, **kwargs):
        """
        Plots a bar plot.
        :Args:
            data (`pandas.DataFrame` or any other array accepted by it)
                A data frame where rows go to the x-axis and columns go to the
                legend.
        """
        data = pandas.DataFrame(data)
        if yerr is None:
            yerr = np.empty(data.shape)
            yerr = yerr.reshape(data.shape)  # force this shape
            yerr = np.nan
        if ax is None:
            self.subplotno += 1
            ax = self.get_ax()
        #if not 'fmt' in kwargs:
            #kwargs['fmt'] = None
        #if not 'ecolor' in kwargs:
            #kwargs['ecolor'] = 'black'
        #colors = self.get_colors(len(data.columns))
        try:
            data.index[0] + 1
        except:
            x = range(len(data))
        else:
            x = data.index.tolist()
        #import pdb; pdb.set_trace()
        for i, (label, column) in enumerate(data.iteritems()):
            ax.errorbar(x, column, yerr=yerr[label].tolist(),
                        label=str(label), fmt='-', ecolor='black', **kwargs)
        step = np.ptp(x) / (len(x) - 1.)
        xlim = ax.get_xlim()
        if xlim[0] != np.min(x) - step/2:  # if sharex, this might have been set
            ax.set_xlim((np.min(x) - step/2, np.max(x) + step/2))
            ax.set_xticks(x)
        #try:
            #data.index[0] + 1
        #except:
            #pass
        #else:
            #ax.set_xticklabels(data.index)
        #xticklabels = self._format_labels(labels=data.index)
        #old_labels = ax.get_xticklabels()
        #if len(xticklabels) < old_labels:
            #xticklabels = [''] + xticklabels
        #ax.set_xticklabels(xticklabels)
        return ax
 
[docs]    def scatter_plot(self, data, ax=None, labels=None, **kwargs):
        return self._scatter(data.iloc[:,0], data.iloc[:,1], labels=data.index,
                             ax=ax, **kwargs)
 
    def _scatter(self, x, y, labels=None, ax=None, **kwargs):
        """
        Draws a scatter plot.
        This is very similar to `matplotlib.pyplot.scatter` but additionally
        accepts labels (for labeling points on the plot), plot title, and an
        axis where the plot should be drawn.
        :Args:
            - x (an iterable object)
                An x-coordinate of data
            - y (an iterable object)
                A y-coordinate of data
        :Kwargs:
            - ax (default: None)
                An axis to plot in.
            #- labels (list of str, default: None)
                #A list of labels for each plotted point
            #- title (str, default: '')
                #Plot title
            - kwargs
                Additional keyword arguments for `matplotlib.pyplot.scatter`
        :Return:
            Current axis for further manipulation.
        """
        if ax is None:
            self.subplotno += 1
            ax = self.get_ax()
        plt.rcParams['axes.color_cycle']
        if labels is not None:
            #for label, (pointx, pointy) in data.iterrows():# enumerate(zip(x,y)):
                #ax.text(pointx, pointy, label, backgroundcolor=(1,1,1,.5))
            for i, (pointx, pointy) in enumerate(zip(x,y)):
                ax.text(pointx, pointy, labels[i], backgroundcolor=(1,1,1,.5))
        ax.scatter(x, y, marker='o', color=self.get_colors()[0], **kwargs)
        return ax
[docs]    def histogram(self, data, ax=None, bins=100, **kwargs):
        data = pandas.DataFrame(data)
        if ax is None:
            self.subplotno += 1
            ax = self.get_ax()
        #data.ndim == 1
        ax.hist(np.array(data), bins=bins, **kwargs)
        return ax
 
[docs]    def matrix_plot(self, data, ax=None, normalize='auto', **kwargs):
        """
        Plots a matrix.
        .. warning:: Not tested yet
        :Args:
            matrix
        :Kwargs:
            - ax (default: None)
                An axis to plot on.
            - title (str, default: '')
                Plot title
            - kwargs
                Keyword arguments to pass to :func:`matplotlib.pyplot.imshow`
        """
        #if ax is None: ax = self.next()
        #from mpl_toolkits.axes_grid1 import make_axes_locatable
        import matplotlib.colors
        if normalize != 'auto':
            norm = matplotlib.colors.normalize(vmin=data.min().min(),
                                               vmax=data.max().max())
        else:
            norm = None
        data = _unstack_levels(data, 'cols')
        im = ax.imshow(data, norm=norm, interpolation='none',
                       cmap='coolwarm', **kwargs)
        #minor_ticks = np.linspace(-.5, nvars - 1.5, nvars)
        #ax.set_xticks(minor_ticks, True)
        #ax.set_yticks(minor_ticks, True)
        ax.set_xticks(np.arange(data.shape[1])-.5, True)
        ax.set_yticks(np.arange(data.shape[0])+.5, True)
        ax.set_xticks(np.arange(data.shape[1]))
        ax.set_yticks(np.arange(data.shape[0]))
        ax.grid(False, which="major")
        ax.grid(True, which="minor", linestyle="-")
        self.axes.cbar_axes[self.subplotno].colorbar(im)
        return ax
 
[docs]    def add_inner_title(self, ax, title, loc=2, size=None, **kwargs):
        from matplotlib.offsetbox import AnchoredText
        from matplotlib.patheffects import withStroke
        if size is None:
            size = dict(size=plt.rcParams['legend.fontsize'])
        at = AnchoredText(title, loc=loc, prop=size,
                          pad=0., borderpad=0.5,
                          frameon=False, **kwargs)
        ax.add_artist(at)
        at.txt._text.set_path_effects([withStroke(foreground="w", linewidth=3)])
        return at
 
[docs]    def mds_plot(self, mean, fonts='freesansbold.ttf', ax=None, ndim=2, **kwargs):
        """Plots Multidimensional scaling results"""
        # plot each point with a name
        #dims = results.ndim
        #try:
            #if results.shape[1] == 1:
                #dims = 1
        #except:
            #pass
        #import pdb; pdb.set_trace()
        res = stats.mds(mean, ndim=ndim)
        if ndim == 1:
            ax = self.bar_plot(res, ax=ax)
            #df = pandas.DataFrame(results, index=labels, columns=['data'])
            #df = df.sort(columns='data')
            #self._plot(df)
        elif ndim == 2:
            #x, y = results.T
            self.scatter_plot(res, ax=ax)
            ##for c, coord in enumerate(results):
                ##ax.plot(coord[0], coord[1], 'o', color=mpl.cm.Paired(.5))
                ##ax.text(coord[0], coord[1], labels[c], fontproperties=fonts[c])
        else:
            print 'Cannot plot more than 2 dims'
        return ax, res
 
    def _violin_plot(self, data, pos, rlabels, ax=None, bp=False, cut=None, **kwargs):
        """
        Make a violin plot of each dataset in the `data` sequence.
        Based on `code by Teemu Ikonen
        <http://matplotlib.1069221.n5.nabble.com/Violin-and-bean-plots-tt27791.html>`_
        which was based on `code by Flavio Codeco Coelho
        <http://pyinsci.blogspot.com/2009/09/violin-plot-with-matplotlib.html>`)
        """
        def draw_density(p, low, high, k1, k2, ncols=2):
            m = low #lower bound of violin
            M = high #upper bound of violin
            x = np.linspace(m, M, 100) # support for violin
            if k1 is not None:
                v1 = k1.evaluate(x) # violin profile (density curve)
                v1 = w*v1/v1.max() # scaling the violin to the available space
            if k2 is not None:
                v2 = k2.evaluate(x) # violin profile (density curve)
                v2 = w*v2/v2.max() # scaling the violin to the available space
            if ncols == 2:
                if k1 is not None:
                    ax.fill_betweenx(x, -v1 + p, p, facecolor='black', edgecolor='black')
                if k2 is not None:
                    ax.fill_betweenx(x, p, p + v2, facecolor='grey', edgecolor='gray')
            else:
                #if k1 is not None and k2 is not None:
                ax.fill_betweenx(x, -v1 + p, p + v2, facecolor='black',
                                     edgecolor='black')
        if pos is None:
            pos = [0,1]
        dist = np.max(pos)-np.min(pos)
        w = .75/4# min(0.15*max(dist,1.0),0.5) * .5
        #for major_xs in range(data.shape[1]):
        for rno, rlabel in enumerate(rlabels):
            p = pos[rno]
            #d1 = data.iloc[rlabel].icol(0)
            ##s1 = sel.iloc[rlabel].icol(0)
            ##if s1:
            #d1 = d1[pandas.notnull(d1)]
            #k1 = scipy.stats.gaussian_kde(d1)  # calculates kernel density
            #else:
                #k1 = None
            #import pdb; pdb.set_trace()
            if rlabel is None:
                d = data
            else:
                if not isinstance(rlabel, (tuple, list)):
                    rlabel = [rlabel]
                d = data.copy()
                for r in rlabel:
                    d = d.loc[:,r]
                #d = data.loc[:,rlabel]
            if d.ndim == 1:
                d1 = d
                d1 = d1[pandas.notnull(d1)]
                d2 = d1
                if len(d1) > 1:
                    k1 = scipy.stats.gaussian_kde(d1)  # calculates kernel density
                else:
                    k1 = None
                k2 = k1
            elif d.ndim == 2:
                d1 = d.iloc[:,0]
                d1 = d1[pandas.notnull(d1)]
                if len(d1) > 1:
                    k1 = scipy.stats.gaussian_kde(d1)  # calculates kernel density
                else:
                    k1 = None
                d2 = d.iloc[:,1]
                d2 = d2[pandas.notnull(d2)]
                #s2 = sel.ix[rlabel].icol(1)
                if len(d2) > 1:
                    k2 = scipy.stats.gaussian_kde(d2)  # calculates kernel density
                else:
                    k2 = None
            else:
                raise Exception('beanplots are only available for one or two '
                    'columns, but we detected %d columns' % data.ndim)
            if k1 is not None and k2 is not None:
                cutoff = .001
                if cut is None:
                    #if s1 and s2:
                    high = max(d1.max(),d2.max())
                    low = min(d1.min(),d2.min())
                    #elif s1:
                        #high = d1.max()
                        #low = d1.min()
                    #elif s2:
                        #high = d2.max()
                        #low = d2.min()
                    stepsize = (high - low) / 100
                    area_low1 = 1  # max cdf value
                    area_low2 = 1  # max cdf value
                    while area_low1 > cutoff or area_low2 > cutoff:
                        area_low1 = k1.integrate_box_1d(-np.inf, low)
                        area_low2 = k2.integrate_box_1d(-np.inf, low)
                        low -= stepsize
                    area_high1 = 1  # max cdf value
                    area_high2 = 1  # max cdf value
                    while area_high1 > cutoff or area_high2 > cutoff:
                        area_high1 = k1.integrate_box_1d(high, np.inf)
                        area_high2 = k2.integrate_box_1d(high, np.inf)
                        high += stepsize
                else:
                    low, high = cut
                draw_density(p, low, high, k1, k2, ncols=d.ndim)
        # a work-around for generating a legend for the PolyCollection
        # from http://matplotlib.org/users/legend_guide.html#using-proxy-artist
        left = Rectangle((0, 0), 1, 1, fc="black", ec='black')
        right = Rectangle((0, 0), 1, 1, fc="gray", ec='gray')
        if d.ndim == 1:
            ax.legend((left,), [''])
        else:
            ax.legend((left, right), d.columns.tolist())
        #ax.set_xlim(pos[0]-3*w, pos[-1]+3*w)
        #if bp:
            #ax.boxplot(data,notch=1,positions=pos,vert=1)
        return ax
    def _stripchart(self, data, pos, rlabels, ax=None,
        mean=False, median=False, width=None, discrete=True, bins=30):
        """Plot samples given in `data` as horizontal lines.
        :Kwargs:
            mean: plot mean of each dataset as a thicker line if True
            median: plot median of each dataset as a dot if True.
            width: Horizontal width of a single dataset plot.
        """
        def draw_lines(p, d, maxcount, hist, bin_edges, sides=None):
            d = d[pandas.notnull(d)]
            if discrete:
                bin_edges = bin_edges[:-1]  # upper edges not needed
                hw = hist * w / (2.*maxcount)
            else:
                bin_edges = d
                hw = w / 2.
            if mean or len(d) < 2:  # draws a longer black line
                ax.hlines(np.mean(d), sides[0]*2*w + p, sides[1]*2*w + p,
                    lw=2, color='black')
            #if sel:
            #import pdb; pdb.set_trace()
            if len(d) > 1:
                ax.hlines(bin_edges, sides[0]*hw + p, sides[1]*hw + p, color='white')
            if median and len(d) > 1:  # puts a white dot
                ax.plot(p, np.median(d), 'x', color='white', mew=2)
        if width is not None:
            w = width
        else:
            #dist = np.max(pos)-np.min(pos)
            w = .75/4 #len(pos) # min(0.15*max(dist,1.0),0.5) * .5
        ## put rows and cols in cols, yerr in rows (original format)
        #data = self._stack_levels(data, 'cols')
        #data = self._unstack_levels(data, 'yerr').T
        #sel = self._stack_levels(sel, 'cols')
        #sel = self._unstack_levels(sel, 'yerr').T
        # apply along cols
        #import pdb; pdb.set_trace()
        rng = (data.min().min(), data.max().max())
        hists = []#data.max()
        for dno, d in data.iteritems():
            #d = data.iloc[:,p]
            d = d[pandas.notnull(d)]
            hist, bin_ed = np.histogram(d, bins=bins, range=rng)
            #import pdb; pdb.set_trace()
            hists.extend(hist.tolist())
            #hists.iloc[p] = hist  # hists is Series
        maxcount = np.max(hists)
        #import pdb; pdb.set_trace()
        #gg
        for rno, rlab in enumerate(rlabels):
            if rlab is None:
                d = data
            else:
                d = _get_multi(data, rlab, dim='columns')
                #if not isinstance(rlab, (tuple, list)):
                    #rlab = [rlab]
                #d = data.copy()
                ##gg
                #for r in rlab:
                    #d = d.loc[:,r]
            # awful repetition of hist
            # until I figure out something better
            if d.ndim == 1:
                d = d[pandas.notnull(d)]
                #import pdb; pdb.set_trace()
                hist, bin_edges = np.histogram(d, bins=bins, range=rng)
                draw_lines(pos[rno], d, maxcount, hist, bin_edges, sides=[-1,1])
            elif d.ndim == 2:
                d1 = d[pandas.notnull(d.iloc[:,0])].iloc[:,0]
                hist, bin_edges = np.histogram(d1, bins=bins, range=rng)
                draw_lines(pos[rno], d1, maxcount, hist, bin_edges, sides=[-1,0])
                d2 = d[pandas.notnull(d.iloc[:,1])].iloc[:,1]
                hist, bin_edges = np.histogram(d2, bins=bins, range=rng)
                draw_lines(pos[rno], d2, maxcount, hist, bin_edges, sides=[ 0,1])
            else:
                raise Exception('beanplots are only available for one or two '
                    'columns, but we detected %d columns' % d.ndim)
        #hist, bin_edges = np.apply_along_axis(np.histogram, 0, data, bins)
        ## it return arrays of object type, so we got to correct that
        #hist = np.array(hist.tolist())
        #bin_edges = np.array(bin_edges.tolist())
        #maxcount = np.max(hist)
        #for n, rlabel in enumerate(rlabels):
            #p = pos[n]
            #d = data.ix[:, rlabel]
            #s = sel.ix[:, rlabel]
            #if len(d.columns) == 2:
                #draw_lines(d.ix[:,0], s.ix[:,0], maxcount, hist[0],
                    #bin_edges[0], sides=[-1,0])
                #draw_lines(d.ix[:,1], s.ix[:,0], maxcount, hist[1],
                    #bin_edges[1], sides=[ 0,1])
            #else:
                #draw_lines(d.ix[:,0], s.ix[:,0], maxcount, hist[n],
                            #bin_edges[n], sides=[-1,1])
        ax.set_xlim(min(pos)-3*w, max(pos)+3*w)
        ax.set_xticks(pos)
        return ax
[docs]    def bean_plot(self, data, ax=None, pos=None, mean=True, median=True, cut=None,
        order=None, discrete=True, **kwargs):
        """Make a bean plot of each dataset in the `data` sequence.
        Reference: `<http://www.jstatsoft.org/v28/c01/paper>`_
        """
        #data_tr, pos, rlabels, sel = self._beanlike_setup(data, ax, order)
        #data_mean = self._stack_levels(data_tr, 'cols')
        #data_mean = self._unstack_levels(data_mean, 'yerr')
        #data_mean = data_mean.mean(1)
        try:  # not guaranteed that columns have names and levels
            len(data.columns.names)
        except:
            rlabels = data.columns
        else:
            rowdata = _stack_levels(data, 'cols')
            if rowdata.shape[1] > 1:
                inds = [i for i,n in enumerate(rowdata.columns.names) if n.startswith('rows.')]
            else:
                inds = []
            if len(inds) == 0:
                rlabels = [None]  # no rows
            else:#elif len(inds) >= 1:
                #import pdb; pdb.set_trace()
                labs = [rowdata.columns.get_level_values(i) for i in inds]
                rlabels = list(zip(*labs))
        pos = range(len(rlabels))
        #import pdb; pdb.set_trace()
        dist = np.max(pos) - np.min(pos)
        #w = min(0.15*max(dist,1.0),0.5) * .5
        w = .75/4 #dist * .75/4
        #import pdb; pdb.set_trace()
        ax = self._stripchart(data, pos, rlabels, ax=ax, mean=mean, median=median,
            width=w, discrete=discrete)
        ax = self._violin_plot(data, pos, rlabels, ax=ax, bp=False, cut=cut)
        #ax = self._stripchart(data_tr, pos, rlabels, sel, ax=ax, mean=mean, median=median,
            #width=0.8*w, discrete=discrete)
        #ax = self._violinplot(data_tr, pos, rlabels, sel, ax=ax, bp=False, cut=cut)
        return ax
 
def _unstack_levels(data, pref):
    try:
        levels = [n for n in data.index.names if n.startswith(pref+'.')]
    except:
        unstacked = data
    else:
        if len(levels) == 0:
            unstacked = pandas.DataFrame(data)
        else:
            try:
                clevs = data.columns.names + levels
            except:
                clevs = levels
            try:
                rlevs = [n for n in data.index.names if n not in levels]
            except:
                rlevs = None #['']#levels
            unstacked = stats.unstack(data,level=levels[0])
            if len(levels) > 1:
                for lev in levels[1:]:
                    unstacked = stats.unstack(unstacked, level=lev)
            if isinstance(unstacked, pandas.Series):
                unstacked = pandas.DataFrame(unstacked).T
            #unused = [n for n in data.index.names if not n.startswith(pref+'.')]
            for lev in clevs:
                try:
                    order = data.columns.get_level_values(lev).unique()
                except:
                    pass
                else:
                    unstacked = stats.reorder(unstacked, order=order, level=lev, dim='columns')
            for lev in rlevs:
                order = data.index.get_level_values(lev).unique()
                unstacked = stats.reorder(unstacked, order=order, level=lev, dim='index')
            for lev in levels:
                order = data.index.get_level_values(lev).unique()
                unstacked = stats.reorder(unstacked, order=order, level=lev, dim='columns')
            try:
                unstacked.columns.names = clevs
            except:
                import pdb; pdb.set_trace()
            if rlevs is not None and len(rlevs) > 0:
                try:
                    unstacked.index.names = rlevs
                except:
                    import pdb; pdb.set_trace()
    return unstacked
def _stack_levels(data, pref):
    try:
        levels = [n for n in data.columns.names if n.startswith(pref+'.')]
    except:
        stacked = data
    else:
        if len(levels) == 0:
            stacked = pandas.DataFrame(data)
        else:
            try:
                clevs = [n for n in data.columns.names if n not in levels]
            except:
                clevs = None #levels
            try:
                rlevs = data.index.names + levels #[n for n in data.index.names if n not in levels]
            except:
                rlevs = levels
            #if len(levels) == 1:
                #stacked = data.stack(levels)
            #else:
            stacked = stats.stack(data, level=levels)
            stacked = pandas.DataFrame(stacked)
            #unused = [n for n in data.index.names if not n.startswith(pref+'.')]
            if clevs is not None:
                for lev in clevs:
                    order = data.columns.get_level_values(lev).unique()
                    stacked = stats.reorder(stacked, order=order, level=lev, dim='columns')
            for lev in rlevs:
                try:
                    order = data.index.get_level_values(lev).unique()
                except:
                    pass
                else:
                    stacked = stats.reorder(stacked, order=order, level=lev, dim='index')
            for lev in levels:
                order = data.columns.get_level_values(lev).unique()
                stacked = stats.reorder(stacked, order=order, level=lev, dim='index')
            if clevs is not None and len(clevs) != 0:
                stacked.columns.names = clevs
            stacked.index.names = rlevs
            #stacked = data.stack(levels)
    return stacked
def _get_multi(data, labels, dim='columns'):
    if not isinstance(labels, (tuple, list)):
        labels = [labels]
    d = data.copy()
    for label in labels:
        if dim == 'columns':
            d = d.loc[:,label]
        else:
            d = d.loc[label]
    return d
if __name__ == '__main__':
    n = 8
    nsampl = 10
    k = n * nsampl
    data = {
        'subplots': ['session1']*k*18 + ['session2']*k*18,
        'cond': [1]*k*9 + [2]*k*9 + [1]*k*9 + [2]*k*9,
        'name': (['one', 'one', 'one']*k + ['two', 'two', 'two']*k +
                ['three', 'three', 'three']*k) * 4,
        'levels': (['small']*k + ['medium']*k + ['large']*k)*12,
        'subjID': ['subj%d' % (i+1) for i in np.repeat(range(n),nsampl)] * 36,
        'RT': range(k)*36,
        'accuracy': np.random.randn(36*k)
        }
    df = pandas.DataFrame(data, columns = ['subplots','cond','name','levels','subjID','RT',
        'accuracy'])
    #df = df.reindex_axis(['subplots','cond','name','levels','subjID','RT',
        #'accuracy'], axis=1)
    agg = stats.aggregate(df, subplots='subplots', rows=['cond', 'name'],
        cols='levels', yerr='subjID', values='RT')
    fig = Plot(ncols=2)
    fig.plot(agg, subplots=True)
    fig.show()