From 9205e1b19dcb1a3c06a70e19be8f802eb7e42803 Mon Sep 17 00:00:00 2001 From: Per A Brodtkorb Date: Tue, 3 Jan 2017 14:58:54 +0100 Subject: [PATCH] Refactored to simplify --- wafo/stats/estimation.py | 163 +++++++++++++++++++++------------------ 1 file changed, 87 insertions(+), 76 deletions(-) diff --git a/wafo/stats/estimation.py b/wafo/stats/estimation.py index 9a4f5e9..5d6a76a 100644 --- a/wafo/stats/estimation.py +++ b/wafo/stats/estimation.py @@ -410,7 +410,7 @@ class Profile(object): self._set_indexes(fit_dist, i) method = fit_dist.method.lower() - self._set_plot_labels(fit_dist, method) + self._set_plot_labels(method) Lmax = self._loglike_max(fit_dist, method) self.Lmax = Lmax @@ -419,13 +419,16 @@ class Profile(object): self._set_profile() - def _set_plot_labels(self, fit_dist, method): + def _set_plot_labels(self, method, title='', xlabel=''): + if not title: + title = '{:s} params'.format(self.fit_dist.dist.name) + if not xlabel: + xlabel = 'phat[{}]'.format(np.ravel(self.i_fixed)[0]) percent = 100 * (1.0 - self.alpha) - self.title = '{:g}% CI for {:s} params'.format(percent, - fit_dist.dist.name) + self.title = '{:g}% CI for {:s}'.format(percent, title) like_txt = 'likelihood' if method == 'ml' else 'product spacing' self.ylabel = 'Profile log' + like_txt - self.xlabel = 'phat[{}]'.format(np.ravel(self.i_fixed)[0]) + self.xlabel = xlabel @staticmethod def _loglike_max(fit_dist, method): @@ -551,7 +554,17 @@ class Profile(object): warnings.warn(str(err)) def _get_variance(self): - return self.fit_dist.par_cov[self.i_fixed, :][:, self.i_fixed] + invfun = getattr(self, '_myinvfun', None) + if invfun is not None: + i_notfixed = self.i_notfixed + pcov = self.fit_dist.par_cov[i_notfixed, :][:, i_notfixed] + gradfun = nd.Gradient(invfun) + phatv = self._par + drl = gradfun(phatv[i_notfixed]) + pvar = np.sum(np.dot(drl, pcov) * drl) + return pvar + pvar = self.fit_dist.par_cov[self.i_fixed, :][:, self.i_fixed] + return pvar def _approx_p_min_max(self, p_opt): pvar = self._get_variance() @@ -565,9 +578,9 @@ class Profile(object): p_low, p_up = self._approx_p_min_max(p_opt) pmin, pmax = self.pmin, self.pmax if pmin is None: - pmin = self._search_pminmax(phatfree0, p_low, p_opt, 'min') + pmin = self._search_p_min_max(phatfree0, p_low, p_opt, 'min') if pmax is None: - pmax = self._search_pminmax(phatfree0, p_up, p_opt, 'max') + pmax = self._search_p_min_max(phatfree0, p_up, p_opt, 'max') return pmin, pmax def _adaptive_pvec(self, p_opt, pmin, pmax): @@ -589,30 +602,41 @@ class Profile(object): return self._adaptive_pvec(p_opt, pmin, pmax) return np.linspace(self.pmin, self.pmax, self.n) - def _search_pminmax(self, phatfree0, p_minmax0, p_opt, direction): + def _update_p_opt(self, p_minmax_opt, dp, Lmax, p_minmax, j): + # print((dp, p_minmax, p_minmax_opt, Lmax)) + converged = False + if np.isnan(Lmax): + dp *= 0.33 + elif Lmax < self.alpha_cross_level - self.alpha_Lrange * 5 * (j + 1): + p_minmax_opt = p_minmax + dp *= 0.33 + elif Lmax < self.alpha_cross_level: + p_minmax_opt = p_minmax + converged = True + else: + dp *= 1.67 + return p_minmax_opt, dp, converged + + def _search_p_min_max(self, phatfree0, p_minmax0, p_opt, direction): phatfree = phatfree0.copy() - sign = -1 if direction == 'min' else 1 + sign = dict(min=-1, max=1)[direction] dp = np.maximum(sign*(p_minmax0 - p_opt) / 40, 0.01) * 10 Lmax, phatfree = self._profile_optimum(phatfree, p_opt) p_minmax_opt = p_minmax0 - for j in range(51): + j = 0 + converged = False + # for j in range(51): + while j < 51 and not converged: + j += 1 p_minmax = p_opt + sign * dp Lmax, phatfree = self._profile_optimum(phatfree, p_minmax) - # print((dp, p_minmax, p_minmax_opt, Lmax)) - if np.isnan(Lmax): - dp *= 0.33 - elif Lmax < self.alpha_cross_level - self.alpha_Lrange*5*(j+1): - p_minmax_opt = p_minmax - dp *= 0.33 - elif Lmax < self.alpha_cross_level: - p_minmax_opt = p_minmax - break - else: - dp *= 1.67 - else: - msg = 'Exceeded max iterations. (p_{0}0={1}, p_{0}={2}, p={3})' - warnings.warn(msg.format(direction, p_minmax0, p_minmax_opt, - p_opt)) + p_minmax_opt, dp, converged = self._update_p_opt(p_minmax_opt, dp, + Lmax, p_minmax, j) + _assert_warn(j < 50, 'Exceeded max iterations. ' + '(p_{0}0={1}, p_{0}={2}, p={3})'.format(direction, + p_minmax0, + p_minmax_opt, + p_opt)) # print('search_pmin iterations={}'.format(j)) return p_minmax_opt @@ -689,8 +713,24 @@ class Profile(object): def plot_all_profiles(phats, plot=None): - if plot is not None: - plt = plot + def _remove_title_or_ylabel(plt, n, j): + if j != 0: + plt.title('') + if j != n // 2: + plt.ylabel('') + + def _profile(phats, k): + profile_phat_k = Profile(phats, i=k) + m = 0 + while hasattr(profile_phat_k, 'best_par') and m < 7: + phats.fit(*profile_phat_k.best_par) + profile_phat_k = Profile(phats, i=k) + m += 1 + + return profile_phat_k + + if plot is None: + plot = plt if phats.par_fix: indices = phats.i_notfixed @@ -699,20 +739,12 @@ def plot_all_profiles(phats, plot=None): n = len(indices) for j, k in enumerate(indices): plt.subplot(n, 1, j+1) - profile_phat_k = Profile(phats, i=k) - m = 0 - while hasattr(profile_phat_k, 'best_par') and m < 7: - phats.fit(*profile_phat_k.best_par) - profile_phat_k = Profile(phats, i=k) - m += 1 + profile_phat_k = _profile(phats, k) profile_phat_k.plot() - if j != 0: - plt.title('') - if j != n//2: - plt.ylabel('') - plt.subplots_adjust(hspace=0.5) + _remove_title_or_ylabel(plt, n, j) + plot.subplots_adjust(hspace=0.5) par_txt = ('{:1.2g}, '*len(phats.par))[:-2].format(*phats.par) - plt.suptitle('phat = [{}] (fit metod: {})'.format(par_txt, phats.method)) + plot.suptitle('phat = [{}] (fit metod: {})'.format(par_txt, phats.method)) return phats @@ -809,21 +841,10 @@ class ProfileQuantile(Profile): prb = exp(self.log_sf) return self.fit_dist.dist.isf(prb, *mphat) - def _get_variance(self): - i_notfixed = self.i_notfixed - phatv = self._par - gradfun = nd.Gradient(self._myinvfun) - drl = gradfun(phatv[self.i_notfixed]) - pcov = self.fit_dist.par_cov[i_notfixed, :][:, i_notfixed] - pvar = np.sum(np.dot(drl, pcov) * drl) - return pvar - - def _set_plot_labels(self, fit_dist, method): - super(ProfileQuantile, self)._set_plot_labels(fit_dist, method) - percent = 100 * (1.0 - self.alpha) - self.title = '{:g}% CI for {:s} quantile'.format(percent, - fit_dist.dist.name) - self.xlabel = 'x' + def _set_plot_labels(self, method): + title = '{:s} quantile'.format(self.fit_dist.dist.name) + super(ProfileQuantile, self)._set_plot_labels(method, title, + xlabel='x') class ProfileProbability(Profile): @@ -911,27 +932,17 @@ class ProfileProbability(Profile): fix_par = self.link(self.x, fixed_log_sf, par, self.i_fixed) return fix_par - def _myprbfun(self, phatnotfixed): + def _myinvfun(self, phatnotfixed): + """_myprbfun""" mphat = self._par.copy() mphat[self.i_notfixed] = phatnotfixed logsf = self.fit_dist.dist.logsf(self.x, *mphat) return np.where(np.isfinite(logsf), logsf, np.nan) - def _get_variance(self): - i_notfixed = self.i_notfixed - phatv = self._par - gradfun = nd.Gradient(self._myprbfun) - drl = gradfun(phatv[self.i_notfixed]) - pcov = self.fit_dist.par_cov[i_notfixed, :][:, i_notfixed] - pvar = np.sum(np.dot(drl, pcov) * drl) - return pvar - - def _set_plot_labels(self, fit_dist, method): - super(ProfileProbability, self)._set_plot_labels(fit_dist, method) - percent = 100 * (1.0 - self.alpha) - self.title = '{:g}% CI for {:s} probability'.format(percent, - fit_dist.dist.name) - self.xlabel = 'log(sf)' + def _set_plot_labels(self, method): + title = '{:s} probability'.format(self.fit_dist.dist.name) + xlabel = 'log(sf)' + super(ProfileProbability, self)._set_plot_labels(method, title, xlabel) class FitDistribution(rv_frozen): @@ -1463,7 +1474,7 @@ class FitDistribution(rv_frozen): phatvstr) subtxt = ('Fit method: {0:s}, Fit p-value: {1:2.2f} {2:s}, ' + 'phat=[{3:s}], {4:s}') - par_txt = '{:1.2g}, ' * len(self.par)[:-2].format(*self.par) + par_txt = ('{:1.2g}, ' * len(self.par))[:-2].format(*self.par) try: LL_txt = 'Lps_max={:2.2g}, Ll_max={:2.2g}'.format(self.LPSmax, self.LLmax) @@ -1691,8 +1702,8 @@ def test1(): phat = FitDistribution(dist, R, floc=0.5, method='ml') phats = FitDistribution(dist, R, floc=0.5, method='mps') # import matplotlib.pyplot as plt - # plt.figure(0) - # plot_all_profiles(phat, plot=plt) + plt.figure(0) + plot_all_profiles(phat, plot=plt) plt.figure(1) phats.plotfitsummary() @@ -1727,5 +1738,5 @@ def test1(): if __name__ == '__main__': - test1() - # test_doctstrings() + # test1() + test_doctstrings()