Improved accuracy and avoids some runtime warnings.

master
Per A Brodtkorb 9 years ago
parent 0750bdd94b
commit 59186edfd6

@ -32,12 +32,12 @@ except:
from scipy.stats._tukeylambda_stats import (tukeylambda_variance as _tlvar, from scipy.stats._tukeylambda_stats import (tukeylambda_variance as _tlvar,
tukeylambda_kurtosis as _tlkurt) tukeylambda_kurtosis as _tlkurt)
from ._distn_infrastructure import ( from wafo.stats._distn_infrastructure import (
rv_continuous, valarray, _skew, _kurtosis, _lazywhere, rv_continuous, valarray, _skew, _kurtosis, _lazywhere,
_ncx2_log_pdf, _ncx2_pdf, _ncx2_cdf, get_distribution_names, _ncx2_log_pdf, _ncx2_pdf, _ncx2_cdf, get_distribution_names,
) )
from ._constants import _XMIN, _EULER, _ZETA3, _XMAX, _LOGXMAX, _EPS from wafo.stats._constants import _XMIN, _EULER, _ZETA3, _XMAX, _LOGXMAX, _EPS
## Kolmogorov-Smirnov one-sided and two-sided test statistics ## Kolmogorov-Smirnov one-sided and two-sided test statistics
@ -2018,12 +2018,12 @@ class genextreme_gen(rv_continuous):
return exp(self._logpdf(x, c)) return exp(self._logpdf(x, c))
def _logpdf(self, x, c): def _logpdf(self, x, c):
cx = c*x cx = _lazywhere((c != 0)*(x == x), (x, c), lambda x, c: c*x, 0.0)
logex2 = where((c == 0)*(x == x), 0.0, special.log1p(-cx)) logex2 = special.log1p(-cx)
logpex2 = self._loglogcdf(x, c) logpex2 = self._loglogcdf(x, c)
# logpex2 = where((c == 0)*(x == x), -x, logex2/c)
pex2 = exp(logpex2) pex2 = exp(logpex2)
# Handle special cases # Handle special cases
putmask(logpex2, (c == 0) & (x == -inf), 0.0)
logpdf = where((cx == 1) | (cx == -inf), -inf, -pex2+logpex2-logex2) logpdf = where((cx == 1) | (cx == -inf), -inf, -pex2+logpex2-logex2)
putmask(logpdf, (c == 1) & (x == 1), 0.0) putmask(logpdf, (c == 1) & (x == 1), 0.0)
return logpdf return logpdf
@ -2032,9 +2032,8 @@ class genextreme_gen(rv_continuous):
return exp(self._logcdf(x, c)) return exp(self._logcdf(x, c))
def _loglogcdf(self, x, c): def _loglogcdf(self, x, c):
return _lazywhere((c == 0)*(x == x), (x, c), return _lazywhere((c != 0) & (x == x), (x, c),
f=lambda x, c: -x, lambda x, c: special.log1p(-c * x)/c, -x)
f2=lambda x, c: special.log1p(-c*x)/c)
def _logcdf(self, x, c): def _logcdf(self, x, c):
return -exp(self._loglogcdf(x, c)) return -exp(self._loglogcdf(x, c))
@ -2049,10 +2048,8 @@ class genextreme_gen(rv_continuous):
def _isf(self, q, c): def _isf(self, q, c):
x = -log(-special.log1p(-q)) x = -log(-special.log1p(-q))
result = _lazywhere((c == 0)*(x == x), (x, c), return _lazywhere((x == x) & (c != 0), (x, c),
f=lambda x, c: x, lambda x, c: -expm1(-c * x) / c, x)
f2=lambda x, c: -special.expm1(-c*x)/c)
return result
def _stats(self, c): def _stats(self, c):
g = lambda n: gam(n*c+1) g = lambda n: gam(n*c+1)
@ -2560,6 +2557,9 @@ class gumbel_l_gen(rv_continuous):
def _logpdf(self, x): def _logpdf(self, x):
return x - exp(x) return x - exp(x)
def _logsf(self, x):
return -exp(x)
def _sf(self, x): def _sf(self, x):
return exp(-exp(x)) return exp(-exp(x))
@ -3203,6 +3203,9 @@ class logistic_gen(rv_continuous):
def _ppf(self, q): def _ppf(self, q):
return -log1p(-q) + log(q) return -log1p(-q) + log(q)
def _isf(self, q):
return log1p(-q) - log(q)
def _stats(self): def _stats(self):
return 0, pi*pi/3.0, 0, 6.0/5.0 return 0, pi*pi/3.0, 0, 6.0/5.0
@ -4262,7 +4265,8 @@ class rayleigh_gen(rv_continuous):
def _logpdf(self, r): def _logpdf(self, r):
rr2 = r * r / 2.0 rr2 = r * r / 2.0
return where(rr2 == inf, - rr2, log(r) - rr2) return _lazywhere(rr2 != inf, (r, rr2), lambda r, rr2: log(r) - rr2,
-rr2)
def _cdf(self, r): def _cdf(self, r):
return -special.expm1(-0.5 * r**2) return -special.expm1(-0.5 * r**2)
@ -4273,6 +4277,9 @@ class rayleigh_gen(rv_continuous):
def _sf(self, r): def _sf(self, r):
return exp(-0.5 * r**2) return exp(-0.5 * r**2)
def _logsf(self, r):
return -0.5 * r**2
def _isf(self, q): def _isf(self, q):
return sqrt(-2 * log(q)) return sqrt(-2 * log(q))
@ -5158,3 +5165,9 @@ pairs = list(globals().items())
_distn_names, _distn_gen_names = get_distribution_names(pairs, rv_continuous) _distn_names, _distn_gen_names = get_distribution_names(pairs, rv_continuous)
__all__ = _distn_names + _distn_gen_names __all__ = _distn_names + _distn_gen_names
if __name__=='__main__':
v = genextreme.logpdf(np.inf, 0)
v2 = genextreme.logpdf(-np.inf, 0)
v2 = genextreme.logpdf(-100, 0)

Loading…
Cancel
Save