Fixed some bugs in kdetools.py + added more tests in test/test_kdetools.py

Updated meshgrid
master
Per.Andreas.Brodtkorb 13 years ago
parent df9e95f0b5
commit bac74a93b5

@ -1115,7 +1115,7 @@ class Kernel(object):
covA = scipy.cov(A)
return a * linalg.sqrtm(covA) * n * (-1. / (d + 4))
return a * linalg.sqrtm(covA).real * n ** (-1. / (d + 4))
def hste(self, data, h0=None, inc=128, maxit=100, releps=0.01, abseps=0.0):
'''HSTE 2-Stage Solve the Equation estimate of smoothing parameter.
@ -1335,7 +1335,7 @@ class Kernel(object):
ix = np.arange(1, inc - 1)
z = ((f[ix + 1] - 2 * f[ix] + f[ix - 1]) / delta ** 2) ** 2
psi4 = delta * z.sum()
h1 = (STEconstant / psi4) ** (1 / 5);
h1 = (STEconstant / psi4) ** (1. / 5);
if count >= maxit:
warnings.warn('The obtained value did not converge.')
@ -1538,9 +1538,9 @@ class Kernel(object):
h = np.zeros(d)
for dim in range(d):
s = sigmaA[dim]
datan = A[dim] / s
ax = ax1[dim] / s
bx = bx1[dim] / s
datan = A[dim] #/ s
ax = ax1[dim] #/ s
bx = bx1[dim] #/ s
xa = np.linspace(ax, bx, inc)
xn = np.linspace(0, bx - ax, inc)
@ -1551,19 +1551,20 @@ class Kernel(object):
rd2 = L + 2
# Eq. 3.7 in Wand and Jones (1995)
PSI_r = (-1) ** (rd2) * np.prod(np.r_[rd2 + 1:r]) / (sqrt(pi) * (2 * s) ** (r + 1));
PSI_r = (-1) ** (rd2) * np.prod(np.r_[rd2 + 1:r+1]) / (sqrt(pi) * (2 * s) ** (r + 1));
#PSI_r = (-1) ** (rd2) * np.prod(np.r_[rd2 + 1:r]) / (sqrt(pi) * (2.0) ** (r + 1));
PSI = PSI_r
if L > 0:
# High order derivatives of the Gaussian kernel
Kd = kernel2.deriv4_6_8_10(0, numout=L)
# L-stage iterations to estimate PSI_4
for ix in range(L - 1, -1, -1):
gi = (-2 * Kd[ix] / (mu2 * PSI * n)) ** (1. / (2 * ix + 5))
for ix in range(L, 0, -1):
gi = (-2 * Kd[ix-1] / (mu2 * PSI * n)) ** (1. / (2 * ix + 5))
# Obtain the kernel weights.
KW0 = kernel2.deriv4_6_8_10(xn / gi, numout=ix + 1)
if ix > 0:
KW0 = kernel2.deriv4_6_8_10(xn / gi, numout=ix)
if ix > 1:
KW0 = KW0[-1]
kw = np.r_[KW0, 0, KW0[inc - 1:0:-1]] # Apply 'fftshift' to kw.
@ -1573,7 +1574,7 @@ class Kernel(object):
PSI = np.sum(c * z[:inc]) / (n ** 2 * gi ** (2 * ix + 3))
#end
#end
h[dim] = s * (STEconstant / PSI) ** (1. / 5)
h[dim] = (STEconstant / PSI) ** (1. / 5)
return h
@ -2161,13 +2162,13 @@ def gridcount(data, X):
Parameters
----------
data = column vectors with D-dimensional data, size D x Nd
X = row vectors defining discretization, size D x N
data = column vectors with D-dimensional data, shape D x Nd
X = row vectors defining discretization, shape D x N
Must include the range of the data.
Returns
-------
c = gridcount, size N x N x ... x N
c = gridcount, shape N x N x ... x N
GRIDCOUNT obtains the grid counts using linear binning.
There are 2 strategies: simple- or linear- binning.
@ -2248,7 +2249,7 @@ def gridcount(data, X):
accum(c_[b1 + 1, b2] , abs(np.prod(stk([X[0, b1], X[1, b2 + 1]]) - dat, axis=0)), size=[inc, inc]) +
accum(c_[b1 , b2 + 1], abs(np.prod(stk([X[0, b1 + 1], X[1, b2]]) - dat, axis=0)), size=[inc, inc]) +
accum(c_[b1 + 1, b2 + 1], abs(np.prod(stk([X[0, b1], X[1, b2]]) - dat, axis=0)), size=[inc, inc])) / w
c = c.T # make sure c is stored in the same way as meshgrid
else: # % d>2
Nc = csiz.prod()
@ -2273,14 +2274,10 @@ def gridcount(data, X):
c = np.reshape(c / w, csiz, order='C')
# TODO: check that the flipping of axis is correct
T = range(d); T[-2], T[-1] = T[-1], T[-2]
c = c.transpose(*T)
if d == 2: # make sure c is stored in the same way as meshgrid
c = c.T
elif d == 3:
c = c.transpose(1, 0, 2)
T = range(d)
T[-2], T[-1] = T[-1], T[-2]
c = c.transpose(*T) # make sure c is stored in the same way as meshgrid
return c
@ -2346,11 +2343,160 @@ def kde_demo2():
pylab.plot(x, st.rayleigh.pdf(x, scale=1), ':')
pylab.figure(0)
def test_gridcount():
import numpy as np
#import wafo.kdetools as wk
from matplotlib import pyplot as plb
data = data_rayleigh()
N = len(data)
x = np.linspace(0,max(data)+1,50)
dx = x[1]-x[0]
c = gridcount(data,x)
ctr = np.array([ 0, 4, 10, 14, 15, 23, 16, 18, 21, 19, 37, 32, 24, 29, 29,
24, 29, 26, 14, 13, 23, 9, 13, 11, 7, 12, 5, 2, 2, 6,
2, 2, 5, 0, 0, 0, 1, 2, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0])
print(np.abs(c-ctr)<1e-13)
pdf = c/dx/N
h = plb.plot(x,c,'.') # 1D histogram
plb.show()
pass
data1 = data.reshape((2,-1))
c2 = gridcount(data1, np.vstack((x,x)))
c2t = np.array([ 0, 0.635018844262034, 1.170430267508894, 0.480210926714613,
1.256122839305450, 2.050244222017545, 1.250782602003382,
1.253065702416950, 1.295571917793612, 1.978725535031301,
0.829707237562507, 1.842636337195244, 2.767829900577593,
1.449607074995753, 2.759640664415913, 1.634036650764552,
1.990159690320205, 1.201953891214720, 1.277182991907633,
1.293002868977407, 1.157268941919115, 1.275443411253732,
1.132629693243210, 1.418284942741350, 0.770571572744340,
0.475119743103730, 0.982244208825375, 0.681834272971076,
0.359044910769366, 0.582570635672141, 0.658412627992049,
0.784814887272479, 0.448937228228751, 0.314220262358783,
0, 0, 0, 0.404906838651598, 0.113164094088927,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0])
print(np.abs((c2t-c2.max(axis=-1)))<1e-13)
data4 = data.reshape((4,-1))
x = np.linspace(0,max(data)+1,11)
c4 = gridcount(data4, np.vstack((x,x,x,x)))
print(np.abs((c2t-c2.max(axis=-1)))<1e-13)
def data_rayleigh():
return np.array([1.412423619721313, 0.936610012402041, 3.408880790209544, 0.712493911648517, 1.453856820100018,
1.362971623321745, 0.989738148038997, 0.553839936552347, 0.225638048436888, 1.045606709473107,
0.637908826214993, 1.608606426103143, 0.961884939327567, 1.919572795331000, 1.627957520304931,
1.301044712134641, 0.623895791202139, 2.512180741739924, 0.785268132885580, 2.273629106639021,
0.711768125619732, 0.967169434614618, 0.427942932230904, 1.429667825110794, 0.631194936581454,
0.303636149404372, 1.602725333691658, 0.923957338868325, 1.470119382037774, 0.984169729516054,
1.405066725863521, 0.209225286647146, 2.197407087587261, 1.795680986321718, 1.655186235334962,
1.831529484073858, 0.983242434909240, 1.385965094654130, 1.309069260384021, 1.228928476737294,
0.802097056076411, 0.756115979892675, 1.096194354290486, 0.718886376439466, 1.806619521908829,
2.924438974501607, 0.246782936313644, 1.238666429277650, 0.426858243844038, 1.799972319758650,
3.007697177898959, 0.372270006672035, 2.367882325903836, 0.191545163286195, 1.517565471255659,
1.750004351582044, 1.236013671840509, 2.081323476045300, 2.141346897323470, 1.402378494050162,
0.544698152936965, 0.700923468199988, 0.634137874072855, 0.292299453493133, 0.475611960045215,
1.384390337219331, 2.369715926664043, 0.935586970891954, 1.028299144484800, 2.883486469293792
, 2.412885676436705, 1.502666625449783, 2.982736936434333, 1.706454468843206, 0.906120073100622
, 1.473661960328491, 0.748241351472675, 0.836991325956595, 1.509961488710520, 1.225113736935942
, 1.029890543888216, 1.358608202305835, 1.666359355892623, 1.323592437751299, 1.266885170390769
, 1.323660367761004, 1.197556616382116, 0.415219867081348, 1.594635770596585, 3.335047448446035
, 0.935717067162817, 2.664366406420023, 0.922317019697774, 2.086307246777435, 1.101280854500658
, 1.032916883571698, 0.700796651725546, 0.518227310036530, 0.859641628285530, 1.609352902696174
, 1.747723418451391, 1.538490395884064, 0.140361038832643, 1.925029474333574, 2.260668891490430
, 1.716877040260210, 0.295284687152802, 0.974796888317386, 1.561117460932286, 1.617115585994090
, 0.712684884618426, 0.791728102952554, 1.495252766892452, 1.139399282670031, 1.398150348314015
, 0.734533909397005, 0.624418865181972, 1.881415056762913, 1.706681395455110, 2.334683483141081
, 0.477838065222462, 0.634304509316731, 0.456849600683082, 1.160070279761997, 0.655340613711381
, 2.127121229851198, 0.456835914801069, 0.300568039387414, 1.276598603254562, 1.720804090031422
, 0.864730384700170, 0.628029981916123, 0.909872945858993, 0.686886746420088, 1.194705989905012
, 2.176393257858438, 1.408082540391850, 0.462744617618753, 0.995689247143699, 0.335155890849689
, 1.179590017201302, 1.063149176870603, 3.468688654992744, 1.827129780001552, 1.153130138387220
, 2.120636338813882, 0.544011313217379, 0.994288065215423, 2.290060076679768, 2.233778068583924
, 1.581312813059112, 1.387961284638806, 0.917070930336856, 2.344909067035151, 0.516281292935132
, 1.619570115238485, 1.442087344999343, 1.892443909224431, 1.007935276931834, 1.664682222719219
, 1.899024552783311, 0.882368714153905, 1.267711468232034, 2.781870230167854, 1.262515173989300
, 0.895667370955997, 1.390103843633942, 0.945814188732813, 1.680879252209405, 1.033698343725955
, 1.164434112863078, 1.540520689869044, 2.684068016815589, 0.891215308218909, 1.907325227589101
, 1.639214228101874, 2.483108383603044, 0.254728352176505, 0.939581631904974, 1.474208721908681
, 0.813131900087889, 0.723688300231953, 1.575927348326343, 1.399779277625481, 1.336475769769517
, 1.469760951955162, 0.312162051579979, 0.926191271942077, 1.095698311512132, 0.742466620037192
, 1.584565588783017, 1.969369796694313, 0.813142402654688, 1.620637451940408, 0.544472183788396
, 1.903841273483371, 0.546256895921489, 1.332096299659611, 0.954938400592347, 1.813185033558344
, 1.183839081745172, 1.159783992966029, 2.047421367071099, 0.933411156868096, 1.092543634708061
, 2.573430838154017, 0.294001371116989, 2.687145854798348, 0.647676314841560, 1.483222246093897
, 1.328873011650546, 1.499517291077073, 0.946451616282504, 1.391629977859238, 1.825818674800223
, 0.197207089634922, 0.418570979518484, 2.713292260256486, 1.451678603677107, 0.725222188153537
, 1.016524657331659, 0.510160866644535, 0.790663553688482, 0.772267750711634, 0.897737257071539
, 0.574718435129065, 0.924902659911130, 0.509352052679121, 2.076287755824404, 0.445024255426400
, 2.306443399859831, 1.009151694589026, 0.311646355326560, 0.915552448311802, 1.631979165302650
, 1.779435892929737, 1.254791667465325, 1.522546690241251, 2.117005924369452, 0.335708348510442
, 0.850786945794020, 0.307546485903476, 0.659553530770440, 1.595968673282009, 1.599529339207843
, 2.050409047591333, 1.321597656988126, 0.382901575350795, 2.263023675024229, 1.795160219589414
, 0.820728808594631, 1.252616635345433, 2.893059873111469, 1.585968547208113, 1.911105489168721
, 1.065697540675240, 1.127880912464618, 1.282656038601722, 0.791791712034066, 1.662754292624110
, 1.184021211521453, 1.442739185251488, 0.857673288506446, 0.546518081971571, 1.136176847824479
, 0.948827835556975, 1.761649333500106, 1.740961388239338, 1.486044626143792, 0.535345914616625
, 0.208765940502775, 1.281107790531077, 0.845985407399993, 2.367961441281100, 2.813630157287030
, 0.821877833204895, 1.796411857645166, 2.128114489536385, 1.349167308872121, 2.075721258630550
, 2.399008601572707, 1.262250789152573, 1.614544130176768, 1.311344244094387, 0.228900207318000
, 1.087703540854728, 1.441743192607425, 1.213654375953261, 0.965104247192400, 2.352343682973261
, 1.881070184767099, 1.944757925743782, 0.965470015113788, 1.341190290874416, 2.029803572337272
, 2.328337398097465, 1.485947310986503, 0.680661741126981, 1.456629522069083, 0.386727549117631
, 1.021861509017076, 1.482839980680464, 2.329786461679046, 1.825236378759161, 1.151270272972182
, 1.681465022236889, 1.038893153052472, 2.671305569135296, 2.973463508311512, 1.998091967015353
, 0.992439538152367, 1.101359057223470, 0.752694797719731, 1.751820513743222, 2.070842495255286
, 2.213621940109904, 1.278350678290866, 1.351639733749908, 0.567799782374724, 2.144632385787214
, 1.094123263719430, 0.678615107641789, 2.144341891738539, 1.695846624058156, 2.069396249839028
, 0.819027610733285, 1.495651321040951, 1.477482666605742, 0.511932330475827, 1.022837224533765
, 0.802470556959117, 1.588170058614226, 0.816352471969601, 2.128510415901388, 1.871914791729839
, 0.994323676062132, 1.173849936976207, 1.540652455108271, 1.896308447022061, 1.371611808573705
, 1.307706279079749, 0.888355489837264, 1.104161992788381, 1.581802123863791, 2.077336259709684
, 1.597514520759674, 0.193846187739953, 1.498827901810269, 1.074392126178632, 1.073250683084153
, 0.498942436443271, 1.836126539886937, 0.886372885469560, 0.751884958648598, 0.916116650002177
, 0.970681891155015, 1.257679318479529, 1.284798886225563, 1.003879276488743, 1.007685729946785
, 1.203631029712442, 1.463948632472297, 2.455282398854625, 1.600867640016765, 1.010899145846306
, 1.888399192628552, 0.537142702822369, 0.353191429514348, 0.419544177537439, 0.598339442960937
, 0.885310772269136, 0.847519694333472, 0.153295465546788, 1.246051759006313, 0.447732587957780
, 0.562898114036050, 1.412332385111654, 1.980540530235424, 2.704891701651084, 1.300708887507808
, 3.394236570275002, 1.269967710402906, 1.203787442037781, 0.896098313870595, 1.060303799139334
, 1.163522680114773, 0.383891805234107, 2.091377862339729, 0.365559694796422, 1.070541000579430
, 1.872070722040661, 1.001756457029345, 1.378809939003001, 1.847850278543804, 2.085003935284227
, 2.313122510412947, 0.650676881494584, 0.773551613369587, 2.136102299351586, 1.341515248421647
, 1.183940022628347, 1.377562113620296, 1.850185830133746, 1.232112165168803, 0.671923793165544
, 1.099946548218587, 1.056844894152012, 2.601133375396755, 1.391207328945862, 1.541896787253508
, 1.595966007631807, 0.923057590473980, 1.206415179152940, 1.275536301443908, 0.583420447186398
, 1.285040337652167, 1.540648406694559, 1.054438062631050, 1.902387509769504, 1.621409166908371
, 0.944812793164613, 1.100477476680040, 0.988327442233132, 1.728654388101105, 1.628053244977060
, 1.060760561571943, 1.538416018178277, 2.410108392389236, 1.751316245100324, 1.563790463015108
, 0.481219389518454, 0.994165631555275, 1.337016990968870, 1.109088579526755, 0.321407029232422
, 0.720641073906049, 1.895735773634961, 0.177585824024661, 1.996485240483058, 0.403199585960614
, 1.487121772300537, 1.177769008306152, 0.701273995641151, 1.302101876486422, 0.510537251157601
, 1.491444215081535, 1.352963516576160, 0.339422616073620, 0.340565840833962, 0.575265488888648
, 0.199078454324122, 1.068868838035460, 1.889502831203267, 1.386174255623796, 1.211807597487022
, 1.997063801362690, 0.453444401722453, 2.184735356478338, 0.478137766710008, 1.206426055203951
, 0.555876664495711, 1.280274233919441, 0.095813804344955, 1.706079097312628, 1.943477111398666
, 2.230140630510882, 2.946309044620703, 1.186142019401047, 0.795814141941795, 0.460857387230226
, 1.190772316835832, 1.327362504940310, 1.696595922853605, 0.416190042989537, 1.472083830192951
, 1.206395605479538, 0.612524363189761, 2.362058183247366, 1.336246455616561, 1.077916969428414
, 2.385755851351826, 1.460727990062456, 1.096704997935700, 1.913474394478998, 1.233385699260248,
1.270577147048640, 1.509727846778659, 0.956645085964223, 0.739599713571419, 1.315249583679571,
2.008261585625269, 1.021943728886631, 0.488828195617451, 1.083244894832682, 0.844912313732214,
1.013054512108690, 1.893114294699785, 1.016751451332806, 0.994570044372612, 0.945503828258995])
def test_docstrings():
import doctest
doctest.testmod()
if __name__ == '__main__':
test_docstrings()
#test_docstrings()
test_gridcount()

@ -1,5 +1,5 @@
import numpy as np
def meshgrid(*xi,**kwargs):
def meshgrid(*xi, **kwargs):
"""
Return coordinate matrices from one or more coordinate vectors.
@ -17,7 +17,10 @@ def meshgrid(*xi,**kwargs):
If True a sparse grid is returned in order to conserve memory.
copy : True (default) or False (optional)
If False a view into the original arrays are returned in order to
conserve memory
conserve memory. Please note that sparse=False, copy=False will likely
return non-contiguous arrays. Furthermore, more than one element of a
broadcasted array may refer to a single memory location. If you
need to write to the arrays, make copies first.
Returns
-------
@ -27,6 +30,23 @@ def meshgrid(*xi,**kwargs):
or ``(N2, N1, N3,...Nn)`` shaped arrays if indexing='xy'
with the elements of `xi` repeated to fill the matrix along
the first dimension for `x1`, the second for `x2` and so on.
Notes
-----
This function supports both indexing conventions through the indexing keyword
argument. Giving the string 'ij' returns a meshgrid with matrix indexing,
while 'xy' returns a meshgrid with Cartesian indexing. The difference is
illustrated by the following code snippet:
xv, yv = meshgrid(x, y, sparse=False, indexing='ij')
for i in range(nx):
for j in range(ny):
# treat xv[i,j], yv[i,j]
xv, yv = meshgrid(x, y, sparse=False, indexing='xy')
for i in range(nx):
for j in range(ny):
# treat xv[j,i], yv[j,i]
See Also
--------
@ -37,96 +57,77 @@ def meshgrid(*xi,**kwargs):
Examples
--------
>>> x = np.linspace(0,1,3) # coordinates along x axis
>>> y = np.linspace(0,1,2) # coordinates along y axis
>>> xv, yv = meshgrid(x,y) # extend x and y for a 2D xy grid
>>> nx, ny = (3, 2)
>>> x = np.linspace(0, 1, nx)
>>> y = np.linspace(0, 1, ny)
>>> xv, yv = meshgrid(x, y)
>>> xv
array([[ 0. , 0.5, 1. ],
[ 0. , 0.5, 1. ]])
>>> yv
array([[ 0., 0., 0.],
[ 1., 1., 1.]])
>>> xv, yv = meshgrid(x,y, sparse=True) # make sparse output arrays
>>> xv, yv = meshgrid(x, y, sparse=True) # make sparse output arrays
>>> xv
array([[ 0. , 0.5, 1. ]])
>>> yv
array([[ 0.],
[ 1.]])
>>> meshgrid(x,y,sparse=True,indexing='ij') # change to matrix indexing
[array([[ 0. ],
[ 0.5],
[ 1. ]]), array([[ 0., 1.]])]
>>> meshgrid(x,y,indexing='ij')
[array([[ 0. , 0. ],
[ 0.5, 0.5],
[ 1. , 1. ]]),
array([[ 0., 1.],
[ 0., 1.],
[ 0., 1.]])]
>>> meshgrid(0,1,5) # just a 3D point
[array([[[0]]]), array([[[1]]]), array([[[5]]])]
>>> map(np.squeeze,meshgrid(0,1,5)) # just a 3D point
[array(0), array(1), array(5)]
>>> meshgrid(3)
array([3])
>>> meshgrid(y) # 1D grid; y is just returned
array([ 0., 1.])
`meshgrid` is very useful to evaluate functions on a grid.
>>> x = np.arange(-5, 5, 0.1)
>>> y = np.arange(-5, 5, 0.1)
>>> xx, yy = meshgrid(x, y, sparse=True)
>>> z = np.sin(xx**2+yy**2)/(xx**2+yy**2)
>>> import matplotlib.pyplot as plt
>>> h = plt.contourf(x,y,z)
"""
copy = kwargs.get('copy',True)
copy_ = kwargs.get('copy', True)
args = np.atleast_1d(*xi)
if not isinstance(args, list):
if args.size>0:
return args.copy() if copy else args
else:
raise TypeError('meshgrid() take 1 or more arguments (0 given)')
sparse = kwargs.get('sparse',False)
indexing = kwargs.get('indexing','xy') # 'ij'
ndim = len(args)
if not isinstance(args, list) or ndim<2:
raise TypeError('meshgrid() takes 2 or more arguments (%d given)' % int(ndim>0))
sparse = kwargs.get('sparse', False)
indexing = kwargs.get('indexing', 'xy')
ndim = len(args)
s0 = (1,)*ndim
output = [x.reshape(s0[:i]+(-1,)+s0[i+1::]) for i, x in enumerate(args)]
output = [x.reshape(s0[:i] + (-1,) + s0[i + 1::]) for i, x in enumerate(args)]
shape = [x.size for x in output]
if indexing == 'xy':
# switch first and second axis
output[0].shape = (1,-1) + (1,)*(ndim-2)
output[1].shape = (-1, 1) + (1,)*(ndim-2)
shape[0],shape[1] = shape[1],shape[0]
output[0].shape = (1, -1) + (1,)*(ndim - 2)
output[1].shape = (-1, 1) + (1,)*(ndim - 2)
shape[0], shape[1] = shape[1], shape[0]
if sparse:
if copy:
if copy_:
return [x.copy() for x in output]
else:
return output
else:
# Return the full N-D matrix (not only the 1-D vector)
if copy:
mult_fact = np.ones(shape,dtype=int)
return [x*mult_fact for x in output]
if copy_:
mult_fact = np.ones(shape, dtype=int)
return [x * mult_fact for x in output]
else:
return np.broadcast_arrays(*output)
def ndgrid(*args,**kwargs):
def ndgrid(*args, **kwargs):
"""
Same as calling meshgrid with indexing='ij' (see meshgrid for
documentation).
"""
kwargs['indexing'] = 'ij'
return meshgrid(*args,**kwargs)
return meshgrid(*args, **kwargs)
if __name__=='__main__':
if __name__ == '__main__':
import doctest
doctest.testmod()

@ -1139,7 +1139,6 @@ class RegLogit(object):
options = options struct defining the calculation
.alpha : confidence coefficient (default 0.05)
.size : size if binomial family (default 1).
'''
[mx, nx] = self.X.shape

