Refactored to simplify

master
Per A Brodtkorb 8 years ago
parent 5adcc7446d
commit fdaf6701de

@ -973,17 +973,15 @@ class _Gaussq(object):
return jacob return jacob
@staticmethod @staticmethod
def _warn(k, a_shape): def _warn_msg(k, a_shape):
n = len(k) n = len(k)
if n > 1: if n > 1:
if n == np.prod(a_shape): if n == np.prod(a_shape):
tmptxt = 'All integrals did not converge' msg = 'All integrals did not converge'
else: else:
tmptxt = '%d integrals did not converge' % (n, ) msg = '%d integrals did not converge' % (n, )
tmptxt = tmptxt + '--singularities likely!' return msg + '--singularities likely!'
else: return 'Integral did not converge--singularity likely!'
tmptxt = 'Integral did not converge--singularity likely!'
warnings.warn(tmptxt)
@staticmethod @staticmethod
def _initialize(wfun, a, b, args): def _initialize(wfun, a, b, args):
@ -996,6 +994,16 @@ class _Gaussq(object):
a_out = zeros((a_out.size, 1)) a_out = zeros((a_out.size, 1))
return a_out, b_out, args, a_shape return a_out, b_out, args, a_shape
def _revert_nans_with_old(self, val, val_old):
if any(np.isnan(val)):
val[np.isnan(val)] = val_old[np.isnan(val)]
def _update_error(self, i, abserr, val, val_old, k):
if i > 1:
abserr[k] = abs(val_old[k] - val[k]) # absolute tolerance
def __call__(self, fun, a, b, releps=1e-3, abseps=1e-3, alpha=0, beta=0, def __call__(self, fun, a, b, releps=1e-3, abseps=1e-3, alpha=0, beta=0,
wfun=1, trace=False, args=(), max_iter=11): wfun=1, trace=False, args=(), max_iter=11):
self.trace = trace self.trace = trace
@ -1015,8 +1023,8 @@ class _Gaussq(object):
dtype = np.result_type(fun((a_0+b_0)*0.5, *args)) dtype = np.result_type(fun((a_0+b_0)*0.5, *args))
n_k = np.prod(a_shape) # # of integrals we have to compute n_k = np.prod(a_shape) # # of integrals we have to compute
k = np.arange(n_k) k = np.arange(n_k)
opts = (n_k, dtype) opt = (n_k, dtype)
val, val_old, abserr = zeros(*opts), ones(*opts), zeros(*opts) val, val_old, abserr = zeros(*opt), np.nan*ones(*opt), 1e100*ones(*opt)
nodes_and_weights = self._nodes_and_weights nodes_and_weights = self._nodes_and_weights
for i in range(max_iter): for i in range(max_iter):
x_n, weights = nodes_and_weights(num_nodes, wfun, alpha, beta) x_n, weights = nodes_and_weights(num_nodes, wfun, alpha, beta)
@ -1026,18 +1034,17 @@ class _Gaussq(object):
y = fun(x, *params) y = fun(x, *params)
self._plot_trace(x, y) self._plot_trace(x, y)
val[k] = np.sum(weights * y, axis=1) * d_x[k] # do the integration val[k] = np.sum(weights * y, axis=1) * d_x[k] # do the integration
if any(np.isnan(val)): self._revert_nans_with_old(val, val_old)
val[np.isnan(val)] = val_old[np.isnan(val)] self._update_error(i, abserr, val, val_old, k)
if 1 < i:
abserr[k] = abs(val_old[k] - val[k]) # absolute tolerance
k, = np.where(abserr > np.maximum(abs(releps * val), abseps)) k, = np.where(abserr > np.maximum(abs(releps * val), abseps))
n_k = len(k) # of integrals we have to compute again converged = len(k) == 0
if n_k == 0: if converged:
break break
val_old[k] = val[k] val_old[k] = val[k]
num_nodes *= 2 # double the # of basepoints and weights num_nodes *= 2 # double the # of basepoints and weights
else:
self._warn(k, a_shape) _assert_warn(converged, self._warn_msg(k, a_shape))
# make sure int is the same size as the integration limits # make sure int is the same size as the integration limits
val.shape = a_shape val.shape = a_shape

Loading…
Cancel
Save