diff --git a/wafo/stats/_distn_infrastructure.py b/wafo/stats/_distn_infrastructure.py index 9677194..0842a9d 100644 --- a/wafo/stats/_distn_infrastructure.py +++ b/wafo/stats/_distn_infrastructure.py @@ -561,7 +561,7 @@ def fit2(self, data, *args, **kwds): `data` is sorted using this function, so if `copydata`==False the data in your namespace will be sorted as well. ''' - return FitDistribution(self, data, *args, **kwds) + return FitDistribution(self, data, args, **kwds) rv_generic.freeze = freeze diff --git a/wafo/stats/estimation.py b/wafo/stats/estimation.py index 3a481ad..5455ff3 100644 --- a/wafo/stats/estimation.py +++ b/wafo/stats/estimation.py @@ -13,7 +13,7 @@ import warnings from wafo.plotbackend import plotbackend from wafo.misc import ecross, findcross, argsreduce from wafo.stats._constants import _EPS, _XMAX -from wafo.stats._distn_infrastructure import rv_frozen +from wafo.stats._distn_infrastructure import rv_frozen, rv_continuous from scipy._lib.six import string_types import numdifftools as nd # @UnresolvedImport from scipy import special @@ -21,7 +21,7 @@ from scipy.linalg import pinv2 from scipy import optimize import numpy as np -from numpy import (alltrue, arange, ravel, zeros, log, sqrt, exp, +from numpy import (alltrue, arange, zeros, log, sqrt, exp, atleast_1d, any, asarray, nan, pi, isfinite) from numpy import flatnonzero as nonzero @@ -550,7 +550,8 @@ class FitDistribution(rv_frozen): >>> sf_ci = Lsf.get_bounds(alpha=0.2) ''' - def __init__(self, dist, data, *args, **kwds): + def __init__(self, dist, data, args=(), method='ML', alpha=0.05, + par_fix=None, search=True, copydata=True, **kwds): extradoc = ''' plotfitsummary() Plot various diagnostic plots to asses quality of fit. @@ -603,25 +604,23 @@ class FitDistribution(rv_frozen): self.dist = dist numargs = dist.numargs - self.method = self.alpha = self.par_fix = self.search = None - self.copydata = None - m_variables = ['method', 'alpha', 'par_fix', 'search', 'copydata'] - m_defaults = ['ml', 0.05, None, True, True] - for (name, val) in zip(m_variables, m_defaults): - setattr(self, name, kwds.pop(name, val)) + self.method = method + self.alpha = alpha + self.par_fix = par_fix + self.search = search + self.copydata = copydata if self.method.lower()[:].startswith('mps'): self._fitfun = self._nlogps else: self._fitfun = self._nnlf - self.data = ravel(data) + self.data = np.ravel(data) if self.copydata: self.data = self.data.copy() self.data.sort() - - par, fixedn = self._fit(*args, **kwds) - # super(FitDistribution, self).__init__(dist, *par) + par, fixedn = self._fit(*args, **kwds.copy()) + super(FitDistribution, self).__init__(dist, *par) self.par = arr(par) somefixed = len(fixedn) > 0 if somefixed: @@ -658,8 +657,9 @@ class FitDistribution(rv_frozen): # First of all, convert fshapes params to fnum: eg for stats.beta, # shapes='a, b'. To fix `a`, can specify either `f1` or `fa`. # Convert the latter into the former. - if self.shapes: - shapes = self.shapes.replace(',', ' ').split() + shapes = self.dist.shapes + if shapes: + shapes = shapes.replace(',', ' ').split() for j, s in enumerate(shapes): val = kwds.pop('f' + s, None) or kwds.pop('fix_' + s, None) if val is not None: @@ -784,16 +784,16 @@ class FitDistribution(rv_frozen): product of spacings.", IMS Lecture Notes Monograph Series 2006, Vol. 52, pp. 272-283 """ - n = 2 if self._rv_continous else 1 + n = 2 if isinstance(self.dist, rv_continuous) else 1 try: loc = theta[-n] scale = theta[-1] args = tuple(theta[:-n]) except IndexError: raise ValueError("Not enough input arguments.") - if not self._rv_continous: + if not isinstance(self.dist, rv_continuous): scale = 1 - if not self._argcheck(*args) or scale <= 0: + if not self.dist._argcheck(*args) or scale <= 0: return np.inf dist = self.dist x = asarray((x - loc) / scale) @@ -867,7 +867,7 @@ class FitDistribution(rv_frozen): # by now kwds must be empty, since everybody took what they needed if kwds: raise TypeError("Unknown arguments: %s." % kwds) - vals = optimizer(func, x0, args=(ravel(data),), disp=0) + vals = optimizer(func, x0, args=(np.ravel(data),), disp=0) vals = tuple(vals) else: vals = tuple(x0)