@ -7,7 +7,7 @@ Created on 20. nov. 2010
import numpy as np
from numpy import array
import wafo.kdetools as wk
import pylab as plb
#import pylab as plb
def test0_KDE1D():
'''
@ -146,6 +146,68 @@ def test2a_KDE1D():
h1 = plb.plot(x, f) # 1D probability density plot
'''
def test_KDE2D():
'''
N = 20
data = np.random.rayleigh(1, size=(2, N))
>>> data = array([[ 0.38103275, 0.35083136, 0.90024207, 1.88230239, 0.96815399,
... 0.57392873, 1.63367908, 1.20944125, 2.03887811, 0.81789145,
... 0.69302049, 1.40856592, 0.92156032, 2.14791432, 2.04373821,
... 0.69800708, 0.58428735, 1.59128776, 2.05771405, 0.87021964],
... [ 1.44080694, 0.39973751, 1.331243 , 2.48895822, 1.18894158,
... 1.40526085, 1.01967897, 0.81196474, 1.37978932, 2.03334689,
... 0.870329 , 1.25106862, 0.5346619 , 0.47541236, 1.51930093,
... 0.58861519, 1.19780448, 0.81548296, 1.56859488, 1.60653533]])
>>> x = np.linspace(0, max(data.ravel()) + 1, 3)
>>> kde = wk.KDE(data, hs=0.5, alpha=0.5)
>>> kde0 = wk.KDE(data, hs=0.5, alpha=0.0, inc=16)
>>> kde0.eval_grid(x, x)
>>> kde0.eval_grid_fast(x, x)
'''
def test_smooth_params():
'''
>>> data = np.array([[ 0.932896 , 0.89522635, 0.80636346, 1.32283371, 0.27125435,
... 1.91666304, 2.30736635, 1.13662384, 1.73071287, 1.06061127,
... 0.99598512, 2.16396591, 1.23458213, 1.12406686, 1.16930431,
... 0.73700592, 1.21135139, 0.46671506, 1.3530304 , 0.91419104],
... [ 0.62759088, 0.23988169, 2.04909823, 0.93766571, 1.19343762,
... 1.94954931, 0.84687514, 0.49284897, 1.05066204, 1.89088505,
... 0.840738 , 1.02901457, 1.0758625 , 1.76357967, 0.45792897,
... 1.54488066, 0.17644313, 1.6798871 , 0.72583514, 2.22087245],
... [ 1.69496432, 0.81791905, 0.82534709, 0.71642389, 0.89294732,
... 1.66888649, 0.69036947, 0.99961448, 0.30657267, 0.98798713,
... 0.83298728, 1.83334948, 1.90144186, 1.25781913, 0.07122458,
... 2.42340852, 2.41342037, 0.87233305, 1.17537114, 1.69505988]])
>>> gauss = wk.Kernel('gaussian')
>>> gauss.hns(data)
array([ 0.18154437, 0.36207987, 0.37396219])
>>> gauss.hos(data)
array([ 0.195209 , 0.3893332 , 0.40210988])
>>> gauss.hmns(data)
array([[ 3.25196193e-01, -2.68892467e-02, 3.18932448e-04],
[ -2.68892467e-02, 3.91283306e-01, 2.38654678e-02],
[ 3.18932448e-04, 2.38654678e-02, 4.05123874e-01]])
>>> gauss.hscv(data)
array([ 0.16858959, 0.33034332, 0.3046287 ])
>>> gauss.hstt(data)
array([ 0.18196282, 0.51090571, 0.1111913 ])
>>> gauss.hste(data)
array([ 0.1683984 , 0.29693232, 0.17974833])
>>> gauss.hldpi(data)
array([ 0.17426948, 0.33672307, 0.31240374])
'''
def test_gridcount_1D():
'''
N = 20
@ -198,6 +260,41 @@ def test_gridcount_2D():
h1 = plb.plot(x, c / dx / N) # 1D probability density plot
t = np.trapz(c / dx / N, x)
print(t)
'''
def test_gridcount_3D():
'''
N = 20
data = np.random.rayleigh(1, size=(3, N))
>>> data = np.array([[ 0.932896 , 0.89522635, 0.80636346, 1.32283371, 0.27125435,
... 1.91666304, 2.30736635, 1.13662384, 1.73071287, 1.06061127,
... 0.99598512, 2.16396591, 1.23458213, 1.12406686, 1.16930431,
... 0.73700592, 1.21135139, 0.46671506, 1.3530304 , 0.91419104],
... [ 0.62759088, 0.23988169, 2.04909823, 0.93766571, 1.19343762,
... 1.94954931, 0.84687514, 0.49284897, 1.05066204, 1.89088505,
... 0.840738 , 1.02901457, 1.0758625 , 1.76357967, 0.45792897,
... 1.54488066, 0.17644313, 1.6798871 , 0.72583514, 2.22087245],
... [ 1.69496432, 0.81791905, 0.82534709, 0.71642389, 0.89294732,
... 1.66888649, 0.69036947, 0.99961448, 0.30657267, 0.98798713,
... 0.83298728, 1.83334948, 1.90144186, 1.25781913, 0.07122458,
... 2.42340852, 2.41342037, 0.87233305, 1.17537114, 1.69505988]])
>>> x = np.linspace(0, max(data.ravel()) + 1, 3)
>>> dx = x[1] - x[0]
>>> X = np.vstack((x, x, x))
>>> c = wk.gridcount(data, X)
>>> c
array([[[ 8.74229894e-01, 1.44969128e+00, 7.49265424e-02],
[ 1.94778915e+00, 2.28951650e+00, 8.53886762e-02],
[ 1.08429416e-01, 1.10905565e-01, 4.16196568e-04]],
<BLANKLINE>
[[ 1.27910940e+00, 2.58396370e+00, 2.18142488e-01],
[ 2.59536282e+00, 4.49653348e+00, 3.73415131e-01],
[ 1.69571495e-01, 3.18733817e-01, 1.62218824e-02]],
<BLANKLINE>
[[ 1.42033973e-01, 2.45459949e-01, 0.00000000e+00],
[ 3.28213680e-01, 2.73167915e-01, 0.00000000e+00],
[ 7.48896775e-03, 1.12880816e-02, 0.00000000e+00]]])
'''
def test_gridcount_4D():
'''
@ -261,6 +358,10 @@ def test_gridcount_4D():
t = np.trapz(x, c / dx / N)
print(t)
'''
if __name__ == '__main__':
def test_docstrings():
import doctest
doctest.testmod()
doctest.testmod()
if __name__ == '__main__':
test_docstrings()
Loading…
Cancel
Save