from __future__ import absolute_import import warnings from .graphutil import cltext from .plotbackend import plotbackend from time import gmtime, strftime import numpy as np from scipy.integrate.quadrature import cumtrapz # @UnresolvedImport from scipy import interpolate from scipy import integrate __all__ = ['PlotData', 'AxisLabels'] def empty_copy(obj): class Empty(obj.__class__): def __init__(self): pass newcopy = Empty() newcopy.__class__ = obj.__class__ return newcopy def _set_seed(iseed): if iseed is not None: try: np.random.set_state(iseed) except: np.random.seed(iseed) def now(): ''' Return current date and time as a string ''' return strftime("%a, %d %b %Y %H:%M:%S", gmtime()) class PlotData(object): ''' Container class for data with interpolation and plotting methods Member variables ---------------- data : array_like args : vector for 1D, list of vectors for 2D, 3D, ... labels : AxisLabels children : list of PlotData objects plot_args_children : list of arguments to the children plots plot_kwds_children : dict of keyword arguments to the children plots plot_args : list of arguments to the main plot plot_kwds : dict of keyword arguments to the main plot Member methods -------------- copy : return a copy of object eval_points : interpolate data at given points and return the result plot : plot data on given axis and the object handles Example ------- >>> import numpy as np >>> x = np.arange(-2, 2, 0.2) # Plot 2 objects in one call >>> d2 = PlotData(np.sin(x), x, xlab='x', ylab='sin', title='sinus') >>> h = d2.plot() >>> h1 = d2() Plot with confidence interval >>> d3 = PlotData(np.sin(x), x) >>> d3.children = [PlotData(np.vstack([np.sin(x)*0.9, np.sin(x)*1.2]).T,x)] >>> d3.plot_args_children=[':r'] >>> h = d3.plot() ''' def __init__(self, data=None, args=None, *args2, **kwds): self.data = data self.args = args self.date = now() self.plotter = kwds.pop('plotter', None) self.children = None self.plot_args_children = kwds.pop('plot_args_children', []) self.plot_kwds_children = kwds.pop('plot_kwds_children', {}) self.plot_args = kwds.pop('plot_args', []) self.plot_kwds = kwds.pop('plot_kwds', {}) self.labels = AxisLabels(**kwds) if not self.plotter: self.setplotter(kwds.get('plotmethod', None)) def copy(self): newcopy = empty_copy(self) newcopy.__dict__.update(self.__dict__) return newcopy def eval_points(self, *points, **kwds): ''' Interpolate data at points Parameters ---------- points : ndarray of float, shape (..., ndim) Points where to interpolate data at. method : {'linear', 'nearest', 'cubic'} method : {'linear', 'nearest', 'cubic'} Method of interpolation. One of - ``nearest``: return the value at the data point closest to the point of interpolation. - ``linear``: tesselate the input point set to n-dimensional simplices, and interpolate linearly on each simplex. - ``cubic`` (1-D): return the value detemined from a cubic spline. - ``cubic`` (2-D): return the value determined from a piecewise cubic, continuously differentiable (C1), and approximately curvature-minimizing polynomial surface. fill_value : float, optional Value used to fill in for requested points outside of the convex hull of the input points. If not provided, then the default is ``nan``. This option has no effect for the 'nearest' method. Examples -------- >>> import numpy as np >>> x = np.arange(-2, 2, 0.4) >>> xi = np.arange(-2, 2, 0.1) >>> d = PlotData(np.sin(x), x, xlab='x', ylab='sin', title='sinus', ... plot_args=['r.']) >>> di = PlotData(d.eval_points(xi), xi) >>> hi = di.plot() >>> h = d.plot() See also -------- scipy.interpolate.griddata ''' options = dict(method='linear') options.update(**kwds) if isinstance(self.args, (list, tuple)): # Multidimensional data ndim = len(self.args) if ndim < 2: msg = ''' Unable to determine plotter-type, because len(self.args)<2. If the data is 1D, then self.args should be a vector! If the data is 2D, then length(self.args) should be 2. If the data is 3D, then length(self.args) should be 3. Unless you fix this, the interpolation will not work!''' warnings.warn(msg) else: xi = np.meshgrid(*self.args) return interpolate.griddata( xi, self.data.ravel(), points, **options) else: # One dimensional data return interpolate.griddata( self.args, self.data, points, **options) def integrate(self, a, b, **kwds): ''' >>> x = np.linspace(0,5,60) >>> d = PlotData(np.sin(x), x) >>> d.dataCI = np.vstack((d.data*.9,d.data*1.1)).T >>> d.integrate(0,np.pi/2, return_ci=True) array([ 0.99940055, 0.85543644, 1.04553343]) ''' method = kwds.pop('method', 'trapz') fun = getattr(integrate, method) if isinstance(self.args, (list, tuple)): # Multidimensional data raise NotImplementedError('integration for ndim>1 not implemented') # ndim = len(self.args) # if ndim < 2: # msg = '''Unable to determine plotter-type, because # len(self.args)<2. # If the data is 1D, then self.args should be a vector! # If the data is 2D, then length(self.args) should be 2. # If the data is 3D, then length(self.args) should be 3. # Unless you fix this, the plot methods will not work!''' # warnings.warn(msg) # else: # return interpolate.griddata(self.args, self.data.ravel(), **kwds) else: # One dimensional data return_ci = kwds.pop('return_ci', False) x = self.args ix = np.flatnonzero((a < x) & (x < b)) xi = np.hstack((a, x.take(ix), b)) fi = np.hstack( (self.eval_points(a), self.data.take(ix), self.eval_points(b))) res = fun(fi, xi, **kwds) if return_ci: return np.hstack( (res, fun(self.dataCI[ix, :].T, xi[1:-1], **kwds))) return res def plot(self, *args, **kwds): axis = kwds.pop('axis', None) if axis is None: axis = plotbackend.gca() tmp = None default_plotflag = self.plot_kwds.get('plotflag', None) plotflag = kwds.get('plotflag', default_plotflag) if not plotflag and self.children is not None: axis.hold('on') tmp = [] child_args = kwds.pop( 'plot_args_children', tuple( self.plot_args_children)) child_kwds = dict(self.plot_kwds_children).copy() child_kwds.update(kwds.pop('plot_kwds_children', {})) child_kwds['axis'] = axis for child in self.children: tmp1 = child(*child_args, **child_kwds) if tmp1 is not None: tmp.append(tmp1) if len(tmp) == 0: tmp = None main_args = args if len(args) else tuple(self.plot_args) main_kwds = dict(self.plot_kwds).copy() main_kwds.update(kwds) main_kwds['axis'] = axis tmp2 = self.plotter.plot(self, *main_args, **main_kwds) return tmp2, tmp def setplotter(self, plotmethod=None): ''' Set plotter based on the data type: data_1d, data_2d, data_3d or data_nd ''' if isinstance(self.args, (list, tuple)): # Multidimensional data ndim = len(self.args) if ndim < 2: msg = ''' Unable to determine plotter-type, because len(self.args)<2. If the data is 1D, then self.args should be a vector! If the data is 2D, then length(self.args) should be 2. If the data is 3D, then length(self.args) should be 3. Unless you fix this, the plot methods will not work!''' warnings.warn(msg) elif ndim == 2: self.plotter = Plotter_2d(plotmethod) else: warnings.warn('Plotter method not implemented for ndim>2') else: # One dimensional data self.plotter = Plotter_1d(plotmethod) def show(self, *args, **kwds): self.plotter.show(*args, **kwds) __call__ = plot interpolate = eval_points class AxisLabels: def __init__(self, title='', xlab='', ylab='', zlab='', **kwds): self.title = title self.xlab = xlab self.ylab = ylab self.zlab = zlab def __repr__(self): return self.__str__() def __str__(self): return '%s\n%s\n%s\n%s\n' % ( self.title, self.xlab, self.ylab, self.zlab) def copy(self): newcopy = empty_copy(self) newcopy.__dict__.update(self.__dict__) return newcopy def labelfig(self, axis=None): if axis is None: axis = plotbackend.gca() try: h = [] for fun, txt in zip( ('set_title', 'set_xlabel', 'set_ylabel', 'set_ylabel'), (self.title, self.xlab, self.ylab, self.zlab)): if txt: if fun.startswith('set_title'): title0 = axis.get_title() if title0.lower().strip() != txt.lower().strip(): txt = title0 + '\n' + txt h.append(getattr(axis, fun)(txt)) return h except: pass class Plotter_1d(object): """ Parameters ---------- plotmethod : string defining type of plot. Options are: bar : bar plot with rectangles barh : horizontal bar plot with rectangles loglog : plot with log scaling on the *x* and *y* axis semilogx : plot with log scaling on the *x* axis semilogy : plot with log scaling on the *y* axis plot : Plot lines and/or markers (default) stem : Stem plot step : stair-step plot scatter : scatter plot """ def __init__(self, plotmethod='plot'): self.plotfun = None if plotmethod is None: plotmethod = 'plot' self.plotmethod = plotmethod self.plotbackend = plotbackend # try: # self.plotfun = getattr(plotbackend, plotmethod) # except: # pass def show(self, *args, **kwds): plotbackend.show(*args, **kwds) def plot(self, wdata, *args, **kwds): axis = kwds.pop('axis', None) if axis is None: axis = plotbackend.gca() plotflag = kwds.pop('plotflag', False) if plotflag: h1 = self._plot(axis, plotflag, wdata, *args, **kwds) else: if isinstance(wdata.data, (list, tuple)): vals = tuple(wdata.data) else: vals = (wdata.data,) if isinstance(wdata.args, (list, tuple)): args1 = tuple((wdata.args)) + vals + args else: args1 = tuple((wdata.args,)) + vals + args plotfun = getattr(axis, self.plotmethod) h1 = plotfun(*args1, **kwds) h2 = wdata.labels.labelfig(axis) return h1, h2 def _plot(self, axis, plotflag, wdata, *args, **kwds): x = wdata.args data = transformdata(x, wdata.data, plotflag) dataCI = getattr(wdata, 'dataCI', ()) h1 = plot1d(axis, x, data, dataCI, plotflag, *args, **kwds) return h1 __call__ = plot def plot1d(axis, args, data, dataCI, plotflag, *varargin, **kwds): plottype = np.mod(plotflag, 10) if plottype == 0: # % No plotting return [] elif plottype == 1: H = axis.plot(args, data, *varargin, **kwds) elif plottype == 2: H = axis.step(args, data, *varargin, **kwds) elif plottype == 3: H = axis.stem(args, data, *varargin, **kwds) elif plottype == 4: H = axis.errorbar( args, data, yerr=[ dataCI[ :, 0] - data, dataCI[ :, 1] - data], *varargin, **kwds) elif plottype == 5: H = axis.bar(args, data, *varargin, **kwds) elif plottype == 6: level = 0 if np.isfinite(level): H = axis.fill_between(args, data, level, *varargin, **kwds) else: H = axis.fill_between(args, data, *varargin, **kwds) elif plottype == 7: H = axis.plot(args, data, *varargin, **kwds) H = axis.fill_between( args, dataCI[ :, 0], dataCI[ :, 1], alpha=0.2, color='r') scale = plotscale(plotflag) logXscale = 'x' in scale logYscale = 'y' in scale logZscale = 'z' in scale if logXscale: axis.set(xscale='log') if logYscale: axis.set(yscale='log') if logZscale: axis.set(zscale='log') transFlag = np.mod(plotflag // 10, 10) logScale = logXscale or logYscale or logZscale if logScale or (transFlag == 5 and not logScale): ax = list(axis.axis()) fmax1 = data.max() if transFlag == 5 and not logScale: ax[3] = 11 * np.log10(fmax1) ax[2] = ax[3] - 40 else: ax[3] = 1.15 * fmax1 ax[2] = ax[3] * 1e-4 axis.axis(ax) if np.any(dataCI) and plottype < 3: axis.hold(True) plot1d(axis, args, dataCI, (), plotflag, 'r--') return H def plotscale(plotflag): ''' Return plotscale from plotflag CALL scale = plotscale(plotflag) plotflag = integer defining plotscale. Let scaleId = floor(plotflag/100). If scaleId < 8 then: 0 'linear' : Linear scale on all axes. 1 'xlog' : Log scale on x-axis. 2 'ylog' : Log scale on y-axis. 3 'xylog' : Log scale on xy-axis. 4 'zlog' : Log scale on z-axis. 5 'xzlog' : Log scale on xz-axis. 6 'yzlog' : Log scale on yz-axis. 7 'xyzlog' : Log scale on xyz-axis. otherwise if (mod(scaleId,10)>0) : Log scale on x-axis. if (mod(floor(scaleId/10),10)>0) : Log scale on y-axis. if (mod(floor(scaleId/100),10)>0) : Log scale on z-axis. scale = string defining plotscale valid options are: 'linear', 'xlog', 'ylog', 'xylog', 'zlog', 'xzlog', 'yzlog', 'xyzlog' Example plotscale(100) % xlog plotscale(200) % xlog plotscale(1000) % ylog See also plotscale ''' scaleId = plotflag // 100 if scaleId > 7: logXscaleId = np.mod(scaleId, 10) > 0 logYscaleId = (np.mod(scaleId // 10, 10) > 0) * 2 logZscaleId = (np.mod(scaleId // 100, 10) > 0) * 4 scaleId = logYscaleId + logXscaleId + logZscaleId scales = [ 'linear', 'xlog', 'ylog', 'xylog', 'zlog', 'xzlog', 'yzlog', 'xyzlog'] return scales[scaleId] def transformdata(x, f, plotflag): transFlag = np.mod(plotflag // 10, 10) if transFlag == 0: data = f elif transFlag == 1: data = 1 - f elif transFlag == 2: data = cumtrapz(f, x) elif transFlag == 3: data = 1 - cumtrapz(f, x) if transFlag in (4, 5): if transFlag == 4: data = -np.log1p(-cumtrapz(f, x)) else: if any(f < 0): raise ValueError('Invalid plotflag: Data or dataCI is ' + 'negative, but must be positive') data = 10 * np.log10(f) return data class Plotter_2d(Plotter_1d): """ Parameters ---------- plotmethod : string defining type of plot. Options are: contour (default) contourf mesh surf """ def __init__(self, plotmethod='contour'): if plotmethod is None: plotmethod = 'contour' super(Plotter_2d, self).__init__(plotmethod) def _plot(self, axis, plotflag, wdata, *args, **kwds): h1 = plot2d(axis, wdata, plotflag, *args, **kwds) return h1 def plot2d(axis, wdata, plotflag, *args, **kwds): f = wdata if isinstance(wdata.args, (list, tuple)): args1 = tuple((wdata.args)) + (wdata.data,) + args else: args1 = tuple((wdata.args,)) + (wdata.data,) + args if plotflag in (1, 6, 7, 8, 9): isPL = False # check if contour levels is submitted if hasattr(f, 'clevels') and len(f.clevels) > 0: CL = f.clevels isPL = hasattr(f, 'plevels') and f.plevels is not None if isPL: PL = f.plevels # levels defines quantile levels? 0=no 1=yes else: dmax = np.max(f.data) dmin = np.min(f.data) CL = dmax - (dmax - dmin) * \ (1 - np.r_[0.01, 0.025, 0.05, 0.1, 0.2, 0.4, 0.5, 0.75]) clvec = np.sort(CL) if plotflag in [1, 8, 9]: h = axis.contour(*args1, levels=CL, **kwds) # else: # [cs hcs] = contour3(f.x{:},f.f,CL,sym); if plotflag in (1, 6): ncl = len(clvec) if ncl > 12: ncl = 12 warnings.warn( 'Only the first 12 levels will be listed in table.') clvals = PL[:ncl] if isPL else clvec[:ncl] unused_axcl = cltext( clvals, percent=isPL) # print contour level text elif any(plotflag == [7, 9]): axis.clabel(h) else: axis.clabel(h) elif plotflag == 2: h = axis.mesh(*args1, **kwds) elif plotflag == 3: # shading interp % flat, faceted % surfc h = axis.surf(*args1, **kwds) elif plotflag == 4: h = axis.waterfall(*args1, **kwds) elif plotflag == 5: h = axis.pcolor(*args1, **kwds) # %shading interp % flat, faceted elif plotflag == 10: h = axis.contourf(*args1, **kwds) axis.clabel(h) plotbackend.colorbar(h) else: raise ValueError('unknown option for plotflag') # if any(plotflag==(2:5)) # shading(shad); # end # pass def test_plotdata(): plotbackend.ioff() x = np.arange(-2, 2, 0.4) xi = np.arange(-2, 2, 0.1) d = PlotData(np.sin(x), x, xlab='x', ylab='sin', title='sinus', plot_args=['r.']) di = PlotData(d.eval_points(xi, method='cubic'), xi) unused_hi = di.plot() unused_h = d.plot() d.show() def test_docstrings(): import doctest print('Testing docstrings in %s' % __file__) doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) def main(): pass if __name__ == '__main__': test_docstrings() # test_plotdata() # main()