From a669e556b739f8330593cb948e3b8075c817d69b Mon Sep 17 00:00:00 2001 From: pbrod Date: Fri, 30 Dec 2016 03:41:42 +0100 Subject: [PATCH] Simplified PlotData in containers.py --- wafo/containers.py | 201 ++++++++++++++++++++++++--------------------- 1 file changed, 106 insertions(+), 95 deletions(-) diff --git a/wafo/containers.py b/wafo/containers.py index cbb38d8..f86f050 100644 --- a/wafo/containers.py +++ b/wafo/containers.py @@ -7,6 +7,7 @@ import numpy as np from scipy.integrate.quadrature import cumtrapz # @UnresolvedImport from scipy import interpolate from scipy import integrate +from _warnings import warn __all__ = ['PlotData', 'AxisLabels'] @@ -161,6 +162,24 @@ class PlotData(object): cdf = np.hstack((0, cumtrapz(self.data, self.args))) return PlotData(cdf, np.copy(self.args), xlab='x', ylab='F(x)') + def _get_fi_xi(self, a, b): + x = self.args + if a is None: + a = x[0] + if b is None: + b = x[-1] + ix = np.flatnonzero((a < x) & (x < b)) + xi = np.hstack((a, x.take(ix), b)) + + if self.data.ndim > 1: + fi = np.vstack((self.eval_points(a), + self.data[ix, :], + self.eval_points(b))).T + else: + fi = np.hstack((self.eval_points(a), self.data.take(ix), + self.eval_points(b))) + return fi, xi + def integrate(self, a=None, b=None, **kwds): ''' >>> x = np.linspace(0,5,60) @@ -180,20 +199,7 @@ class PlotData(object): raise NotImplementedError('integration for ndim>1 not implemented') # One dimensional data return_ci = kwds.pop('return_ci', False) - x = self.args - if a is None: - a = x[0] - if b is None: - b = x[-1] - ix = np.flatnonzero((a < x) & (x < b)) - xi = np.hstack((a, x.take(ix), b)) - if self.data.ndim > 1: - fi = np.vstack((self.eval_points(a), - self.data[ix, :], - self.eval_points(b))).T - else: - fi = np.hstack((self.eval_points(a), self.data.take(ix), - self.eval_points(b))) + fi, xi = self._get_fi_xi(a, b) res = fun(fi, xi, **kwds) if return_ci: res_ci = [child.integrate(a, b, method=method) @@ -201,13 +207,8 @@ class PlotData(object): return np.hstack((res, np.ravel(res_ci))) return res - def plot(self, *args, **kwds): - axis = kwds.pop('axis', None) - if axis is None: - axis = plt.gca() + def _plot_children(self, axis, plotflag, kwds): tmp = None - default_plotflag = self.plot_kwds.get('plotflag') - plotflag = kwds.get('plotflag', default_plotflag) if not plotflag and self.children is not None: axis.hold('on') tmp = [] @@ -220,8 +221,18 @@ class PlotData(object): tmp1 = child.plot(*child_args, **child_kwds) if tmp1 is not None: tmp.append(tmp1) - if len(tmp) == 0: - tmp = None + if tmp: + return tmp + return None + + def plot(self, *args, **kwargs): + kwds = kwargs.copy() + axis = kwds.pop('axis', None) + if axis is None: + axis = plt.gca() + default_plotflag = self.plot_kwds.get('plotflag') + plotflag = kwds.get('plotflag', default_plotflag) + tmp = self._plot_children(axis, plotflag, kwds) main_args = args if len(args) else tuple(self.plot_args) main_kwds = dict(self.plot_kwds).copy() main_kwds.update(kwds) @@ -279,23 +290,30 @@ class AxisLabels: newcopy.__dict__.update(self.__dict__) return newcopy + def _add_title_if_fun_is_set_title(self, txt, title0, fun): + if fun.startswith('set_title'): + if title0.lower().strip() != txt.lower().strip(): + txt = title0 + '\n' + txt + return txt + + def _labelfig(self, axis): + h = [] + title0 = axis.get_title() + for fun, txt in zip( + ('set_title', 'set_xlabel', 'set_ylabel', 'set_ylabel'), + (self.title, self.xlab, self.ylab, self.zlab)): + if txt: + txt = self._add_title_if_fun_is_set_title(txt, title0, fun) + h.append(getattr(axis, fun)(txt)) + return h + def labelfig(self, axis=None): if axis is None: axis = plt.gca() try: - h = [] - for fun, txt in zip( - ('set_title', 'set_xlabel', 'set_ylabel', 'set_ylabel'), - (self.title, self.xlab, self.ylab, self.zlab)): - if txt: - if fun.startswith('set_title'): - title0 = axis.get_title() - if title0.lower().strip() != txt.lower().strip(): - txt = title0 + '\n' + txt - h.append(getattr(axis, fun)(txt)) - return h - except: - pass + return self._labelfig(axis) + except Exception as err: + warnings.warn(str(err)) class Plotter_1d(object): @@ -358,28 +376,29 @@ class Plotter_1d(object): __call__ = plot +def set_axis(axis, f_max, trans_flag, log_scale): + ax = list(axis.axis()) + if trans_flag == 8 and not log_scale: + ax[3] = 11 * np.log10(f_max) + ax[2] = ax[3] - 40 + else: + ax[3] = 1.15 * f_max + ax[2] = ax[3] * 1e-4 + axis.axis(ax) + + def set_plot_scale(axis, f_max, plotflag): scale = plotscale(plotflag) - log_x_scale = 'x' in scale - log_y_scale = 'y' in scale - log_z_scale = 'z' in scale - if log_x_scale: - axis.set(xscale='log') - if log_y_scale: - axis.set(yscale='log') - if log_z_scale: - axis.set(zscale='log') + log_scale = False + for dim in ['x', 'y', 'z']: + if dim in scale: + log_scale = True + opt = {'{}scale'.format(dim): 'log'} + axis.set(**opt) + trans_flag = np.mod(plotflag // 10, 10) - log_scale = log_x_scale or log_y_scale or log_z_scale if log_scale or (trans_flag == 5 and not log_scale): - ax = list(axis.axis()) - if trans_flag == 8 and not log_scale: - ax[3] = 11 * np.log10(f_max) - ax[2] = ax[3] - 40 - else: - ax[3] = 1.15 * f_max - ax[2] = ax[3] * 1e-4 - axis.axis(ax) + set_axis(axis, f_max, trans_flag, log_scale) def plot1d(axis, args, data, dataCI, plotflag, *varargin, **kwds): @@ -536,30 +555,42 @@ class Plotter_2d(Plotter_1d): return h1 +def _get_contour_levels(f): + isPL = False + PL = None + # check if contour levels is submitted + if hasattr(f, 'clevels') and len(f.clevels) > 0: + CL = f.clevels + isPL = hasattr(f, 'plevels') and f.plevels is not None + if isPL: + PL = f.plevels # levels defines quantile levels? 0=no 1=yes + else: + dmax = np.max(f.data) + dmin = np.min(f.data) + CL = dmax - (dmax - dmin) * \ + (1 - np.r_[0.01, 0.025, 0.05, 0.1, 0.2, 0.4, 0.5, 0.75]) + clvec = np.sort(CL) + return clvec, PL + + def plot2d(axis, wdata, plotflag, *args, **kwds): f = wdata if isinstance(wdata.args, (list, tuple)): args1 = tuple((wdata.args)) + (wdata.data,) + args else: args1 = tuple((wdata.args,)) + (wdata.data,) + args + + pltfun = [None, axis.contour, axis.mesh, axis.surf, axis.waterfal, + axis.pcolor, axis.contour, axis.contour, axis.contour, + axis.contour, axis.contourf][plotflag] + if plotflag in (1, 6, 7, 8, 9): - isPL = False - # check if contour levels is submitted - if hasattr(f, 'clevels') and len(f.clevels) > 0: - CL = f.clevels - isPL = hasattr(f, 'plevels') and f.plevels is not None - if isPL: - PL = f.plevels # levels defines quantile levels? 0=no 1=yes - else: - dmax = np.max(f.data) - dmin = np.min(f.data) - CL = dmax - (dmax - dmin) * \ - (1 - np.r_[0.01, 0.025, 0.05, 0.1, 0.2, 0.4, 0.5, 0.75]) - clvec = np.sort(CL) + clvec, PL = _get_contour_levels(f) if plotflag in [1, 8, 9]: - h = axis.contour(*args1, levels=clvec, **kwds) - # else: + h = pltfun(*args1, levels=clvec, **kwds) + else: + h = pltfun(*args1, **kwds) # [cs hcs] = contour3(f.x{:},f.f,CL,sym); if plotflag in (1, 6): @@ -569,31 +600,16 @@ def plot2d(axis, wdata, plotflag, *args, **kwds): warnings.warn( 'Only the first 12 levels will be listed in table.') + isPL = PL is not None clvals = PL[:ncl] if isPL else clvec[:ncl] unused_axcl = cltext(clvals, percent=isPL) - elif any(plotflag == [7, 9]): - axis.clabel(h) else: axis.clabel(h) - elif plotflag == 2: - h = axis.mesh(*args1, **kwds) - elif plotflag == 3: - # shading interp % flat, faceted % surfc - h = axis.surf(*args1, **kwds) - elif plotflag == 4: - h = axis.waterfall(*args1, **kwds) - elif plotflag == 5: - h = axis.pcolor(*args1, **kwds) # %shading interp % flat, faceted - elif plotflag == 10: - h = axis.contourf(*args1, **kwds) - axis.clabel(h) - plt.colorbar(h) else: - raise ValueError('unknown option for plotflag') - # if any(plotflag==(2:5)) - # shading(shad); - # end - # pass + h = pltfun(*args1, **kwds) + if plotflag == 10: + axis.clabel(h) + plt.colorbar(h) def test_plotdata(): @@ -613,12 +629,7 @@ def test_plotdata(): d.show('hold') -def test_docstrings(): - import doctest - print('Testing docstrings in %s' % __file__) - doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) - - if __name__ == '__main__': - test_docstrings() + from wafo.testing import test_docstrings + test_docstrings(__file__) # test_plotdata()