From fdaf6701de851a2f859b9dc159fe83bac61dc2fa Mon Sep 17 00:00:00 2001 From: Per A Brodtkorb Date: Wed, 22 Feb 2017 12:00:40 +0100 Subject: [PATCH] Refactored to simplify --- wafo/integrate.py | 45 ++++++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/wafo/integrate.py b/wafo/integrate.py index 2cf76d8..9257b65 100644 --- a/wafo/integrate.py +++ b/wafo/integrate.py @@ -973,17 +973,15 @@ class _Gaussq(object): return jacob @staticmethod - def _warn(k, a_shape): + def _warn_msg(k, a_shape): n = len(k) if n > 1: if n == np.prod(a_shape): - tmptxt = 'All integrals did not converge' + msg = 'All integrals did not converge' else: - tmptxt = '%d integrals did not converge' % (n, ) - tmptxt = tmptxt + '--singularities likely!' - else: - tmptxt = 'Integral did not converge--singularity likely!' - warnings.warn(tmptxt) + msg = '%d integrals did not converge' % (n, ) + return msg + '--singularities likely!' + return 'Integral did not converge--singularity likely!' @staticmethod def _initialize(wfun, a, b, args): @@ -996,6 +994,16 @@ class _Gaussq(object): a_out = zeros((a_out.size, 1)) return a_out, b_out, args, a_shape + + def _revert_nans_with_old(self, val, val_old): + if any(np.isnan(val)): + val[np.isnan(val)] = val_old[np.isnan(val)] + + + def _update_error(self, i, abserr, val, val_old, k): + if i > 1: + abserr[k] = abs(val_old[k] - val[k]) # absolute tolerance + def __call__(self, fun, a, b, releps=1e-3, abseps=1e-3, alpha=0, beta=0, wfun=1, trace=False, args=(), max_iter=11): self.trace = trace @@ -1015,8 +1023,8 @@ class _Gaussq(object): dtype = np.result_type(fun((a_0+b_0)*0.5, *args)) n_k = np.prod(a_shape) # # of integrals we have to compute k = np.arange(n_k) - opts = (n_k, dtype) - val, val_old, abserr = zeros(*opts), ones(*opts), zeros(*opts) + opt = (n_k, dtype) + val, val_old, abserr = zeros(*opt), np.nan*ones(*opt), 1e100*ones(*opt) nodes_and_weights = self._nodes_and_weights for i in range(max_iter): x_n, weights = nodes_and_weights(num_nodes, wfun, alpha, beta) @@ -1026,18 +1034,17 @@ class _Gaussq(object): y = fun(x, *params) self._plot_trace(x, y) val[k] = np.sum(weights * y, axis=1) * d_x[k] # do the integration - if any(np.isnan(val)): - val[np.isnan(val)] = val_old[np.isnan(val)] - if 1 < i: - abserr[k] = abs(val_old[k] - val[k]) # absolute tolerance - k, = np.where(abserr > np.maximum(abs(releps * val), abseps)) - n_k = len(k) # of integrals we have to compute again - if n_k == 0: - break + self._revert_nans_with_old(val, val_old) + self._update_error(i, abserr, val, val_old, k) + + k, = np.where(abserr > np.maximum(abs(releps * val), abseps)) + converged = len(k) == 0 + if converged: + break val_old[k] = val[k] num_nodes *= 2 # double the # of basepoints and weights - else: - self._warn(k, a_shape) + + _assert_warn(converged, self._warn_msg(k, a_shape)) # make sure int is the same size as the integration limits val.shape = a_shape