You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
184 lines
6.4 KiB
Python
184 lines
6.4 KiB
Python
11 years ago
|
from __future__ import division, print_function, absolute_import
|
||
|
|
||
|
import numpy.testing as npt
|
||
|
import numpy as np
|
||
10 years ago
|
from scipy.lib.six import xrange
|
||
|
|
||
11 years ago
|
from wafo import stats
|
||
|
from wafo.stats.tests.common_tests import (check_normalization, check_moment,
|
||
|
check_mean_expect,
|
||
11 years ago
|
check_var_expect, check_skew_expect, check_kurt_expect,
|
||
|
check_entropy, check_private_entropy, check_edge_support,
|
||
|
check_named_args)
|
||
10 years ago
|
from wafo.stats._distr_params import distdiscrete
|
||
11 years ago
|
knf = npt.dec.knownfailureif
|
||
|
|
||
|
|
||
|
def test_discrete_basic():
|
||
|
for distname, arg in distdiscrete:
|
||
|
distfn = getattr(stats, distname)
|
||
|
np.random.seed(9765456)
|
||
|
rvs = distfn.rvs(size=2000, *arg)
|
||
|
supp = np.unique(rvs)
|
||
10 years ago
|
m, v = distfn.stats(*arg)
|
||
11 years ago
|
yield check_cdf_ppf, distfn, arg, supp, distname + ' cdf_ppf'
|
||
|
|
||
|
yield check_pmf_cdf, distfn, arg, distname
|
||
|
yield check_oth, distfn, arg, supp, distname + ' oth'
|
||
|
yield check_edge_support, distfn, arg
|
||
|
|
||
|
alpha = 0.01
|
||
|
yield check_discrete_chisquare, distfn, arg, rvs, alpha, \
|
||
|
distname + ' chisquare'
|
||
|
|
||
|
seen = set()
|
||
|
for distname, arg in distdiscrete:
|
||
|
if distname in seen:
|
||
|
continue
|
||
|
seen.add(distname)
|
||
10 years ago
|
distfn = getattr(stats,distname)
|
||
11 years ago
|
locscale_defaults = (0,)
|
||
|
meths = [distfn.pmf, distfn.logpmf, distfn.cdf, distfn.logcdf,
|
||
|
distfn.logsf]
|
||
|
# make sure arguments are within support
|
||
|
spec_k = {'randint': 11, 'hypergeom': 4, 'bernoulli': 0, }
|
||
|
k = spec_k.get(distname, 1)
|
||
|
yield check_named_args, distfn, k, arg, locscale_defaults, meths
|
||
|
yield check_scale_docstring, distfn
|
||
|
|
||
|
# Entropy
|
||
|
yield check_entropy, distfn, arg, distname
|
||
|
if distfn.__class__._entropy != stats.rv_discrete._entropy:
|
||
|
yield check_private_entropy, distfn, arg, stats.rv_discrete
|
||
|
|
||
|
|
||
|
def test_moments():
|
||
|
for distname, arg in distdiscrete:
|
||
10 years ago
|
distfn = getattr(stats,distname)
|
||
11 years ago
|
m, v, s, k = distfn.stats(*arg, moments='mvsk')
|
||
|
yield check_normalization, distfn, arg, distname
|
||
|
|
||
|
# compare `stats` and `moment` methods
|
||
|
yield check_moment, distfn, arg, m, v, distname
|
||
|
yield check_mean_expect, distfn, arg, m, distname
|
||
|
yield check_var_expect, distfn, arg, m, v, distname
|
||
|
yield check_skew_expect, distfn, arg, m, v, s, distname
|
||
|
|
||
10 years ago
|
cond = False #distname in ['zipf']
|
||
11 years ago
|
msg = distname + ' fails kurtosis'
|
||
|
yield knf(cond, msg)(check_kurt_expect), distfn, arg, m, v, k, distname
|
||
|
|
||
|
# frozen distr moments
|
||
|
yield check_moment_frozen, distfn, arg, m, 1
|
||
10 years ago
|
yield check_moment_frozen, distfn, arg, v+m*m, 2
|
||
11 years ago
|
|
||
|
|
||
|
def check_cdf_ppf(distfn, arg, supp, msg):
|
||
|
# cdf is a step function, and ppf(q) = min{k : cdf(k) >= q, k integer}
|
||
|
npt.assert_array_equal(distfn.ppf(distfn.cdf(supp, *arg), *arg),
|
||
|
supp, msg + '-roundtrip')
|
||
|
npt.assert_array_equal(distfn.ppf(distfn.cdf(supp, *arg) - 1e-8, *arg),
|
||
|
supp, msg + '-roundtrip')
|
||
|
supp1 = supp[supp < distfn.b]
|
||
|
npt.assert_array_equal(distfn.ppf(distfn.cdf(supp1, *arg) + 1e-8, *arg),
|
||
|
supp1 + distfn.inc, msg + 'ppf-cdf-next')
|
||
|
# -1e-8 could cause an error if pmf < 1e-8
|
||
|
|
||
|
|
||
|
def check_pmf_cdf(distfn, arg, distname):
|
||
|
startind = np.int(distfn.ppf(0.01, *arg) - 1)
|
||
|
index = list(range(startind, startind + 10))
|
||
10 years ago
|
cdfs, pmfs_cum = distfn.cdf(index,*arg), distfn.pmf(index, *arg).cumsum()
|
||
11 years ago
|
|
||
|
atol, rtol = 1e-10, 1e-10
|
||
|
if distname == 'skellam': # ncx2 accuracy
|
||
|
atol, rtol = 1e-5, 1e-5
|
||
|
npt.assert_allclose(cdfs - cdfs[0], pmfs_cum - pmfs_cum[0],
|
||
|
atol=atol, rtol=rtol)
|
||
|
|
||
|
|
||
|
def check_moment_frozen(distfn, arg, m, k):
|
||
|
npt.assert_allclose(distfn(*arg).moment(k), m,
|
||
|
atol=1e-10, rtol=1e-10)
|
||
|
|
||
|
|
||
|
def check_oth(distfn, arg, supp, msg):
|
||
|
# checking other methods of distfn
|
||
|
npt.assert_allclose(distfn.sf(supp, *arg), 1. - distfn.cdf(supp, *arg),
|
||
|
atol=1e-10, rtol=1e-10)
|
||
|
|
||
|
q = np.linspace(0.01, 0.99, 20)
|
||
|
npt.assert_allclose(distfn.isf(q, *arg), distfn.ppf(1. - q, *arg),
|
||
|
atol=1e-10, rtol=1e-10)
|
||
|
|
||
|
median_sf = distfn.isf(0.5, *arg)
|
||
|
npt.assert_(distfn.sf(median_sf - 1, *arg) > 0.5)
|
||
|
npt.assert_(distfn.cdf(median_sf + 1, *arg) > 0.5)
|
||
|
|
||
|
|
||
|
def check_discrete_chisquare(distfn, arg, rvs, alpha, msg):
|
||
|
"""Perform chisquare test for random sample of a discrete distribution
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
distname : string
|
||
|
name of distribution function
|
||
|
arg : sequence
|
||
|
parameters of distribution
|
||
|
alpha : float
|
||
|
significance level, threshold for p-value
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
result : bool
|
||
|
0 if test passes, 1 if test fails
|
||
|
|
||
|
uses global variable debug for printing results
|
||
|
|
||
|
"""
|
||
|
n = len(rvs)
|
||
|
nsupp = 20
|
||
10 years ago
|
wsupp = 1.0/nsupp
|
||
11 years ago
|
|
||
|
# construct intervals with minimum mass 1/nsupp
|
||
|
# intervals are left-half-open as in a cdf difference
|
||
|
distsupport = xrange(max(distfn.a, -1000), min(distfn.b, 1000) + 1)
|
||
|
last = 0
|
||
|
distsupp = [max(distfn.a, -1000)]
|
||
|
distmass = []
|
||
|
for ii in distsupport:
|
||
10 years ago
|
current = distfn.cdf(ii,*arg)
|
||
|
if current - last >= wsupp-1e-14:
|
||
11 years ago
|
distsupp.append(ii)
|
||
|
distmass.append(current - last)
|
||
|
last = current
|
||
10 years ago
|
if current > (1-wsupp):
|
||
11 years ago
|
break
|
||
|
if distsupp[-1] < distfn.b:
|
||
|
distsupp.append(distfn.b)
|
||
10 years ago
|
distmass.append(1-last)
|
||
11 years ago
|
distsupp = np.array(distsupp)
|
||
|
distmass = np.array(distmass)
|
||
|
|
||
|
# convert intervals to right-half-open as required by histogram
|
||
10 years ago
|
histsupp = distsupp+1e-8
|
||
11 years ago
|
histsupp[0] = distfn.a
|
||
|
|
||
|
# find sample frequencies and perform chisquare test
|
||
10 years ago
|
freq,hsupp = np.histogram(rvs,histsupp)
|
||
|
cdfs = distfn.cdf(distsupp,*arg)
|
||
|
(chis,pval) = stats.chisquare(np.array(freq),n*distmass)
|
||
11 years ago
|
|
||
|
npt.assert_(pval > alpha, 'chisquare - test for %s'
|
||
10 years ago
|
' at arg = %s with pval = %s' % (msg,str(arg),str(pval)))
|
||
11 years ago
|
|
||
|
|
||
|
def check_scale_docstring(distfn):
|
||
|
if distfn.__doc__ is not None:
|
||
|
# Docstrings can be stripped if interpreter is run with -OO
|
||
|
npt.assert_('scale' not in distfn.__doc__)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
npt.run_module_suite()
|