Refactored _reduce_func

master
Per A Brodtkorb 8 years ago
parent 9205e1b19d
commit 267c3cb62d

@ -723,6 +723,7 @@ def plot_all_profiles(phats, plot=None):
profile_phat_k = Profile(phats, i=k) profile_phat_k = Profile(phats, i=k)
m = 0 m = 0
while hasattr(profile_phat_k, 'best_par') and m < 7: while hasattr(profile_phat_k, 'best_par') and m < 7:
# iterate to find optimum phat!
phats.fit(*profile_phat_k.best_par) phats.fit(*profile_phat_k.best_par)
profile_phat_k = Profile(phats, i=k) profile_phat_k = Profile(phats, i=k)
m += 1 m += 1
@ -841,10 +842,9 @@ class ProfileQuantile(Profile):
prb = exp(self.log_sf) prb = exp(self.log_sf)
return self.fit_dist.dist.isf(prb, *mphat) return self.fit_dist.dist.isf(prb, *mphat)
def _set_plot_labels(self, method): def _set_plot_labels(self, method, title='', xlabel='x'):
title = '{:s} quantile'.format(self.fit_dist.dist.name) title = '{:s} quantile'.format(self.fit_dist.dist.name)
super(ProfileQuantile, self)._set_plot_labels(method, title, super(ProfileQuantile, self)._set_plot_labels(method, title, xlabel)
xlabel='x')
class ProfileProbability(Profile): class ProfileProbability(Profile):
@ -939,7 +939,7 @@ class ProfileProbability(Profile):
logsf = self.fit_dist.dist.logsf(self.x, *mphat) logsf = self.fit_dist.dist.logsf(self.x, *mphat)
return np.where(np.isfinite(logsf), logsf, np.nan) return np.where(np.isfinite(logsf), logsf, np.nan)
def _set_plot_labels(self, method): def _set_plot_labels(self, method, title='', xlabel=''):
title = '{:s} probability'.format(self.fit_dist.dist.name) title = '{:s} probability'.format(self.fit_dist.dist.name)
xlabel = 'log(sf)' xlabel = 'log(sf)'
super(ProfileProbability, self)._set_plot_labels(method, title, xlabel) super(ProfileProbability, self)._set_plot_labels(method, title, xlabel)
@ -1154,7 +1154,7 @@ class FitDistribution(rv_frozen):
t.append('%s = %s\n' % (par, str(getattr(self, par)))) t.append('%s = %s\n' % (par, str(getattr(self, par))))
return ''.join(t) return ''.join(t)
def _reduce_func(self, args, kwds): def _convert_fshapes2fnum(self, kwds):
# First of all, convert fshapes params to fnum: eg for stats.beta, # First of all, convert fshapes params to fnum: eg for stats.beta,
# shapes='a, b'. To fix `a`, can specify either `f1` or `fa`. # shapes='a, b'. To fix `a`, can specify either `f1` or `fa`.
# Convert the latter into the former. # Convert the latter into the former.
@ -1169,10 +1169,13 @@ class FitDistribution(rv_frozen):
raise ValueError("Duplicate entry for %s." % key) raise ValueError("Duplicate entry for %s." % key)
else: else:
kwds[key] = val kwds[key] = val
return kwds
def _unpack_args_kwds(self, args, kwds):
kwds = self._convert_fshapes2fnum(kwds)
args = list(args) args = list(args)
nargs = len(args)
fixedn = [] fixedn = []
names = ['f%d' % n for n in range(nargs - 2)] + ['floc', 'fscale'] names = ['f%d' % n for n in range(len(args) - 2)] + ['floc', 'fscale']
x0 = [] x0 = []
for n, key in enumerate(names): for n, key in enumerate(names):
if key in kwds: if key in kwds:
@ -1180,7 +1183,12 @@ class FitDistribution(rv_frozen):
args[n] = kwds.pop(key) args[n] = kwds.pop(key)
else: else:
x0.append(args[n]) x0.append(args[n])
return x0, args, fixedn
def _reduce_func(self, args, kwds):
x0, args, fixedn = self._unpack_args_kwds(args, kwds)
nargs = len(args)
fitfun = self._fitfun fitfun = self._fitfun
if len(fixedn) == 0: if len(fixedn) == 0:

Loading…
Cancel
Save