From 5adcc7446d019a35e2376f1d637f96921a9a2beb Mon Sep 17 00:00:00 2001 From: Per A Brodtkorb Date: Wed, 22 Feb 2017 11:11:06 +0100 Subject: [PATCH] Refactored to reduce complexity --- wafo/integrate.py | 57 +++++++++++++++++++++++++++++------------------ 1 file changed, 35 insertions(+), 22 deletions(-) diff --git a/wafo/integrate.py b/wafo/integrate.py index 3539e22..2cf76d8 100644 --- a/wafo/integrate.py +++ b/wafo/integrate.py @@ -1271,8 +1271,28 @@ def _plot_error(neval, err_dic, plot_error): plt.legend() +def _print_headers(formats_h, headers, names): + print(''.join(fi % t for (fi, t) in zip(formats_h, + ['ftn'] + names))) + print(' '.join(headers)) + + +def _stack_values_and_errors(neval, vals_dic, err_dic, names): + data = [neval] + for name in names: + data.append(vals_dic[name]) + data.append(err_dic[name]) + + data = np.vstack(tuple(data)).T + return data + + +def _print_data(formats, data): + for row in data: + print(''.join(fi % t for (fi, t) in zip(formats, row.tolist()))) + + def _print_values_and_errors(neval, vals_dic, err_dic): - kmax = len(neval) names = sorted(vals_dic.keys()) num_cols = 2 formats = ['%4.0f, '] + ['%10.10f, '] * num_cols * 2 @@ -1281,19 +1301,10 @@ def _print_values_and_errors(neval, vals_dic, err_dic): formats_h[-1] = formats_h[-1].split(',')[0] headers = ['evals'] + ['%12s %12s' % ('approx', 'error')] * num_cols while len(names) > 0: - print(''.join(fi % t for (fi, t) in zip(formats_h, - ['ftn'] + names[:num_cols]))) - print(' '.join(headers)) - data = [neval] - for name in names[:num_cols]: - data.append(vals_dic[name]) - data.append(err_dic[name]) - - data = np.vstack(tuple(data)).T - for k in range(kmax): - tmp = data[k].tolist() - print(''.join(fi % t for (fi, t) in zip(formats, tmp))) - + names_c = names[:num_cols] + _print_headers(formats_h, headers, names_c) + data = _stack_values_and_errors(neval, vals_dic, err_dic, names_c) + _print_data(formats, data) names = names[num_cols:] @@ -1303,6 +1314,15 @@ def _display(neval, vals_dic, err_dic, plot_error): _plot_error(neval, err_dic, plot_error) +def chebychev(y, x, n=None): + if n is None: + n = len(y) + c_k = np.polynomial.chebyshev.chebfit(x, y, deg=min(n - 1, 36)) + c_ki = np.polynomial.chebyshev.chebint(c_k) + q = np.polynomial.chebyshev.chebval(x[-1], c_ki) + return q + + def qdemo(f, a, b, kmax=9, plot_error=False): """ Compares different quadrature rules. @@ -1371,7 +1391,7 @@ def qdemo(f, a, b, kmax=9, plot_error=False): err_dic = {} # try various approximations - methods = [trapz, simps, boole, ] + methods = [trapz, simps, boole, chebychev] for k in range(kmax): n = 2 ** (k + 1) + 1 @@ -1389,13 +1409,6 @@ def qdemo(f, a, b, kmax=9, plot_error=False): vals_dic.setdefault(name, []).append(q[0]) err_dic.setdefault(name, []).append(abs(q[0] - true_val)) - name = 'Chebychev' - c_k = np.polynomial.chebyshev.chebfit(x, y, deg=min(n-1, 36)) - c_ki = np.polynomial.chebyshev.chebint(c_k) - q = np.polynomial.chebyshev.chebval(x[-1], c_ki) - vals_dic.setdefault(name, []).append(q) - err_dic.setdefault(name, []).append(abs(q - true_val)) - name = 'Gauss-Legendre' # quadrature q = intg.fixed_quad(f, a, b, n=n)[0] vals_dic.setdefault(name, []).append(q)