Simplified PlotData in containers.py

master
pbrod 8 years ago
parent 2a492b4cf1
commit a669e556b7

@ -7,6 +7,7 @@ import numpy as np
from scipy.integrate.quadrature import cumtrapz # @UnresolvedImport from scipy.integrate.quadrature import cumtrapz # @UnresolvedImport
from scipy import interpolate from scipy import interpolate
from scipy import integrate from scipy import integrate
from _warnings import warn
__all__ = ['PlotData', 'AxisLabels'] __all__ = ['PlotData', 'AxisLabels']
@ -161,6 +162,24 @@ class PlotData(object):
cdf = np.hstack((0, cumtrapz(self.data, self.args))) cdf = np.hstack((0, cumtrapz(self.data, self.args)))
return PlotData(cdf, np.copy(self.args), xlab='x', ylab='F(x)') return PlotData(cdf, np.copy(self.args), xlab='x', ylab='F(x)')
def _get_fi_xi(self, a, b):
x = self.args
if a is None:
a = x[0]
if b is None:
b = x[-1]
ix = np.flatnonzero((a < x) & (x < b))
xi = np.hstack((a, x.take(ix), b))
if self.data.ndim > 1:
fi = np.vstack((self.eval_points(a),
self.data[ix, :],
self.eval_points(b))).T
else:
fi = np.hstack((self.eval_points(a), self.data.take(ix),
self.eval_points(b)))
return fi, xi
def integrate(self, a=None, b=None, **kwds): def integrate(self, a=None, b=None, **kwds):
''' '''
>>> x = np.linspace(0,5,60) >>> x = np.linspace(0,5,60)
@ -180,20 +199,7 @@ class PlotData(object):
raise NotImplementedError('integration for ndim>1 not implemented') raise NotImplementedError('integration for ndim>1 not implemented')
# One dimensional data # One dimensional data
return_ci = kwds.pop('return_ci', False) return_ci = kwds.pop('return_ci', False)
x = self.args fi, xi = self._get_fi_xi(a, b)
if a is None:
a = x[0]
if b is None:
b = x[-1]
ix = np.flatnonzero((a < x) & (x < b))
xi = np.hstack((a, x.take(ix), b))
if self.data.ndim > 1:
fi = np.vstack((self.eval_points(a),
self.data[ix, :],
self.eval_points(b))).T
else:
fi = np.hstack((self.eval_points(a), self.data.take(ix),
self.eval_points(b)))
res = fun(fi, xi, **kwds) res = fun(fi, xi, **kwds)
if return_ci: if return_ci:
res_ci = [child.integrate(a, b, method=method) res_ci = [child.integrate(a, b, method=method)
@ -201,13 +207,8 @@ class PlotData(object):
return np.hstack((res, np.ravel(res_ci))) return np.hstack((res, np.ravel(res_ci)))
return res return res
def plot(self, *args, **kwds): def _plot_children(self, axis, plotflag, kwds):
axis = kwds.pop('axis', None)
if axis is None:
axis = plt.gca()
tmp = None tmp = None
default_plotflag = self.plot_kwds.get('plotflag')
plotflag = kwds.get('plotflag', default_plotflag)
if not plotflag and self.children is not None: if not plotflag and self.children is not None:
axis.hold('on') axis.hold('on')
tmp = [] tmp = []
@ -220,8 +221,18 @@ class PlotData(object):
tmp1 = child.plot(*child_args, **child_kwds) tmp1 = child.plot(*child_args, **child_kwds)
if tmp1 is not None: if tmp1 is not None:
tmp.append(tmp1) tmp.append(tmp1)
if len(tmp) == 0: if tmp:
tmp = None return tmp
return None
def plot(self, *args, **kwargs):
kwds = kwargs.copy()
axis = kwds.pop('axis', None)
if axis is None:
axis = plt.gca()
default_plotflag = self.plot_kwds.get('plotflag')
plotflag = kwds.get('plotflag', default_plotflag)
tmp = self._plot_children(axis, plotflag, kwds)
main_args = args if len(args) else tuple(self.plot_args) main_args = args if len(args) else tuple(self.plot_args)
main_kwds = dict(self.plot_kwds).copy() main_kwds = dict(self.plot_kwds).copy()
main_kwds.update(kwds) main_kwds.update(kwds)
@ -279,23 +290,30 @@ class AxisLabels:
newcopy.__dict__.update(self.__dict__) newcopy.__dict__.update(self.__dict__)
return newcopy return newcopy
def _add_title_if_fun_is_set_title(self, txt, title0, fun):
if fun.startswith('set_title'):
if title0.lower().strip() != txt.lower().strip():
txt = title0 + '\n' + txt
return txt
def _labelfig(self, axis):
h = []
title0 = axis.get_title()
for fun, txt in zip(
('set_title', 'set_xlabel', 'set_ylabel', 'set_ylabel'),
(self.title, self.xlab, self.ylab, self.zlab)):
if txt:
txt = self._add_title_if_fun_is_set_title(txt, title0, fun)
h.append(getattr(axis, fun)(txt))
return h
def labelfig(self, axis=None): def labelfig(self, axis=None):
if axis is None: if axis is None:
axis = plt.gca() axis = plt.gca()
try: try:
h = [] return self._labelfig(axis)
for fun, txt in zip( except Exception as err:
('set_title', 'set_xlabel', 'set_ylabel', 'set_ylabel'), warnings.warn(str(err))
(self.title, self.xlab, self.ylab, self.zlab)):
if txt:
if fun.startswith('set_title'):
title0 = axis.get_title()
if title0.lower().strip() != txt.lower().strip():
txt = title0 + '\n' + txt
h.append(getattr(axis, fun)(txt))
return h
except:
pass
class Plotter_1d(object): class Plotter_1d(object):
@ -358,28 +376,29 @@ class Plotter_1d(object):
__call__ = plot __call__ = plot
def set_axis(axis, f_max, trans_flag, log_scale):
ax = list(axis.axis())
if trans_flag == 8 and not log_scale:
ax[3] = 11 * np.log10(f_max)
ax[2] = ax[3] - 40
else:
ax[3] = 1.15 * f_max
ax[2] = ax[3] * 1e-4
axis.axis(ax)
def set_plot_scale(axis, f_max, plotflag): def set_plot_scale(axis, f_max, plotflag):
scale = plotscale(plotflag) scale = plotscale(plotflag)
log_x_scale = 'x' in scale log_scale = False
log_y_scale = 'y' in scale for dim in ['x', 'y', 'z']:
log_z_scale = 'z' in scale if dim in scale:
if log_x_scale: log_scale = True
axis.set(xscale='log') opt = {'{}scale'.format(dim): 'log'}
if log_y_scale: axis.set(**opt)
axis.set(yscale='log')
if log_z_scale:
axis.set(zscale='log')
trans_flag = np.mod(plotflag // 10, 10) trans_flag = np.mod(plotflag // 10, 10)
log_scale = log_x_scale or log_y_scale or log_z_scale
if log_scale or (trans_flag == 5 and not log_scale): if log_scale or (trans_flag == 5 and not log_scale):
ax = list(axis.axis()) set_axis(axis, f_max, trans_flag, log_scale)
if trans_flag == 8 and not log_scale:
ax[3] = 11 * np.log10(f_max)
ax[2] = ax[3] - 40
else:
ax[3] = 1.15 * f_max
ax[2] = ax[3] * 1e-4
axis.axis(ax)
def plot1d(axis, args, data, dataCI, plotflag, *varargin, **kwds): def plot1d(axis, args, data, dataCI, plotflag, *varargin, **kwds):
@ -536,30 +555,42 @@ class Plotter_2d(Plotter_1d):
return h1 return h1
def _get_contour_levels(f):
isPL = False
PL = None
# check if contour levels is submitted
if hasattr(f, 'clevels') and len(f.clevels) > 0:
CL = f.clevels
isPL = hasattr(f, 'plevels') and f.plevels is not None
if isPL:
PL = f.plevels # levels defines quantile levels? 0=no 1=yes
else:
dmax = np.max(f.data)
dmin = np.min(f.data)
CL = dmax - (dmax - dmin) * \
(1 - np.r_[0.01, 0.025, 0.05, 0.1, 0.2, 0.4, 0.5, 0.75])
clvec = np.sort(CL)
return clvec, PL
def plot2d(axis, wdata, plotflag, *args, **kwds): def plot2d(axis, wdata, plotflag, *args, **kwds):
f = wdata f = wdata
if isinstance(wdata.args, (list, tuple)): if isinstance(wdata.args, (list, tuple)):
args1 = tuple((wdata.args)) + (wdata.data,) + args args1 = tuple((wdata.args)) + (wdata.data,) + args
else: else:
args1 = tuple((wdata.args,)) + (wdata.data,) + args args1 = tuple((wdata.args,)) + (wdata.data,) + args
pltfun = [None, axis.contour, axis.mesh, axis.surf, axis.waterfal,
axis.pcolor, axis.contour, axis.contour, axis.contour,
axis.contour, axis.contourf][plotflag]
if plotflag in (1, 6, 7, 8, 9): if plotflag in (1, 6, 7, 8, 9):
isPL = False clvec, PL = _get_contour_levels(f)
# check if contour levels is submitted
if hasattr(f, 'clevels') and len(f.clevels) > 0:
CL = f.clevels
isPL = hasattr(f, 'plevels') and f.plevels is not None
if isPL:
PL = f.plevels # levels defines quantile levels? 0=no 1=yes
else:
dmax = np.max(f.data)
dmin = np.min(f.data)
CL = dmax - (dmax - dmin) * \
(1 - np.r_[0.01, 0.025, 0.05, 0.1, 0.2, 0.4, 0.5, 0.75])
clvec = np.sort(CL)
if plotflag in [1, 8, 9]: if plotflag in [1, 8, 9]:
h = axis.contour(*args1, levels=clvec, **kwds) h = pltfun(*args1, levels=clvec, **kwds)
# else: else:
h = pltfun(*args1, **kwds)
# [cs hcs] = contour3(f.x{:},f.f,CL,sym); # [cs hcs] = contour3(f.x{:},f.f,CL,sym);
if plotflag in (1, 6): if plotflag in (1, 6):
@ -569,31 +600,16 @@ def plot2d(axis, wdata, plotflag, *args, **kwds):
warnings.warn( warnings.warn(
'Only the first 12 levels will be listed in table.') 'Only the first 12 levels will be listed in table.')
isPL = PL is not None
clvals = PL[:ncl] if isPL else clvec[:ncl] clvals = PL[:ncl] if isPL else clvec[:ncl]
unused_axcl = cltext(clvals, percent=isPL) unused_axcl = cltext(clvals, percent=isPL)
elif any(plotflag == [7, 9]):
axis.clabel(h)
else: else:
axis.clabel(h) axis.clabel(h)
elif plotflag == 2:
h = axis.mesh(*args1, **kwds)
elif plotflag == 3:
# shading interp % flat, faceted % surfc
h = axis.surf(*args1, **kwds)
elif plotflag == 4:
h = axis.waterfall(*args1, **kwds)
elif plotflag == 5:
h = axis.pcolor(*args1, **kwds) # %shading interp % flat, faceted
elif plotflag == 10:
h = axis.contourf(*args1, **kwds)
axis.clabel(h)
plt.colorbar(h)
else: else:
raise ValueError('unknown option for plotflag') h = pltfun(*args1, **kwds)
# if any(plotflag==(2:5)) if plotflag == 10:
# shading(shad); axis.clabel(h)
# end plt.colorbar(h)
# pass
def test_plotdata(): def test_plotdata():
@ -613,12 +629,7 @@ def test_plotdata():
d.show('hold') d.show('hold')
def test_docstrings():
import doctest
print('Testing docstrings in %s' % __file__)
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
if __name__ == '__main__': if __name__ == '__main__':
test_docstrings() from wafo.testing import test_docstrings
test_docstrings(__file__)
# test_plotdata() # test_plotdata()

Loading…
Cancel
Save