Removed obsolete code and added more tests.

master
Per A Brodtkorb 8 years ago
parent d9e2349248
commit 7657b53b3d

@ -32,9 +32,9 @@ def kde_demo1():
data = np.random.normal(loc=0, scale=1.0, size=7)
kernel = Kernel('gauss')
hs = kernel.hns(data)
hVec = [hs / 2, hs, 2 * hs]
h_vec = [hs / 2, hs, 2 * hs]
for ix, h in enumerate(hVec):
for ix, h in enumerate(h_vec):
plt.figure(ix)
kde = KDE(data, hs=h, kernel=kernel)
f2 = kde(x, output='plot', title='h_s = {0:2.2f}'.format(float(h)),
@ -176,7 +176,8 @@ def kde_demo5(N=500):
def kreg_demo1(hs=None, fast=False, fun='hisj'):
""""""
"""Compare KRegression to KernelReg from statsmodels.nonparametric
"""
N = 100
# ei = np.random.normal(loc=0, scale=0.075, size=(N,))
ei = np.array([
@ -236,6 +237,9 @@ def kreg_demo1(hs=None, fast=False, fun='hisj'):
def _get_data(n=100, symmetric=False, loc1=1.1, scale1=0.6, scale2=1.0):
"""
Return test data for binomial regression demo.
"""
st = scipy.stats
dist = st.norm
@ -262,6 +266,9 @@ def _get_data(n=100, symmetric=False, loc1=1.1, scale1=0.6, scale2=1.0):
def check_bkregression():
"""
Check binomial regression
"""
plt.ion()
k = 0
for _i, n in enumerate([50, 100, 300, 600]):

@ -412,27 +412,6 @@ class TKDE(_KDE):
Check the KDE for spurious spikes''')
return pdf
def eval_grid_fast2(self, *args, **kwds):
"""Evaluate the estimated pdf on a grid.
Parameters
----------
arg_0,arg_1,... arg_d-1 : vectors
Alternatively, if no vectors is passed in then
arg_i = gauss2dat(linspace(dat2gauss(self.xmin[i]),
dat2gauss(self.xmax[i]), self.inc))
output : string optional
'value' if value output
'data' if object output
Returns
-------
values : array-like
The values evaluated at meshgrid(*args).
"""
return self.eval_grid_fun(self._eval_grid_fast, *args, **kwds)
def _interpolate(self, points, f, *args, **kwds):
ipoints = meshgrid(*args) # if self.d > 1 else args
for i in range(self.d):

@ -103,15 +103,12 @@ class TestKde(unittest.TestCase):
assert_allclose(f, [1.03982714, 0.45839018, 0.39514782, 0.32860602,
0.26433318, 0.20717946, 0.15907684, 0.1201074,
0.08941027, 0.06574882])
f = kde.eval_grid(x)
assert_allclose(f, [1.03982714, 0.45839018, 0.39514782, 0.32860602,
0.26433318, 0.20717946, 0.15907684, 0.1201074,
0.08941027, 0.06574882])
assert_allclose(np.trapz(f, x), 0.94787730659349068)
f = kde.eval_grid_fast(x)
assert_allclose(f, [1.0401892415290148, 0.45838973393693677,
0.39514689240671547, 0.32860531818532457,
0.2643330110605783, 0.20717975528556506,
0.15907696844388747, 0.12010770443337843,
0.08941129458260941, 0.06574899139165799])
f = kde.eval_grid_fast2(x)
assert_allclose(f, [1.0401892415290148, 0.45838973393693677,
0.39514689240671547, 0.32860531818532457,
0.2643330110605783, 0.20717975528556506,
@ -340,7 +337,7 @@ class TestRegression(unittest.TestCase):
bkreg = wk.BKRegression(x, y, a=0.05, b=0.05)
fbest = bkreg.prb_search_best(hsfun='hste', alpha=0.05, color='g')
print(fbest.data[::10].tolist())
# print(fbest.data[::10].tolist())
assert_allclose(fbest.data[::10],
[1.80899736e-15, 0, 6.48351162e-16, 6.61404311e-15,
@ -355,6 +352,27 @@ class TestRegression(unittest.TestCase):
2.68633448e-04, 1.68974054e-05, 5.73040143e-07,
1.11994760e-08, 1.36708818e-10, 1.09965904e-12,
5.43806309e-15, 0.0, 0, 0], atol=1e-10)
bkreg = wk.BKRegression(x, y, method='wilson')
fbest = bkreg.prb_search_best(hsfun='hste', alpha=0.05, color='g')
assert_allclose(fbest.data[::10],
[3.2321397702105376e-15, 4.745626420805898e-17,
6.406118940191104e-16, 5.648884668051452e-16,
3.499875381296387e-16, 1.0090442883241678e-13,
4.264723863193633e-11, 9.29288388831705e-09,
9.610074789043923e-07, 4.086642453634508e-05,
0.0008305202502773989, 0.00909121197102206,
0.05490768364395013, 0.1876637145781381,
0.4483015169104682, 0.8666709816557657,
0.9916656713022183, 0.9996648903706271,
0.999990921956741, 0.9999909219567404,
0.999664890370625, 0.9916656713022127,
0.8666709816557588, 0.4483015169104501,
0.18766371457812697, 0.054907683643947366,
0.009091211971022042, 0.0008305202502770367,
4.086642453593762e-05, 9.610074786590158e-07,
9.292883469982049e-09, 4.264660017463372e-11,
1.005284921271869e-13, -0.0, -0.0, -0.0, -0.0, -0.0],
atol=1e-10)
if __name__ == "__main__":
# import sys;sys.argv = ['', 'Test.testName']

Loading…
Cancel
Save