From 0c2f6621e4eb105f5c9bb43cf4804d7802c6d1a8 Mon Sep 17 00:00:00 2001 From: "per.andreas.brodtkorb" Date: Thu, 5 Jan 2012 06:33:00 +0000 Subject: [PATCH] Small change: now possible to submit figure and axis to plot on --- pywafo/src/wafo/graphutil.py | 108 +++++++++++++++++---------- pywafo/src/wafo/wafodata.py | 137 +++++++++++++++++++++-------------- 2 files changed, 154 insertions(+), 91 deletions(-) diff --git a/pywafo/src/wafo/graphutil.py b/pywafo/src/wafo/graphutil.py index d1f86ae..76b963f 100644 --- a/pywafo/src/wafo/graphutil.py +++ b/pywafo/src/wafo/graphutil.py @@ -20,30 +20,40 @@ def _matchfun(x, gidtxt): return x.get_gid() == gidtxt return False -def delete_object(gidtxt, cf=None, ca=None, verbose=False): +def delete_text_object(gidtxt, figure=None, axis=None, verbose=False): ''' - Delete all objects matching the gidtxt if it exists + Delete all text objects matching the gidtxt if it exists + + Parameters + ---------- + gidtxt : string + + figure, axis : objects + current figure and current axis, respectively. + verbose : bool + If true print warnings when trying to delete non-existent objects ''' - if cf is None: - cf = plotbackend.gcf() - if ca is None: - ca = plotbackend.gca() + if figure is None: + figure = plotbackend.gcf() + if axis is None: + axis = figure.gca() lmatchfun = lambda x : _matchfun(x, gidtxt) - objs = plotbackend.findobj(cf, lmatchfun) - if len(objs): - for i in objs: - try: - ca.texts.remove(i) - except: - if verbose: - warnings.warn('Tried to delete a non-existing %s from axes' % gidtxt) - try: - cf.texts.remove(i) - except: - if verbose: - warnings.warn('Tried to delete a non-existing %s from figure' % gidtxt) + objs = axis.findobj(lmatchfun) + for obj in objs: + try: + axis.texts.remove(obj) + except: + if verbose: + warnings.warn('Tried to delete a non-existing %s from axis' % gidtxt) + objs = figure.findobj(lmatchfun) + for obj in objs: + try: + figure.texts.remove(obj) + except: + if verbose: + warnings.warn('Tried to delete a non-existing %s from figure' % gidtxt) -def cltext(levels, percent=False, n=4, xs=0.036, ys=0.94, zs=0): +def cltext(levels, percent=False, n=4, xs=0.036, ys=0.94, zs=0, figure=None, axis=None): ''' Places contour level text in the current window @@ -58,6 +68,10 @@ def cltext(levels, percent=False, n=4, xs=0.036, ys=0.94, zs=0): contour line encloses n : integer maximum N digits of precision (default 4) + figure, axis : objects + current figure and current axis, respectively. + default figure = plotbackend.gcf(), + axis = plotbackend.gca() Returns ------- @@ -89,10 +103,15 @@ def cltext(levels, percent=False, n=4, xs=0.036, ys=0.94, zs=0): >>> plt.show() ''' # TODO : Make it work like legend does (but without the box): include position options etc... + if figure is None: + figure = plotbackend.gcf() + if axis is None: + axis = figure.gca() + clevels = np.atleast_1d(levels) - cax = plotbackend.gca() - axpos = cax.get_position() + + axpos = axis.get_position() xint = axpos.intervalx yint = axpos.intervaly @@ -100,7 +119,7 @@ def cltext(levels, percent=False, n=4, xs=0.036, ys=0.94, zs=0): yss = yint[0] + ys * (yint[1] - yint[0]) # delete cltext object if it exists - delete_object(_CLTEXT_GID, ca=cax) + delete_text_object(_CLTEXT_GID, axis=axis) charHeight = 1.0 / 33.0 delta_y = charHeight @@ -115,15 +134,16 @@ def cltext(levels, percent=False, n=4, xs=0.036, ys=0.94, zs=0): cltxt = ''.join([format_ % level for level in clevels.tolist()]) titleProp = dict(gid=_CLTEXT_GID, horizontalalignment='left', - verticalalignment='center', fontweight='bold', axes=cax) # - ha1 = plotbackend.figtext(xss, yss, titletxt, **titleProp) + verticalalignment='center', fontweight='bold', axes=axis) # + + ha1 = figure.text(xss, yss, titletxt, **titleProp) yss -= delta_y; txtProp = dict(gid=_CLTEXT_GID, horizontalalignment='left', - verticalalignment='top', axes=cax) + verticalalignment='top', axes=axis) - ha2 = plotbackend.figtext(xss, yss, cltxt, **txtProp) - + ha2 = figure.text(xss, yss, cltxt, **txtProp) + plotbackend.draw_if_interactive() return ha1, ha2 def tallibing(x, y, n, **kwds): @@ -149,8 +169,8 @@ def tallibing(x, y, n, **kwds): >>> import wafo.graphutil as wg >>> import wafo.demos as wd >>> [x,y,z] = wd.peaks(n=20) - >>> wg.epcolor(x,y,z) - >>> wg.tallibing(x,y,z) + >>> h0 = wg.epcolor(x,y,z) + >>> h1 = wg.tallibing(x,y,z) pcolor(x,y,z); shading interp; @@ -158,28 +178,32 @@ def tallibing(x, y, n, **kwds): -------- text ''' + + axis = kwds.pop('axis',None) + if axis is None: + axis = plotbackend.gca() + x, y, n = np.atleast_1d(x, y, n) if mlab.isvector(x) or mlab.isvector(y): x, y = np.meshgrid(x,y) - cax = plotbackend.gca() - x = x.ravel() y = y.ravel() n = n.ravel() n = np.round(n) # delete tallibing object if it exists - delete_object(_TALLIBING_GID, ca=cax) + delete_text_object(_TALLIBING_GID, axis=axis) txtProp = dict(gid=_TALLIBING_GID, size=8, color='w', horizontalalignment='center', - verticalalignment='center', fontweight='demi', axes=cax) + verticalalignment='center', fontweight='demi', axes=axis) txtProp.update(**kwds) h = [] for xi,yi, ni in zip(x,y,n): if ni: - h.append(plotbackend.text(xi, yi, str(ni), **txtProp)) + h.append(axis.text(xi, yi, str(ni), **txtProp)) + plotbackend.draw_if_interactive() return h def epcolor(*args, **kwds): @@ -203,15 +227,20 @@ def epcolor(*args, **kwds): >>> import wafo.demos as wd >>> import wafo.graphutil as wg >>> x, y, z = wd.peaks(n=20) - >>> wg.epcolor(x,y,z) + >>> h = wg.epcolor(x,y,z) See also -------- pylab.pcolor ''' + axis = kwds.pop('axis',None) + if axis is None: + axis = plotbackend.gca() midbin = kwds.pop('midbin', True) if not midbin: - return plotbackend.pcolor(*args,**kwds) + ret = axis.pcolor(*args,**kwds) + plotbackend.draw_if_interactive() + return ret nargin = len(args) data = np.atleast_2d(args[-1]).copy() @@ -230,7 +259,10 @@ def epcolor(*args, **kwds): xx = _findbins(x) yy = _findbins(y) - return plotbackend.pcolor(xx, yy, data, **kwds) + ret = axis.pcolor(xx, yy, data, **kwds) + plotbackend.draw_if_interactive() + return ret + def _findbins(x): ''' Return points half way between all values of X _and_ outside the diff --git a/pywafo/src/wafo/wafodata.py b/pywafo/src/wafo/wafodata.py index 3977972..548a634 100644 --- a/pywafo/src/wafo/wafodata.py +++ b/pywafo/src/wafo/wafodata.py @@ -80,16 +80,20 @@ class WafoData(object): self.setplotter(kwds.get('plotmethod', None)) def plot(self, *args, **kwds): + axis = kwds.pop('axis',None) + if axis is None: + axis = plotbackend.gca() tmp = None plotflag = kwds.get('plotflag', None) if not plotflag and self.children != None: plotbackend.hold('on') tmp = [] - child_args = args if len(args) else tuple(self.plot_args_children) + child_args = kwds.pop('plot_args_children', tuple(self.plot_args_children)) child_kwds = dict(self.plot_kwds_children).copy() - child_kwds.update(**kwds) + child_kwds.update(kwds.pop('plot_kwds_children', {})) + child_kwds['axis'] = axis for child in self.children: - tmp1 = child.plot(*child_args, **kwds) + tmp1 = child.plot(*child_args, **child_kwds) if tmp1 != None: tmp.append(tmp1) if len(tmp) == 0: @@ -97,6 +101,7 @@ class WafoData(object): main_args = args if len(args) else tuple(self.plot_args) main_kwds = dict(self.plot_kwds).copy() main_kwds.update(kwds) + main_kwds['axis'] = axis tmp2 = self.plotter.plot(self, *main_args, **main_kwds) return tmp2, tmp @@ -145,11 +150,13 @@ class AxisLabels: newcopy.__dict__.update(self.__dict__) return newcopy - def labelfig(self): + def labelfig(self, axis=None): + if axis is None: + axis = plotbackend.gca() try: - h1 = plotbackend.title(self.title) - h2 = plotbackend.xlabel(self.xlab) - h3 = plotbackend.ylabel(self.ylab) + h1 = axis.set_title(self.title) + h2 = axis.set_xlabel(self.xlab) + h3 = axis.set_ylabel(self.ylab) #h4 = plotbackend.zlabel(self.zlab) return h1, h2, h3 except: @@ -176,76 +183,85 @@ class Plotter_1d(object): self.plotfun = None if plotmethod is None: plotmethod = 'plot' + self.plotmethod = plotmethod self.plotbackend = plotbackend - try: - self.plotfun = getattr(plotbackend, plotmethod) - except: - pass +# try: +# self.plotfun = getattr(plotbackend, plotmethod) +# except: +# pass def show(self): plotbackend.show() def plot(self, wdata, *args, **kwds): + axis = kwds.pop('axis',None) + if axis is None: + axis = plotbackend.gca() plotflag = kwds.pop('plotflag', False) if plotflag: - h1 = self._plot(plotflag, wdata, *args, **kwds) + h1 = self._plot(axis, plotflag, wdata, *args, **kwds) else: + if isinstance(wdata.data, (list, tuple)): + vals = tuple(wdata.data) + else: + vals = (wdata.data,) if isinstance(wdata.args, (list, tuple)): - args1 = tuple((wdata.args)) + (wdata.data,) + args + args1 = tuple((wdata.args)) + vals + args else: - args1 = tuple((wdata.args,)) + (wdata.data,) + args - h1 = self.plotfun(*args1, **kwds) - h2 = wdata.labels.labelfig() + args1 = tuple((wdata.args,)) + vals + args + plotfun = getattr(axis, self.plotmethod) + h1 = plotfun(*args1, **kwds) + h2 = wdata.labels.labelfig(axis) return h1, h2 - def _plot(self, plotflag, wdata, *args, **kwds): + def _plot(self, axis, plotflag, wdata, *args, **kwds): x = wdata.args data = transformdata(x, wdata.data, plotflag) dataCI = getattr(wdata, 'dataCI', ()) - h1 = plot1d(x, data, dataCI, plotflag, *args, **kwds) + h1 = plot1d(axis, x, data, dataCI, plotflag, *args, **kwds) return h1 -def plot1d(args, data, dataCI, plotflag, *varargin, **kwds): +def plot1d(axis, args, data, dataCI, plotflag, *varargin, **kwds): plottype = np.mod(plotflag, 10) if plottype == 0: # % No plotting return [] elif plottype == 1: - H = plotbackend.plot(args, data, *varargin, **kwds) + H = axis.plot(args, data, *varargin, **kwds) elif plottype == 2: - H = plotbackend.step(args, data, *varargin, **kwds) + H = axis.step(args, data, *varargin, **kwds) elif plottype == 3: - H = plotbackend.stem(args, data, *varargin, **kwds) + H = axis.stem(args, data, *varargin, **kwds) elif plottype == 4: - H = plotbackend.errorbar(args, data, dataCI[:,0] - data, dataCI[:,1] - data, *varargin, **kwds) + H = axis.errorbar(args, data, yerr=[dataCI[:,0] - data, dataCI[:,1] - data], *varargin, **kwds) elif plottype == 5: - H = plotbackend.bar(args, data, *varargin, **kwds) + H = axis.bar(args, data, *varargin, **kwds) elif plottype == 6: level = 0 if np.isfinite(level): - H = plotbackend.fill_between(args, data, level, *varargin, **kwds); + H = axis.fill_between(args, data, level, *varargin, **kwds); else: - H = plotbackend.fill_between(args, data, *varargin, **kwds); + H = axis.fill_between(args, data, *varargin, **kwds); elif plottype==7: - H = plotbackend.plot(args, data, *varargin, **kwds) - H = plotbackend.fill_between(args, dataCI[:,0], dataCI[:,1], alpha=0.2, color='r'); + H = axis.plot(args, data, *varargin, **kwds) + H = axis.fill_between(args, dataCI[:,0], dataCI[:,1], alpha=0.2, color='r'); scale = plotscale(plotflag) logXscale = 'x' in scale logYscale = 'y' in scale logZscale = 'z' in scale - ax = plotbackend.gca() + if logXscale: - plotbackend.setp(ax, xscale='log') + axis.set(xscale='log') if logYscale: - plotbackend.setp(ax, yscale='log') + axis.set(yscale='log') if logZscale: - plotbackend.setp(ax, zscale='log') + axis.set(zscale='log') transFlag = np.mod(plotflag // 10, 10) logScale = logXscale or logYscale or logZscale if logScale or (transFlag == 5 and not logScale): - ax = list(plotbackend.axis()) + ax = list(axis.axis()) fmax1 = data.max() if transFlag == 5 and not logScale: ax[3] = 11 * np.log10(fmax1) @@ -254,11 +270,11 @@ def plot1d(args, data, dataCI, plotflag, *varargin, **kwds): ax[3] = 1.15 * fmax1; ax[2] = ax[3] * 1e-4; - plotbackend.axis(ax) + axis.axis(ax) if np.any(dataCI) and plottype < 3: - plotbackend.hold('on') - plot1d(args, dataCI, (), plotflag, 'r--'); + axis.hold(True) + plot1d(axis, args, dataCI, (), plotflag, 'r--'); return H def plotscale(plotflag): @@ -287,12 +303,27 @@ def plotscale(plotflag): 'linear', 'xlog', 'ylog', 'xylog', 'zlog', 'xzlog', 'yzlog', 'xyzlog' - Example - plotscale(100) % xlog - plotscale(200) % xlog - plotscale(1000) % ylog + Example + >>> for id in range(100,701,100): + ... plotscale(id) + 'xlog' + 'ylog' + 'xylog' + 'zlog' + 'xzlog' + 'yzlog' + 'xyzlog' + + >>> plotscale(200) + 'ylog' + >>> plotscale(300) + 'xylog' + >>> plotscale(300) + 'xylog' - See also plotscale + See also + -------- + transformdata ''' scaleId = plotflag // 100 if scaleId > 7: @@ -341,11 +372,11 @@ class Plotter_2d(Plotter_1d): plotmethod = 'contour' super(Plotter_2d, self).__init__(plotmethod) - def _plot(self, plotflag, wdata, *args, **kwds): - h1 = plot2d(wdata, plotflag, *args, **kwds) + def _plot(self, axis, plotflag, wdata, *args, **kwds): + h1 = plot2d(axis, wdata, plotflag, *args, **kwds) return h1 -def plot2d(wdata, plotflag, *args, **kwds): +def plot2d(axis, wdata, plotflag, *args, **kwds): f = wdata if isinstance(wdata.args, (list, tuple)): args1 = tuple((wdata.args)) + (wdata.data,) + args @@ -365,7 +396,7 @@ def plot2d(wdata, plotflag, *args, **kwds): clvec = np.sort(CL) if plotflag in [1, 8, 9]: - h = plotbackend.contour(*args1, levels=CL, **kwds); + h = axis.contour(*args1, levels=CL, **kwds); #else: # [cs hcs] = contour3(f.x{:},f.f,CL,sym); @@ -378,20 +409,20 @@ def plot2d(wdata, plotflag, *args, **kwds): clvals = PL[:ncl] if isPL else clvec[:ncl] unused_axcl = cltext(clvals, percent=isPL) # print contour level text elif any(plotflag == [7, 9]): - plotbackend.clabel(h) + axis.clabel(h) else: - plotbackend.clabel(h) + axis.clabel(h) elif plotflag == 2: - h = plotbackend.mesh(*args1, **kwds) + h = axis.mesh(*args1, **kwds) elif plotflag == 3: - h = plotbackend.surf(*args1, **kwds) #shading interp % flat, faceted % surfc + h = axis.surf(*args1, **kwds) #shading interp % flat, faceted % surfc elif plotflag == 4: - h = plotbackend.waterfall(*args1, **kwds) + h = axis.waterfall(*args1, **kwds) elif plotflag == 5: - h = plotbackend.pcolor(*args1, **kwds) #%shading interp % flat, faceted + h = axis.pcolor(*args1, **kwds) #%shading interp % flat, faceted elif plotflag == 10: - h = plotbackend.contourf(*args1, **kwds) - plotbackend.clabel(h) + h = axis.contourf(*args1, **kwds) + axis.clabel(h) plotbackend.colorbar(h) else: raise ValueError('unknown option for plotflag')