You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pywafo/wafo/containers.py

625 lines
21 KiB
Python

from __future__ import absolute_import
12 years ago
import warnings
9 years ago
from wafo.graphutil import cltext
from wafo.plotbackend import plotbackend as plt
12 years ago
from time import gmtime, strftime
import numpy as np
from scipy.integrate.quadrature import cumtrapz # @UnresolvedImport
12 years ago
from scipy import interpolate
from scipy import integrate
__all__ = ['PlotData', 'AxisLabels']
12 years ago
def empty_copy(obj):
class Empty(obj.__class__):
12 years ago
def __init__(self):
pass
newcopy = Empty()
newcopy.__class__ = obj.__class__
return newcopy
12 years ago
def now():
'''
Return current date and time as a string
'''
return strftime("%a, %d %b %Y %H:%M:%S", gmtime())
12 years ago
class PlotData(object):
12 years ago
'''
Container class for data with interpolation and plotting methods
Member variables
----------------
data : array_like
args : vector for 1D, list of vectors for 2D, 3D, ...
labels : AxisLabels
children : list of PlotData objects
plot_args_children : list of arguments to the children plots
plot_kwds_children : dict of keyword arguments to the children plots
plot_args : list of arguments to the main plot
plot_kwds : dict of keyword arguments to the main plot
Member methods
--------------
copy : return a copy of object
eval_points : interpolate data at given points and return the result
plot : plot data on given axis and the object handles
12 years ago
Example
-------
>>> import numpy as np
>>> x = np.linspace(0, np.pi, 9)
12 years ago
# Plot 2 objects in one call
>>> d2 = PlotData(np.sin(x), x, xlab='x', ylab='sin', title='sinus',
... plot_args=['r.'])
>>> h = d2.plot()
>>> h1 = d2()
12 years ago
# Plot with confidence interval
8 years ago
>>> ci = PlotData(np.vstack([np.sin(x)*0.9, np.sin(x)*1.2]).T, x,
... plot_args=[':r'])
>>> d3 = PlotData(np.sin(x), x, children=[ci])
>>> h = d3.plot() # plot data, CI red dotted line
>>> h = d3.plot(plot_args_children=['b--']) # CI with blue dashed line
12 years ago
'''
12 years ago
def __init__(self, data=None, args=None, *args2, **kwds):
self.data = data
self.args = args
self.date = now()
self.plotter = kwds.pop('plotter', None)
self.children = kwds.pop('children', None)
12 years ago
self.plot_args_children = kwds.pop('plot_args_children', [])
self.plot_kwds_children = kwds.pop('plot_kwds_children', {})
self.plot_args = kwds.pop('plot_args', [])
self.plot_kwds = kwds.pop('plot_kwds', {})
12 years ago
self.labels = AxisLabels(**kwds)
if not self.plotter:
self.setplotter(kwds.get('plotmethod'))
12 years ago
def copy(self):
newcopy = empty_copy(self)
newcopy.__dict__.update(self.__dict__)
return newcopy
12 years ago
def eval_points(self, *points, **kwds):
'''
Interpolate data at points
12 years ago
Parameters
----------
points : ndarray of float, shape (..., ndim)
Points where to interpolate data at.
method : {'linear', 'nearest', 'cubic'}
method : {'linear', 'nearest', 'cubic'}
Method of interpolation. One of
- ``nearest``: return the value at the data point closest to
the point of interpolation.
12 years ago
- ``linear``: tesselate the input point set to n-dimensional
simplices, and interpolate linearly on each simplex.
12 years ago
- ``cubic`` (1-D): return the value detemined from a cubic
spline.
- ``cubic`` (2-D): return the value determined from a
piecewise cubic, continuously differentiable (C1), and
approximately curvature-minimizing polynomial surface.
fill_value : float, optional
Value used to fill in for requested points outside of the
convex hull of the input points. If not provided, then the
default is ``nan``. This option has no effect for the
'nearest' method.
12 years ago
Examples
--------
>>> import numpy as np
>>> x = np.arange(-2, 2, 0.4)
>>> xi = np.arange(-2, 2, 0.1)
>>> d = PlotData(np.sin(x), x, xlab='x', ylab='sin', title='sinus',
... plot_args=['r.'])
12 years ago
>>> di = PlotData(d.eval_points(xi), xi)
9 years ago
>>> hi = di.plot()
>>> h = d.plot()
>>> dicdf = di.to_cdf()
>>> h = dicdf.plot()
12 years ago
See also
--------
scipy.interpolate.griddata
'''
options = dict(method='linear')
options.update(**kwds)
if isinstance(self.args, (list, tuple)): # Multidimensional data
12 years ago
ndim = len(self.args)
if ndim < 2:
msg = '''
Unable to determine plotter-type, because len(self.args)<2.
12 years ago
If the data is 1D, then self.args should be a vector!
If the data is 2D, then length(self.args) should be 2.
If the data is 3D, then length(self.args) should be 3.
Unless you fix this, the interpolation will not work!'''
warnings.warn(msg)
else:
xi = np.meshgrid(*self.args)
return interpolate.griddata(xi, self.data.ravel(), points,
**options)
# One dimensional data
return interpolate.griddata(self.args, self.data, points, **options)
9 years ago
def to_cdf(self):
if isinstance(self.args, (list, tuple)): # Multidimensional data
raise NotImplementedError('integration for ndim>1 not implemented')
cdf = np.hstack((0, cumtrapz(self.data, self.args)))
9 years ago
return PlotData(cdf, np.copy(self.args), xlab='x', ylab='F(x)')
9 years ago
def integrate(self, a=None, b=None, **kwds):
12 years ago
'''
>>> x = np.linspace(0,5,60)
8 years ago
>>> y = np.sin(x)
>>> ci = PlotData(np.vstack((y*.9, y*1.1)).T, x)
>>> d = PlotData(y, x, children=[ci])
>>> d.integrate(0, np.pi/2, return_ci=True)
array([ 0.99940055, 0.89946049, 1.0993406 ])
9 years ago
>>> np.allclose(d.integrate(0, 5, return_ci=True),
... d.integrate(return_ci=True))
True
12 years ago
'''
method = kwds.pop('method', 'trapz')
12 years ago
fun = getattr(integrate, method)
if isinstance(self.args, (list, tuple)): # Multidimensional data
12 years ago
raise NotImplementedError('integration for ndim>1 not implemented')
# One dimensional data
return_ci = kwds.pop('return_ci', False)
x = self.args
if a is None:
a = x[0]
if b is None:
b = x[-1]
ix = np.flatnonzero((a < x) & (x < b))
xi = np.hstack((a, x.take(ix), b))
8 years ago
if self.data.ndim > 1:
fi = np.vstack((self.eval_points(a),
self.data[ix, :],
self.eval_points(b))).T
else:
fi = np.hstack((self.eval_points(a), self.data.take(ix),
self.eval_points(b)))
res = fun(fi, xi, **kwds)
if return_ci:
8 years ago
res_ci = [child.integrate(a, b, method=method)
for child in self.children]
8 years ago
return np.hstack((res, np.ravel(res_ci)))
return res
12 years ago
def plot(self, *args, **kwds):
axis = kwds.pop('axis', None)
12 years ago
if axis is None:
axis = plt.gca()
12 years ago
tmp = None
default_plotflag = self.plot_kwds.get('plotflag')
12 years ago
plotflag = kwds.get('plotflag', default_plotflag)
if not plotflag and self.children is not None:
12 years ago
axis.hold('on')
tmp = []
8 years ago
child_args = kwds.pop('plot_args_children',
tuple(self.plot_args_children))
12 years ago
child_kwds = dict(self.plot_kwds_children).copy()
child_kwds.update(kwds.pop('plot_kwds_children', {}))
child_kwds['axis'] = axis
for child in self.children:
8 years ago
tmp1 = child.plot(*child_args, **child_kwds)
if tmp1 is not None:
12 years ago
tmp.append(tmp1)
if len(tmp) == 0:
tmp = None
main_args = args if len(args) else tuple(self.plot_args)
main_kwds = dict(self.plot_kwds).copy()
main_kwds.update(kwds)
main_kwds['axis'] = axis
tmp2 = self.plotter.plot(self, *main_args, **main_kwds)
return tmp2, tmp
def setplotter(self, plotmethod=None):
'''
Set plotter based on the data type:
data_1d, data_2d, data_3d or data_nd
12 years ago
'''
if isinstance(self.args, (list, tuple)): # Multidimensional data
12 years ago
ndim = len(self.args)
if ndim < 2:
msg = '''
Unable to determine plotter-type, because len(self.args)<2.
12 years ago
If the data is 1D, then self.args should be a vector!
If the data is 2D, then length(self.args) should be 2.
If the data is 3D, then length(self.args) should be 3.
Unless you fix this, the plot methods will not work!'''
warnings.warn(msg)
elif ndim == 2:
self.plotter = Plotter_2d(plotmethod)
else:
warnings.warn('Plotter method not implemented for ndim>2')
else: # One dimensional data
12 years ago
self.plotter = Plotter_1d(plotmethod)
def show(self, *args, **kwds):
self.plotter.show(*args, **kwds)
12 years ago
__call__ = plot
interpolate = eval_points
12 years ago
class AxisLabels:
12 years ago
def __init__(self, title='', xlab='', ylab='', zlab='', **kwds):
self.title = title
self.xlab = xlab
self.ylab = ylab
self.zlab = zlab
12 years ago
def __repr__(self):
return self.__str__()
12 years ago
def __str__(self):
8 years ago
txt = 'AxisLabels(title={}, xlab={}, ylab={}, zlab={})'
return txt.format(self.title, self.xlab, self.ylab, self.zlab)
12 years ago
def copy(self):
newcopy = empty_copy(self)
newcopy.__dict__.update(self.__dict__)
return newcopy
12 years ago
def labelfig(self, axis=None):
if axis is None:
axis = plt.gca()
12 years ago
try:
h = []
for fun, txt in zip(
('set_title', 'set_xlabel', 'set_ylabel', 'set_ylabel'),
(self.title, self.xlab, self.ylab, self.zlab)):
12 years ago
if txt:
if fun.startswith('set_title'):
title0 = axis.get_title()
if title0.lower().strip() != txt.lower().strip():
txt = title0 + '\n' + txt
12 years ago
h.append(getattr(axis, fun)(txt))
return h
except:
pass
12 years ago
class Plotter_1d(object):
12 years ago
"""
Parameters
----------
plotmethod : string
defining type of plot. Options are:
bar : bar plot with rectangles
barh : horizontal bar plot with rectangles
loglog : plot with log scaling on the *x* and *y* axis
semilogx : plot with log scaling on the *x* axis
semilogy : plot with log scaling on the *y* axis
plot : Plot lines and/or markers (default)
stem : Stem plot
step : stair-step plot
scatter : scatter plot
"""
12 years ago
def __init__(self, plotmethod='plot'):
self.plotfun = None
if plotmethod is None:
plotmethod = 'plot'
self.plotmethod = plotmethod
def show(self, *args, **kwds):
plt.show(*args, **kwds)
12 years ago
def plot(self, wdata, *args, **kwds):
axis = kwds.pop('axis', None)
12 years ago
if axis is None:
axis = plt.gca()
12 years ago
plotflag = kwds.pop('plotflag', False)
if plotflag:
h1 = self._plot(axis, plotflag, wdata, *args, **kwds)
else:
if isinstance(wdata.data, (list, tuple)):
vals = tuple(wdata.data)
else:
vals = (wdata.data,)
if isinstance(wdata.args, (list, tuple)):
args1 = tuple((wdata.args)) + vals + args
else:
args1 = tuple((wdata.args,)) + vals + args
plotfun = getattr(axis, self.plotmethod)
h1 = plotfun(*args1, **kwds)
h2 = wdata.labels.labelfig(axis)
return h1, h2
12 years ago
def _plot(self, axis, plotflag, wdata, *args, **kwds):
x = wdata.args
8 years ago
data = transformdata_1d(x, wdata.data, plotflag)
12 years ago
dataCI = getattr(wdata, 'dataCI', ())
8 years ago
if dataCI:
dataCI = transformdata_1d(x, dataCI, plotflag)
12 years ago
h1 = plot1d(axis, x, data, dataCI, plotflag, *args, **kwds)
return h1
__call__ = plot
8 years ago
def set_plot_scale(axis, f_max, plotflag):
12 years ago
scale = plotscale(plotflag)
8 years ago
log_x_scale = 'x' in scale
log_y_scale = 'y' in scale
log_z_scale = 'z' in scale
if log_x_scale:
12 years ago
axis.set(xscale='log')
8 years ago
if log_y_scale:
axis.set(yscale='log')
8 years ago
if log_z_scale:
12 years ago
axis.set(zscale='log')
8 years ago
trans_flag = np.mod(plotflag // 10, 10)
log_scale = log_x_scale or log_y_scale or log_z_scale
if log_scale or (trans_flag == 5 and not log_scale):
12 years ago
ax = list(axis.axis())
8 years ago
if trans_flag == 8 and not log_scale:
ax[3] = 11 * np.log10(f_max)
12 years ago
ax[2] = ax[3] - 40
else:
8 years ago
ax[3] = 1.15 * f_max
ax[2] = ax[3] * 1e-4
12 years ago
axis.axis(ax)
8 years ago
def plot1d(axis, args, data, dataCI, plotflag, *varargin, **kwds):
h = []
plottype = np.mod(plotflag, 10)
if plottype == 0: # No plotting
return h
fun = {1: 'plot', 2: 'step', 3: 'stem', 5: 'bar'}.get(plottype)
if fun is not None:
plotfun = getattr(axis, fun)
h.extend(plotfun(args, data, *varargin, **kwds))
if np.any(dataCI) and plottype < 3:
axis.hold(True)
h.extend(plotfun(args, dataCI, 'r--'))
elif plottype == 4:
h = axis.errorbar(args, data,
yerr=[dataCI[:, 0] - data,
dataCI[:, 1] - data],
*varargin, **kwds)
elif plottype == 6:
h = axis.fill_between(args, data, 0, *varargin, **kwds)
elif plottype == 7:
h = axis.plot(args, data, *varargin, **kwds)
h.extend(axis.fill_between(args, dataCI[:, 0], dataCI[:, 1],
alpha=0.2, color='r'))
fmax1 = data.max()
set_plot_scale(axis, fmax1, plotflag)
return h
12 years ago
12 years ago
def plotscale(plotflag):
'''
Return plotscale from plotflag
12 years ago
CALL scale = plotscale(plotflag)
12 years ago
plotflag = integer defining plotscale.
Let scaleId = floor(plotflag/100).
12 years ago
If scaleId < 8 then:
0 'linear' : Linear scale on all axes.
1 'xlog' : Log scale on x-axis.
2 'ylog' : Log scale on y-axis.
3 'xylog' : Log scale on xy-axis.
4 'zlog' : Log scale on z-axis.
5 'xzlog' : Log scale on xz-axis.
6 'yzlog' : Log scale on yz-axis.
7 'xyzlog' : Log scale on xyz-axis.
otherwise
if (mod(scaleId,10)>0) : Log scale on x-axis.
if (mod(floor(scaleId/10),10)>0) : Log scale on y-axis.
if (mod(floor(scaleId/100),10)>0) : Log scale on z-axis.
scale = string defining plotscale valid options are:
12 years ago
'linear', 'xlog', 'ylog', 'xylog', 'zlog', 'xzlog',
'yzlog', 'xyzlog'
Examples
--------
>>> for i in range(7):
... plotscale(i*100)
'linear'
'xlog'
'ylog'
'xylog'
'zlog'
'xzlog'
'yzlog'
>>> plotscale(100)
'xlog'
>>> plotscale(1000)
'ylog'
>>> plotscale(10000)
'zlog'
8 years ago
>>> plotscale(1100)
'xylog'
>>> plotscale(11100)
'xyzlog'
See also
---------
plotscale
12 years ago
'''
scaleId = plotflag // 100
if scaleId > 7:
logXscaleId = np.mod(scaleId, 10) > 0
logYscaleId = (np.mod(scaleId // 10, 10) > 0) * 2
logZscaleId = (np.mod(scaleId // 100, 10) > 0) * 4
scaleId = logYscaleId + logXscaleId + logZscaleId
scales = ['linear', 'xlog', 'ylog', 'xylog', 'zlog', 'xzlog',
'yzlog', 'xyzlog']
12 years ago
return scales[scaleId]
8 years ago
def plotflag2plottype_1d(plotflag):
plottype = np.mod(plotflag, 10)
return ['', 'plot', 'step', 'stem', 'errorbar', 'bar'][plottype]
def plotflag2transform_id(plotflag):
transform_id = np.mod(plotflag // 10, 10)
return ['f', '1-f',
'cumtrapz(f)', '1-cumtrapz(f)',
'log(f)', 'log(1-f)'
'log(cumtrapz(f))', 'log(cumtrapz(f))',
'log10(f)'][transform_id]
def transform_id2plotflag2(transform_id):
return {'': 0, 'None': 0, 'f': 0, '1-f': 1,
'cumtrapz(f)': 2, '1-cumtrapz(f)': 3,
'log(f)': 4, 'log(1-f)': 5,
'log(cumtrapz(f))': 6, 'log(1-cumtrapz(f))': 7,
'10log10(f)': 8}[transform_id] * 10
def transformdata_1d(x, f, plotflag):
transform_id = np.mod(plotflag // 10, 10)
transform = [lambda f, x: f,
lambda f, x: 1 - f,
lambda f, x: cumtrapz(f, x),
lambda f, x: 1 - cumtrapz(f, x),
lambda f, x: np.log(f),
lambda f, x: np.log1p(-f),
lambda f, x: np.log(cumtrapz(f, x)),
lambda f, x: np.log1p(-cumtrapz(f, x)),
lambda f, x: 10*np.log10(f)
][transform_id]
return transform(f, x)
12 years ago
12 years ago
class Plotter_2d(Plotter_1d):
12 years ago
"""
Parameters
----------
plotmethod : string
defining type of plot. Options are:
contour (default)
contourf
mesh
surf
"""
def __init__(self, plotmethod='contour'):
if plotmethod is None:
plotmethod = 'contour'
super(Plotter_2d, self).__init__(plotmethod)
12 years ago
def _plot(self, axis, plotflag, wdata, *args, **kwds):
h1 = plot2d(axis, wdata, plotflag, *args, **kwds)
return h1
12 years ago
def plot2d(axis, wdata, plotflag, *args, **kwds):
f = wdata
if isinstance(wdata.args, (list, tuple)):
args1 = tuple((wdata.args)) + (wdata.data,) + args
else:
args1 = tuple((wdata.args,)) + (wdata.data,) + args
if plotflag in (1, 6, 7, 8, 9):
isPL = False
# check if contour levels is submitted
if hasattr(f, 'clevels') and len(f.clevels) > 0:
12 years ago
CL = f.clevels
isPL = hasattr(f, 'plevels') and f.plevels is not None
12 years ago
if isPL:
PL = f.plevels # levels defines quantile levels? 0=no 1=yes
12 years ago
else:
dmax = np.max(f.data)
dmin = np.min(f.data)
CL = dmax - (dmax - dmin) * \
(1 - np.r_[0.01, 0.025, 0.05, 0.1, 0.2, 0.4, 0.5, 0.75])
12 years ago
clvec = np.sort(CL)
12 years ago
if plotflag in [1, 8, 9]:
9 years ago
h = axis.contour(*args1, levels=clvec, **kwds)
# else:
12 years ago
# [cs hcs] = contour3(f.x{:},f.f,CL,sym);
12 years ago
if plotflag in (1, 6):
ncl = len(clvec)
if ncl > 12:
ncl = 12
warnings.warn(
'Only the first 12 levels will be listed in table.')
12 years ago
clvals = PL[:ncl] if isPL else clvec[:ncl]
9 years ago
unused_axcl = cltext(clvals, percent=isPL)
12 years ago
elif any(plotflag == [7, 9]):
axis.clabel(h)
else:
axis.clabel(h)
elif plotflag == 2:
h = axis.mesh(*args1, **kwds)
elif plotflag == 3:
# shading interp % flat, faceted % surfc
h = axis.surf(*args1, **kwds)
elif plotflag == 4:
12 years ago
h = axis.waterfall(*args1, **kwds)
elif plotflag == 5:
h = axis.pcolor(*args1, **kwds) # %shading interp % flat, faceted
12 years ago
elif plotflag == 10:
h = axis.contourf(*args1, **kwds)
axis.clabel(h)
plt.colorbar(h)
12 years ago
else:
raise ValueError('unknown option for plotflag')
# if any(plotflag==(2:5))
12 years ago
# shading(shad);
# end
12 years ago
# pass
12 years ago
def test_plotdata():
plt.ioff()
x = np.linspace(0, np.pi, 9)
xi = np.linspace(0, np.pi, 4*9)
8 years ago
d = PlotData(np.sin(x)/2, x, dataCI=[], xlab='x', ylab='sin',
title='sinus', plot_args=['r.'])
12 years ago
di = PlotData(d.eval_points(xi, method='cubic'), xi)
unused_hi = di.plot()
unused_h = d.plot()
f = di.to_cdf()
8 years ago
for i in range(4):
8 years ago
_ = di.plot(plotflag=i)
d.show('hold')
12 years ago
12 years ago
def test_docstrings():
import doctest
print('Testing docstrings in %s' % __file__)
doctest.testmod(optionflags=doctest.NORMALIZE_WHITESPACE)
12 years ago
if __name__ == '__main__':
8 years ago
test_docstrings()
# test_plotdata()