diff --git a/wafo/containers.py b/wafo/containers.py index 7e22ca5..f85fa78 100644 --- a/wafo/containers.py +++ b/wafo/containers.py @@ -1,7 +1,7 @@ from __future__ import absolute_import import warnings from wafo.graphutil import cltext -from wafo.plotbackend import plotbackend +from wafo.plotbackend import plotbackend as plt from time import gmtime, strftime import numpy as np from scipy.integrate.quadrature import cumtrapz # @UnresolvedImport @@ -20,14 +20,6 @@ def empty_copy(obj): return newcopy -def _set_seed(iseed): - if iseed is not None: - try: - np.random.set_state(iseed) - except: - np.random.seed(iseed) - - def now(): ''' Return current date and time as a string @@ -60,20 +52,22 @@ class PlotData(object): Example ------- >>> import numpy as np - >>> x = np.arange(-2, 2, 0.2) + >>> x = np.linspace(0, np.pi, 9) # Plot 2 objects in one call - >>> d2 = PlotData(np.sin(x), x, xlab='x', ylab='sin', title='sinus') + >>> d2 = PlotData(np.sin(x), x, xlab='x', ylab='sin', title='sinus', + ... plot_args=['r.']) - h = d2.plot() - h1 = d2() + >>> h = d2.plot() + >>> h1 = d2() - Plot with confidence interval + # Plot with confidence interval >>> d3 = PlotData(np.sin(x), x) - >>> d3.children = [PlotData(np.vstack([np.sin(x)*0.9, np.sin(x)*1.2]).T,x)] - >>> d3.plot_args_children=[':r'] + >>> d3.children = [PlotData(np.vstack([np.sin(x)*0.9, + ... np.sin(x)*1.2]).T, x)] + >>> d3.plot_args_children = [':r'] - h = d3.plot() + >>> h = d3.plot() ''' @@ -82,7 +76,7 @@ class PlotData(object): self.args = args self.date = now() self.plotter = kwds.pop('plotter', None) - self.children = None + self.children = kwds.pop('children', None) self.plot_args_children = kwds.pop('plot_args_children', []) self.plot_kwds_children = kwds.pop('plot_kwds_children', {}) self.plot_args = kwds.pop('plot_args', []) @@ -133,8 +127,10 @@ class PlotData(object): ... plot_args=['r.']) >>> di = PlotData(d.eval_points(xi), xi) - hi = di.plot() - h = d.plot() + >>> hi = di.plot() + >>> h = d.plot() + >>> di.to_cdf() + See also -------- @@ -154,16 +150,15 @@ class PlotData(object): warnings.warn(msg) else: xi = np.meshgrid(*self.args) - return interpolate.griddata( - xi, self.data.ravel(), points, **options) - else: # One dimensional data - return interpolate.griddata( - self.args, self.data, points, **options) + return interpolate.griddata(xi, self.data.ravel(), points, + **options) + # One dimensional data + return interpolate.griddata(self.args, self.data, points, **options) def to_cdf(self): if isinstance(self.args, (list, tuple)): # Multidimensional data raise NotImplementedError('integration for ndim>1 not implemented') - cdf = np.hstack((0, integrate.cumtrapz(self.data, self.args))) + cdf = np.hstack((0, cumtrapz(self.data, self.args))) return PlotData(cdf, np.copy(self.args), xlab='x', ylab='F(x)') def integrate(self, a=None, b=None, **kwds): @@ -182,40 +177,27 @@ class PlotData(object): fun = getattr(integrate, method) if isinstance(self.args, (list, tuple)): # Multidimensional data raise NotImplementedError('integration for ndim>1 not implemented') - # ndim = len(self.args) - # if ndim < 2: -# msg = '''Unable to determine plotter-type, because -# len(self.args)<2. -# If the data is 1D, then self.args should be a vector! -# If the data is 2D, then length(self.args) should be 2. -# If the data is 3D, then length(self.args) should be 3. -# Unless you fix this, the plot methods will not work!''' -# warnings.warn(msg) -# else: -# return interpolate.griddata(self.args, self.data.ravel(), **kwds) - else: # One dimensional data - return_ci = kwds.pop('return_ci', False) - x = self.args - if a is None: - a = x[0] - if b is None: - b = x[-1] - ix = np.flatnonzero((a < x) & (x < b)) - xi = np.hstack((a, x.take(ix), b)) - fi = np.hstack( - (self.eval_points(a), - self.data.take(ix), - self.eval_points(b))) - res = fun(fi, xi, **kwds) - if return_ci: - return np.hstack( - (res, fun(self.dataCI[ix, :].T, xi[1:-1], **kwds))) - return res + # One dimensional data + return_ci = kwds.pop('return_ci', False) + x = self.args + if a is None: + a = x[0] + if b is None: + b = x[-1] + ix = np.flatnonzero((a < x) & (x < b)) + xi = np.hstack((a, x.take(ix), b)) + fi = np.hstack((self.eval_points(a), self.data.take(ix), + self.eval_points(b))) + res = fun(fi, xi, **kwds) + if return_ci: + return np.hstack( + (res, fun(self.dataCI[ix, :].T, xi[1:-1], **kwds))) + return res def plot(self, *args, **kwds): axis = kwds.pop('axis', None) if axis is None: - axis = plotbackend.gca() + axis = plt.gca() tmp = None default_plotflag = self.plot_kwds.get('plotflag', None) plotflag = kwds.get('plotflag', default_plotflag) @@ -294,7 +276,7 @@ class AxisLabels: def labelfig(self, axis=None): if axis is None: - axis = plotbackend.gca() + axis = plt.gca() try: h = [] for fun, txt in zip( @@ -335,19 +317,19 @@ class Plotter_1d(object): if plotmethod is None: plotmethod = 'plot' self.plotmethod = plotmethod - self.plotbackend = plotbackend + # self.plotbackend = plotbackend # try: # self.plotfun = getattr(plotbackend, plotmethod) # except: # pass def show(self, *args, **kwds): - plotbackend.show(*args, **kwds) + plt.show(*args, **kwds) def plot(self, wdata, *args, **kwds): axis = kwds.pop('axis', None) if axis is None: - axis = plotbackend.gca() + axis = plt.gca() plotflag = kwds.pop('plotflag', False) if plotflag: h1 = self._plot(axis, plotflag, wdata, *args, **kwds) @@ -377,7 +359,7 @@ class Plotter_1d(object): def plot1d(axis, args, data, dataCI, plotflag, *varargin, **kwds): plottype = np.mod(plotflag, 10) - if plottype == 0: # % No plotting + if plottype == 0: # No plotting return [] elif plottype == 1: H = axis.plot(args, data, *varargin, **kwds) @@ -386,18 +368,10 @@ def plot1d(axis, args, data, dataCI, plotflag, *varargin, **kwds): elif plottype == 3: H = axis.stem(args, data, *varargin, **kwds) elif plottype == 4: - H = axis.errorbar( - args, - data, - yerr=[ - dataCI[ - :, - 0] - data, - dataCI[ - :, - 1] - data], - *varargin, - **kwds) + H = axis.errorbar(args, data, + yerr=[dataCI[:, 0] - data, + dataCI[:, 1] - data], + *varargin, **kwds) elif plottype == 5: H = axis.bar(args, data, *varargin, **kwds) elif plottype == 6: @@ -408,10 +382,8 @@ def plot1d(axis, args, data, dataCI, plotflag, *varargin, **kwds): H = axis.fill_between(args, data, *varargin, **kwds) elif plottype == 7: H = axis.plot(args, data, *varargin, **kwds) - H = axis.fill_between( - args, dataCI[ - :, 0], dataCI[ - :, 1], alpha=0.2, color='r') + H = axis.fill_between(args, dataCI[:, 0], dataCI[:, 1], + alpha=0.2, color='r') scale = plotscale(plotflag) logXscale = 'x' in scale @@ -467,16 +439,32 @@ def plotscale(plotflag): if (mod(floor(scaleId/10),10)>0) : Log scale on y-axis. if (mod(floor(scaleId/100),10)>0) : Log scale on z-axis. - scale = string defining plotscale valid options are: + scale = string defining plotscale valid options are: 'linear', 'xlog', 'ylog', 'xylog', 'zlog', 'xzlog', 'yzlog', 'xyzlog' - Example - plotscale(100) % xlog - plotscale(200) % xlog - plotscale(1000) % ylog - - See also plotscale + Examples + -------- + >>> for i in range(7): + ... plotscale(i*100) + 'linear' + 'xlog' + 'ylog' + 'xylog' + 'zlog' + 'xzlog' + 'yzlog' + + >>> plotscale(100) + 'xlog' + >>> plotscale(1000) + 'ylog' + >>> plotscale(10000) + 'zlog' + + See also + --------- + plotscale ''' scaleId = plotflag // 100 if scaleId > 7: @@ -485,15 +473,8 @@ def plotscale(plotflag): logZscaleId = (np.mod(scaleId // 100, 10) > 0) * 4 scaleId = logYscaleId + logXscaleId + logZscaleId - scales = [ - 'linear', - 'xlog', - 'ylog', - 'xylog', - 'zlog', - 'xzlog', - 'yzlog', - 'xyzlog'] + scales = ['linear', 'xlog', 'ylog', 'xylog', 'zlog', 'xzlog', + 'yzlog', 'xyzlog'] return scales[scaleId] @@ -593,7 +574,7 @@ def plot2d(axis, wdata, plotflag, *args, **kwds): elif plotflag == 10: h = axis.contourf(*args1, **kwds) axis.clabel(h) - plotbackend.colorbar(h) + plt.colorbar(h) else: raise ValueError('unknown option for plotflag') # if any(plotflag==(2:5)) @@ -603,16 +584,19 @@ def plot2d(axis, wdata, plotflag, *args, **kwds): def test_plotdata(): - plotbackend.ioff() - x = np.arange(-2, 2, 0.4) - xi = np.arange(-2, 2, 0.1) + plt.ioff() + x = np.linspace(0, np.pi, 9) + xi = np.linspace(0, np.pi, 4*9) - d = PlotData(np.sin(x), x, xlab='x', ylab='sin', title='sinus', + d = PlotData(np.sin(x)/2, x, xlab='x', ylab='sin', title='sinus', plot_args=['r.']) di = PlotData(d.eval_points(xi, method='cubic'), xi) unused_hi = di.plot() unused_h = d.plot() - d.show() + f = di.to_cdf() + for i in range(4): + _ = f.plot(plotflag=i) + d.show('hold') def test_docstrings(): @@ -621,10 +605,6 @@ def test_docstrings(): doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE) -def main(): - pass - if __name__ == '__main__': - test_docstrings() - # test_plotdata() - # main() + #test_docstrings() + test_plotdata()