From 267c3cb62d2b6b93a1cd24ae069c5baac9fff23f Mon Sep 17 00:00:00 2001 From: Per A Brodtkorb Date: Tue, 3 Jan 2017 22:16:00 +0100 Subject: [PATCH] Refactored _reduce_func --- wafo/stats/estimation.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/wafo/stats/estimation.py b/wafo/stats/estimation.py index 5d6a76a..6ce8e2b 100644 --- a/wafo/stats/estimation.py +++ b/wafo/stats/estimation.py @@ -723,6 +723,7 @@ def plot_all_profiles(phats, plot=None): profile_phat_k = Profile(phats, i=k) m = 0 while hasattr(profile_phat_k, 'best_par') and m < 7: + # iterate to find optimum phat! phats.fit(*profile_phat_k.best_par) profile_phat_k = Profile(phats, i=k) m += 1 @@ -841,10 +842,9 @@ class ProfileQuantile(Profile): prb = exp(self.log_sf) return self.fit_dist.dist.isf(prb, *mphat) - def _set_plot_labels(self, method): + def _set_plot_labels(self, method, title='', xlabel='x'): title = '{:s} quantile'.format(self.fit_dist.dist.name) - super(ProfileQuantile, self)._set_plot_labels(method, title, - xlabel='x') + super(ProfileQuantile, self)._set_plot_labels(method, title, xlabel) class ProfileProbability(Profile): @@ -939,7 +939,7 @@ class ProfileProbability(Profile): logsf = self.fit_dist.dist.logsf(self.x, *mphat) return np.where(np.isfinite(logsf), logsf, np.nan) - def _set_plot_labels(self, method): + def _set_plot_labels(self, method, title='', xlabel=''): title = '{:s} probability'.format(self.fit_dist.dist.name) xlabel = 'log(sf)' super(ProfileProbability, self)._set_plot_labels(method, title, xlabel) @@ -1154,7 +1154,7 @@ class FitDistribution(rv_frozen): t.append('%s = %s\n' % (par, str(getattr(self, par)))) return ''.join(t) - def _reduce_func(self, args, kwds): + def _convert_fshapes2fnum(self, kwds): # First of all, convert fshapes params to fnum: eg for stats.beta, # shapes='a, b'. To fix `a`, can specify either `f1` or `fa`. # Convert the latter into the former. @@ -1169,10 +1169,13 @@ class FitDistribution(rv_frozen): raise ValueError("Duplicate entry for %s." % key) else: kwds[key] = val + return kwds + + def _unpack_args_kwds(self, args, kwds): + kwds = self._convert_fshapes2fnum(kwds) args = list(args) - nargs = len(args) fixedn = [] - names = ['f%d' % n for n in range(nargs - 2)] + ['floc', 'fscale'] + names = ['f%d' % n for n in range(len(args) - 2)] + ['floc', 'fscale'] x0 = [] for n, key in enumerate(names): if key in kwds: @@ -1180,7 +1183,12 @@ class FitDistribution(rv_frozen): args[n] = kwds.pop(key) else: x0.append(args[n]) + return x0, args, fixedn + def _reduce_func(self, args, kwds): + x0, args, fixedn = self._unpack_args_kwds(args, kwds) + + nargs = len(args) fitfun = self._fitfun if len(fixedn) == 0: @@ -1194,7 +1202,7 @@ class FitDistribution(rv_frozen): def restore(args, theta): # Replace with theta for all numbers not in fixedn # This allows the non-fixed values to vary, but - # we still call self.nnlf with all parameters. + # we still call self.nnlf with all parameters. i = 0 for n in range(nargs): if n not in fixedn: @@ -1358,7 +1366,7 @@ class FitDistribution(rv_frozen): H = np.asmatrix(self._hessian(self._fitfun, self.par, self.data)) # H = -nd.Hessian(lambda par: self._fitfun(par, self.data), - # method='forward')(self.par) + # method='forward')(self.par) self.H = H try: if somefixed: