Added eval_points and integrate to the WafoData class

master
per.andreas.brodtkorb 13 years ago
parent 52fbfae045
commit fbadd0b3bb

@ -6,7 +6,7 @@ from numpy import pi, sqrt, ones, zeros #@UnresolvedImport
from scipy import integrate as intg
import scipy.special.orthogonal as ort
from scipy import special as sp
import pylab as plb
from wafo.plotbackend import plotbackend as plt
from scipy.integrate import simps, trapz
from wafo.misc import is_numlike
from wafo.demos import humps
@ -187,7 +187,7 @@ def clencurt(fun, a, b, n0=5, trace=False, *args):
f = np.flipud(fun[:, 1::])
if trace:
plb.plot(x, f, '+')
plt.plot(x, f, '+')
# using a Gauss-Lobatto variant, i.e., first and last
# term f(a) and f(b) is multiplied with 0.5
@ -1087,13 +1087,13 @@ def gaussq(fun, a, b, reltol=1e-3, abstol=1e-3, alpha=0, beta=0, wfun=1,
x_trace.append(x.ravel())
y_trace.append(y.ravel())
hfig = plb.plot(x, y, 'r.')
hfig = plt.plot(x, y, 'r.')
#hold on
#drawnow,shg
#if trace>1:
# pause
plb.setp(hfig, 'color', 'b')
plt.setp(hfig, 'color', 'b')
abserr[k] = abs(val_old[k] - val[k]) #absolute tolerance
@ -1122,8 +1122,8 @@ def gaussq(fun, a, b, reltol=1e-3, abstol=1e-3, alpha=0, beta=0, wfun=1,
abserr.shape = a_shape
if trace > 0:
plb.clf()
plb.plot(np.hstack(x_trace), np.hstack(y_trace), '+')
plt.clf()
plt.plot(np.hstack(x_trace), np.hstack(y_trace), '+')
return val, abserr
def richardson(Q, k):
@ -1430,10 +1430,10 @@ def qdemo(f, a, b):
print(''.join(fi % t for fi, t in zip(formats, tmp)))
plb.loglog(neval, np.vstack((et, es, eb, ec, ec2, eg)).T)
plb.xlabel('number of function evaluations')
plb.ylabel('error')
plb.legend(('Trapezoid', 'Simpsons', 'Booles', 'Clenshaw', 'Chebychev', 'Gauss-L'))
plt.loglog(neval, np.vstack((et, es, eb, ec, ec2, eg)).T)
plt.xlabel('number of function evaluations')
plt.ylabel('error')
plt.legend(('Trapezoid', 'Simpsons', 'Booles', 'Clenshaw', 'Chebychev', 'Gauss-L'))
#ec3'

@ -4,6 +4,8 @@ from plotbackend import plotbackend
from time import gmtime, strftime
import numpy as np
from scipy.integrate.quadrature import cumtrapz #@UnresolvedImport
from scipy.interpolate import griddata
from scipy import integrate
__all__ = ['WafoData', 'AxisLabels']
@ -105,6 +107,59 @@ class WafoData(object):
tmp2 = self.plotter.plot(self, *main_args, **main_kwds)
return tmp2, tmp
def eval_points(self, *args, **kwds):
'''
>>> x = np.linspace(0,5,20)
>>> d = WafoData(np.sin(x),x)
>>> xi = np.linspace(0,5,60)
>>> di = WafoData(d.eval_points(xi, method='cubic'),xi)
>>> d.plot('.')
>>> di.plot()
'''
if isinstance(self.args, (list, tuple)): # Multidimensional data
ndim = len(self.args)
if ndim < 2:
msg = '''Unable to determine plotter-type, because len(self.args)<2.
If the data is 1D, then self.args should be a vector!
If the data is 2D, then length(self.args) should be 2.
If the data is 3D, then length(self.args) should be 3.
Unless you fix this, the plot methods will not work!'''
warnings.warn(msg)
else:
return griddata(self.args, self.data.ravel(), *args,**kwds)
else: #One dimensional data
return griddata((self.args,), self.data, *args,**kwds)
def integrate(self, a, b, **kwds):
'''
>>> x = np.linspace(0,5,60)
>>> d = WafoData(np.sin(x), x)
>>> d.integrate(0,np.pi/2)
'''
method = kwds.pop('method','trapz')
fun = getattr(integrate, method)
if isinstance(self.args, (list, tuple)): # Multidimensional data
ndim = len(self.args)
if ndim < 2:
msg = '''Unable to determine plotter-type, because len(self.args)<2.
If the data is 1D, then self.args should be a vector!
If the data is 2D, then length(self.args) should be 2.
If the data is 3D, then length(self.args) should be 3.
Unless you fix this, the plot methods will not work!'''
warnings.warn(msg)
else:
return griddata(self.args, self.data.ravel(), **kwds)
else: #One dimensional data
x = self.args
ix = np.flatnonzero((a<x) & (x<b) )
xi = np.hstack((a, x.take(ix), b))
fi = np.hstack((self.eval_points(a),self.data.take(ix),self.eval_points(b)))
return fun(fi, xi, **kwds)
def show(self):
self.plotter.show()
@ -431,7 +486,19 @@ def plot2d(axis, wdata, plotflag, *args, **kwds):
#end
# pass
def test_eval_points():
plotbackend.ioff()
x = np.linspace(0,5,21)
d = WafoData(np.sin(x),x)
xi = np.linspace(0,5,61)
di = WafoData(d.eval_points(xi,method='cubic'),xi)
d.plot('.')
di.plot()
di.show()
def test_integrate():
x = np.linspace(0,5,60)
d = WafoData(np.sin(x), x)
print(d.integrate(0,np.pi/2,method='simps'))
def test_docstrings():
import doctest
doctest.testmod()
@ -440,5 +507,7 @@ def main():
pass
if __name__ == '__main__':
test_docstrings()
test_integrate()
#test_eval_points()
#test_docstrings()
#main()

Loading…
Cancel
Save