Simplified PlotData in containers.py

master
pbrod 8 years ago
parent 2a492b4cf1
commit a669e556b7

@ -7,6 +7,7 @@ import numpy as np
from scipy.integrate.quadrature import cumtrapz # @UnresolvedImport
from scipy import interpolate
from scipy import integrate
from _warnings import warn
__all__ = ['PlotData', 'AxisLabels']
@ -161,6 +162,24 @@ class PlotData(object):
cdf = np.hstack((0, cumtrapz(self.data, self.args)))
return PlotData(cdf, np.copy(self.args), xlab='x', ylab='F(x)')
def _get_fi_xi(self, a, b):
x = self.args
if a is None:
a = x[0]
if b is None:
b = x[-1]
ix = np.flatnonzero((a < x) & (x < b))
xi = np.hstack((a, x.take(ix), b))
if self.data.ndim > 1:
fi = np.vstack((self.eval_points(a),
self.data[ix, :],
self.eval_points(b))).T
else:
fi = np.hstack((self.eval_points(a), self.data.take(ix),
self.eval_points(b)))
return fi, xi
def integrate(self, a=None, b=None, **kwds):
'''
>>> x = np.linspace(0,5,60)
@ -180,20 +199,7 @@ class PlotData(object):
raise NotImplementedError('integration for ndim>1 not implemented')
# One dimensional data
return_ci = kwds.pop('return_ci', False)
x = self.args
if a is None:
a = x[0]
if b is None:
b = x[-1]
ix = np.flatnonzero((a < x) & (x < b))
xi = np.hstack((a, x.take(ix), b))
if self.data.ndim > 1:
fi = np.vstack((self.eval_points(a),
self.data[ix, :],
self.eval_points(b))).T
else:
fi = np.hstack((self.eval_points(a), self.data.take(ix),
self.eval_points(b)))
fi, xi = self._get_fi_xi(a, b)
res = fun(fi, xi, **kwds)
if return_ci:
res_ci = [child.integrate(a, b, method=method)
@ -201,13 +207,8 @@ class PlotData(object):
return np.hstack((res, np.ravel(res_ci)))
return res
def plot(self, *args, **kwds):
axis = kwds.pop('axis', None)
if axis is None:
axis = plt.gca()
def _plot_children(self, axis, plotflag, kwds):
tmp = None
default_plotflag = self.plot_kwds.get('plotflag')
plotflag = kwds.get('plotflag', default_plotflag)
if not plotflag and self.children is not None:
axis.hold('on')
tmp = []
@ -220,8 +221,18 @@ class PlotData(object):
tmp1 = child.plot(*child_args, **child_kwds)
if tmp1 is not None:
tmp.append(tmp1)
if len(tmp) == 0:
tmp = None
if tmp:
return tmp
return None
def plot(self, *args, **kwargs):
kwds = kwargs.copy()
axis = kwds.pop('axis', None)
if axis is None:
axis = plt.gca()
default_plotflag = self.plot_kwds.get('plotflag')
plotflag = kwds.get('plotflag', default_plotflag)
tmp = self._plot_children(axis, plotflag, kwds)
main_args = args if len(args) else tuple(self.plot_args)
main_kwds = dict(self.plot_kwds).copy()
main_kwds.update(kwds)
@ -279,23 +290,30 @@ class AxisLabels:
newcopy.__dict__.update(self.__dict__)
return newcopy
def _add_title_if_fun_is_set_title(self, txt, title0, fun):
if fun.startswith('set_title'):
if title0.lower().strip() != txt.lower().strip():
txt = title0 + '\n' + txt
return txt
def _labelfig(self, axis):
h = []
title0 = axis.get_title()
for fun, txt in zip(
('set_title', 'set_xlabel', 'set_ylabel', 'set_ylabel'),
(self.title, self.xlab, self.ylab, self.zlab)):
if txt:
txt = self._add_title_if_fun_is_set_title(txt, title0, fun)
h.append(getattr(axis, fun)(txt))
return h
def labelfig(self, axis=None):
if axis is None:
axis = plt.gca()
try:
h = []
for fun, txt in zip(
('set_title', 'set_xlabel', 'set_ylabel', 'set_ylabel'),
(self.title, self.xlab, self.ylab, self.zlab)):
if txt:
if fun.startswith('set_title'):
title0 = axis.get_title()
if title0.lower().strip() != txt.lower().strip():
txt = title0 + '\n' + txt
h.append(getattr(axis, fun)(txt))
return h
except:
pass
return self._labelfig(axis)
except Exception as err:
warnings.warn(str(err))
class Plotter_1d(object):
@ -358,28 +376,29 @@ class Plotter_1d(object):
__call__ = plot
def set_axis(axis, f_max, trans_flag, log_scale):
ax = list(axis.axis())
if trans_flag == 8 and not log_scale:
ax[3] = 11 * np.log10(f_max)
ax[2] = ax[3] - 40
else:
ax[3] = 1.15 * f_max
ax[2] = ax[3] * 1e-4
axis.axis(ax)
def set_plot_scale(axis, f_max, plotflag):
scale = plotscale(plotflag)
log_x_scale = 'x' in scale
log_y_scale = 'y' in scale
log_z_scale = 'z' in scale
if log_x_scale:
axis.set(xscale='log')
if log_y_scale:
axis.set(yscale='log')
if log_z_scale:
axis.set(zscale='log')
log_scale = False
for dim in ['x', 'y', 'z']:
if dim in scale:
log_scale = True
opt = {'{}scale'.format(dim): 'log'}
axis.set(**opt)
trans_flag = np.mod(plotflag // 10, 10)
log_scale = log_x_scale or log_y_scale or log_z_scale
if log_scale or (trans_flag == 5 and not log_scale):
ax = list(axis.axis())
if trans_flag == 8 and not log_scale:
ax[3] = 11 * np.log10(f_max)
ax[2] = ax[3] - 40
else:
ax[3] = 1.15 * f_max
ax[2] = ax[3] * 1e-4
axis.axis(ax)
set_axis(axis, f_max, trans_flag, log_scale)
def plot1d(axis, args, data, dataCI, plotflag, *varargin, **kwds):
@ -536,30 +555,42 @@ class Plotter_2d(Plotter_1d):
return h1
def _get_contour_levels(f):
isPL = False
PL = None
# check if contour levels is submitted
if hasattr(f, 'clevels') and len(f.clevels) > 0:
CL = f.clevels
isPL = hasattr(f, 'plevels') and f.plevels is not None
if isPL:
PL = f.plevels # levels defines quantile levels? 0=no 1=yes
else:
dmax = np.max(f.data)
dmin = np.min(f.data)
CL = dmax - (dmax - dmin) * \
(1 - np.r_[0.01, 0.025, 0.05, 0.1, 0.2, 0.4, 0.5, 0.75])
clvec = np.sort(CL)
return clvec, PL
def plot2d(axis, wdata, plotflag, *args, **kwds):
f = wdata
if isinstance(wdata.args, (list, tuple)):
args1 = tuple((wdata.args)) + (wdata.data,) + args
else:
args1 = tuple((wdata.args,)) + (wdata.data,) + args
pltfun = [None, axis.contour, axis.mesh, axis.surf, axis.waterfal,
axis.pcolor, axis.contour, axis.contour, axis.contour,
axis.contour, axis.contourf][plotflag]
if plotflag in (1, 6, 7, 8, 9):
isPL = False
# check if contour levels is submitted
if hasattr(f, 'clevels') and len(f.clevels) > 0:
CL = f.clevels
isPL = hasattr(f, 'plevels') and f.plevels is not None
if isPL:
PL = f.plevels # levels defines quantile levels? 0=no 1=yes
else:
dmax = np.max(f.data)
dmin = np.min(f.data)
CL = dmax - (dmax - dmin) * \
(1 - np.r_[0.01, 0.025, 0.05, 0.1, 0.2, 0.4, 0.5, 0.75])
clvec = np.sort(CL)
clvec, PL = _get_contour_levels(f)
if plotflag in [1, 8, 9]:
h = axis.contour(*args1, levels=clvec, **kwds)
# else:
h = pltfun(*args1, levels=clvec, **kwds)
else:
h = pltfun(*args1, **kwds)
# [cs hcs] = contour3(f.x{:},f.f,CL,sym);
if plotflag in (1, 6):
@ -569,31 +600,16 @@ def plot2d(axis, wdata, plotflag, *args, **kwds):
warnings.warn(
'Only the first 12 levels will be listed in table.')
isPL = PL is not None
clvals = PL[:ncl] if isPL else clvec[:ncl]
unused_axcl = cltext(clvals, percent=isPL)
elif any(plotflag == [7, 9]):
axis.clabel(h)
else:
axis.clabel(h)
elif plotflag == 2:
h = axis.mesh(*args1, **kwds)
elif plotflag == 3:
# shading interp % flat, faceted % surfc
h = axis.surf(*args1, **kwds)
elif plotflag == 4:
h = axis.waterfall(*args1, **kwds)
elif plotflag == 5:
h = axis.pcolor(*args1, **kwds) # %shading interp % flat, faceted
elif plotflag == 10:
h = axis.contourf(*args1, **kwds)
axis.clabel(h)
plt.colorbar(h)
else:
raise ValueError('unknown option for plotflag')
# if any(plotflag==(2:5))
# shading(shad);
# end
# pass
h = pltfun(*args1, **kwds)
if plotflag == 10:
axis.clabel(h)
plt.colorbar(h)
def test_plotdata():
@ -613,12 +629,7 @@ def test_plotdata():
d.show('hold')
def test_docstrings():
import doctest
print('Testing docstrings in %s' % __file__)
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
if __name__ == '__main__':
test_docstrings()
from wafo.testing import test_docstrings
test_docstrings(__file__)
# test_plotdata()

Loading…
Cancel
Save