Small refactoring FitDistribution

master
Per A Brodtkorb 9 years ago
parent df78c0e728
commit 8fe60a970f

@ -561,7 +561,7 @@ def fit2(self, data, *args, **kwds):
`data` is sorted using this function, so if `copydata`==False the data `data` is sorted using this function, so if `copydata`==False the data
in your namespace will be sorted as well. in your namespace will be sorted as well.
''' '''
return FitDistribution(self, data, *args, **kwds) return FitDistribution(self, data, args, **kwds)
rv_generic.freeze = freeze rv_generic.freeze = freeze

@ -13,7 +13,7 @@ import warnings
from wafo.plotbackend import plotbackend from wafo.plotbackend import plotbackend
from wafo.misc import ecross, findcross, argsreduce from wafo.misc import ecross, findcross, argsreduce
from wafo.stats._constants import _EPS, _XMAX from wafo.stats._constants import _EPS, _XMAX
from wafo.stats._distn_infrastructure import rv_frozen from wafo.stats._distn_infrastructure import rv_frozen, rv_continuous
from scipy._lib.six import string_types from scipy._lib.six import string_types
import numdifftools as nd # @UnresolvedImport import numdifftools as nd # @UnresolvedImport
from scipy import special from scipy import special
@ -21,7 +21,7 @@ from scipy.linalg import pinv2
from scipy import optimize from scipy import optimize
import numpy as np import numpy as np
from numpy import (alltrue, arange, ravel, zeros, log, sqrt, exp, from numpy import (alltrue, arange, zeros, log, sqrt, exp,
atleast_1d, any, asarray, nan, pi, isfinite) atleast_1d, any, asarray, nan, pi, isfinite)
from numpy import flatnonzero as nonzero from numpy import flatnonzero as nonzero
@ -550,7 +550,8 @@ class FitDistribution(rv_frozen):
>>> sf_ci = Lsf.get_bounds(alpha=0.2) >>> sf_ci = Lsf.get_bounds(alpha=0.2)
''' '''
def __init__(self, dist, data, *args, **kwds): def __init__(self, dist, data, args=(), method='ML', alpha=0.05,
par_fix=None, search=True, copydata=True, **kwds):
extradoc = ''' extradoc = '''
plotfitsummary() plotfitsummary()
Plot various diagnostic plots to asses quality of fit. Plot various diagnostic plots to asses quality of fit.
@ -603,25 +604,23 @@ class FitDistribution(rv_frozen):
self.dist = dist self.dist = dist
numargs = dist.numargs numargs = dist.numargs
self.method = self.alpha = self.par_fix = self.search = None self.method = method
self.copydata = None self.alpha = alpha
m_variables = ['method', 'alpha', 'par_fix', 'search', 'copydata'] self.par_fix = par_fix
m_defaults = ['ml', 0.05, None, True, True] self.search = search
for (name, val) in zip(m_variables, m_defaults): self.copydata = copydata
setattr(self, name, kwds.pop(name, val))
if self.method.lower()[:].startswith('mps'): if self.method.lower()[:].startswith('mps'):
self._fitfun = self._nlogps self._fitfun = self._nlogps
else: else:
self._fitfun = self._nnlf self._fitfun = self._nnlf
self.data = ravel(data) self.data = np.ravel(data)
if self.copydata: if self.copydata:
self.data = self.data.copy() self.data = self.data.copy()
self.data.sort() self.data.sort()
par, fixedn = self._fit(*args, **kwds.copy())
par, fixedn = self._fit(*args, **kwds) super(FitDistribution, self).__init__(dist, *par)
# super(FitDistribution, self).__init__(dist, *par)
self.par = arr(par) self.par = arr(par)
somefixed = len(fixedn) > 0 somefixed = len(fixedn) > 0
if somefixed: if somefixed:
@ -658,8 +657,9 @@ class FitDistribution(rv_frozen):
# First of all, convert fshapes params to fnum: eg for stats.beta, # First of all, convert fshapes params to fnum: eg for stats.beta,
# shapes='a, b'. To fix `a`, can specify either `f1` or `fa`. # shapes='a, b'. To fix `a`, can specify either `f1` or `fa`.
# Convert the latter into the former. # Convert the latter into the former.
if self.shapes: shapes = self.dist.shapes
shapes = self.shapes.replace(',', ' ').split() if shapes:
shapes = shapes.replace(',', ' ').split()
for j, s in enumerate(shapes): for j, s in enumerate(shapes):
val = kwds.pop('f' + s, None) or kwds.pop('fix_' + s, None) val = kwds.pop('f' + s, None) or kwds.pop('fix_' + s, None)
if val is not None: if val is not None:
@ -784,16 +784,16 @@ class FitDistribution(rv_frozen):
product of spacings.", product of spacings.",
IMS Lecture Notes Monograph Series 2006, Vol. 52, pp. 272-283 IMS Lecture Notes Monograph Series 2006, Vol. 52, pp. 272-283
""" """
n = 2 if self._rv_continous else 1 n = 2 if isinstance(self.dist, rv_continuous) else 1
try: try:
loc = theta[-n] loc = theta[-n]
scale = theta[-1] scale = theta[-1]
args = tuple(theta[:-n]) args = tuple(theta[:-n])
except IndexError: except IndexError:
raise ValueError("Not enough input arguments.") raise ValueError("Not enough input arguments.")
if not self._rv_continous: if not isinstance(self.dist, rv_continuous):
scale = 1 scale = 1
if not self._argcheck(*args) or scale <= 0: if not self.dist._argcheck(*args) or scale <= 0:
return np.inf return np.inf
dist = self.dist dist = self.dist
x = asarray((x - loc) / scale) x = asarray((x - loc) / scale)
@ -867,7 +867,7 @@ class FitDistribution(rv_frozen):
# by now kwds must be empty, since everybody took what they needed # by now kwds must be empty, since everybody took what they needed
if kwds: if kwds:
raise TypeError("Unknown arguments: %s." % kwds) raise TypeError("Unknown arguments: %s." % kwds)
vals = optimizer(func, x0, args=(ravel(data),), disp=0) vals = optimizer(func, x0, args=(np.ravel(data),), disp=0)
vals = tuple(vals) vals = tuple(vals)
else: else:
vals = tuple(x0) vals = tuple(x0)

Loading…
Cancel
Save