Fixed some bugs

master
Per.Andreas.Brodtkorb 13 years ago
parent bac74a93b5
commit 9dceca985e

@ -13,7 +13,7 @@ from __future__ import division
from itertools import product
from misc import tranproc #, trangood
from numpy import pi, sqrt, atleast_2d, exp, newaxis #@UnresolvedImport
from scipy import interpolate, linalg
from scipy import interpolate, linalg, sparse
from scipy.special import gamma
from wafo.misc import meshgrid
from wafo.wafodata import WafoData
@ -128,11 +128,11 @@ class _KDE(object):
if self.xmin is None:
self.xmin = amin - offset
else:
self.xmin = self.xmin * np.ones(self.d)
self.xmin = self.xmin * np.ones((self.d,1))
if self.xmax is None:
self.xmax = amax + offset
else:
self.xmax = self.xmax * np.ones(self.d)
self.xmax = self.xmax * np.ones((self.d,1))
def eval_grid_fast(self, *args, **kwds):
"""Evaluate the estimated pdf on a grid.
@ -187,14 +187,20 @@ class _KDE(object):
else:
titlestr = 'Kernel density estimate (%s)' % self.kernel.name
kwds2 = dict(title=titlestr)
kwds2['plot_kwds'] = dict(plotflag=1)
kwds2.update(**kwds)
if self.d == 1:
args = args[0]
elif self.d > 1:
wdata = WafoData(f, args, **kwds2)
if self.d > 1:
PL = np.r_[10:90:20, 95, 99, 99.9]
try:
ql = qlevels(f, p=PL)
kwds2.setdefault('levels', ql)
return WafoData(f, args, **kwds2)
wdata.clevels = ql
wdata.plevels = PL
except:
pass
return wdata
def _check_shape(self, points):
points = atleast_2d(points)
@ -329,17 +335,17 @@ class TKDE(_KDE):
tdataset = self._dat2gaus(self.dataset)
xmin = self.xmin
if xmin is not None:
xmin = self._dat2gaus(xmin)
xmin = self._dat2gaus(np.reshape(xmin,(-1,1)))
xmax = self.xmax
if xmax is not None:
xmax = self._dat2gaus(xmax)
xmax = self._dat2gaus(np.reshape(xmax,(-1,1)))
self.tkde = KDE(tdataset, self.hs, self.kernel, self.alpha, xmin, xmax,
self.inc)
def _check_xmin(self):
if self.L2 is not None:
amin = self.dataset.min(axis= -1)
L2 = np.atleast_1d(self.L2) * np.ones(self.d) # default no transformation
self.xmin = np.where(L2 != 1, np.maximum(self.xmin, amin / 100.0), self.xmin)
self.xmin = np.where(L2 != 1, np.maximum(self.xmin, amin / 100.0), self.xmin).reshape((-1,1))
def _dat2gaus(self, points):
if self.L2 is None:
@ -422,7 +428,7 @@ class TKDE(_KDE):
fi = pdf(*args)
self.args = args
#fi.shape = ipoints[0].shape
return fi
return fi*(fi>0)
return f
def _eval_grid(self, *args):
if self.L2 is None:
@ -602,7 +608,6 @@ class KDE(_KDE):
def _eval_grid_fast(self, *args):
# TODO: This does not work correctly yet! Check it.
X = np.vstack(args)
d, inc = X.shape
dx = X[:, 1] - X[:, 0]
@ -1631,6 +1636,95 @@ def mkernel(X, kernel):
fun = _MKERNEL_DICT[kernel[:4]]
return fun(np.atleast_2d(X))
def accumsum(accmap, a, size=None, dtype=None):
"""
A sum accumulation function
Parameters
----------
accmap : ndarray
This is the "accumulation map". It maps input (i.e. indices into
`a`) to their destination in the output array. The first `a.ndim`
dimensions of `accmap` must be the same as `a.shape`. That is,
`accmap.shape[:a.ndim]` must equal `a.shape`. For example, if `a`
has shape (15,4), then `accmap.shape[:2]` must equal (15,4). In this
case `accmap[i,j]` gives the index into the output array where
element (i,j) of `a` is to be accumulated. If the output is, say,
a 2D, then `accmap` must have shape (15,4,2). The value in the
last dimension give indices into the output array. If the output is
1D, then the shape of `accmap` can be either (15,4) or (15,4,1)
a : ndarray
The input data to be accumulated.
size : ndarray or None
The size of the output array. If None, the size will be determined
from `accmap`.
dtype : numpy data type, or None
The data type of the output array. If None, the data type of
`a` is used.
Returns
-------
out : ndarray
The accumulated results.
The shape of `out` is `size` if `size` is given. Otherwise the
shape is determined by the (lexicographically) largest indices of
the output found in `accmap`.
Examples
--------
>>> from numpy import array, prod
>>> a = array([[1,2,3],[4,-1,6],[-1,8,9]])
>>> a
array([[ 1, 2, 3],
[ 4, -1, 6],
[-1, 8, 9]])
>>> # Sum the diagonals.
>>> accmap = array([[0,1,2],[2,0,1],[1,2,0]])
>>> s = accum(accmap, a)
>>> s
array([ 9, 7, 15])
>>> # A 2D output, from sub-arrays with shapes and positions like this:
>>> # [ (2,2) (2,1)]
>>> # [ (1,2) (1,1)]
>>> accmap = array([
... [[0,0],[0,0],[0,1]],
... [[0,0],[0,0],[0,1]],
... [[1,0],[1,0],[1,1]]])
>>> # Accumulate using a product.
>>> accum(accmap, a, func=prod, dtype=float)
array([[ -8., 18.],
[ -8., 9.]])
>>> # Same accmap, but create an array of lists of values.
>>> accum(accmap, a, func=lambda x: x, dtype='O')
array([[[1, 2, 4, -1], [3, 6]],
[[-1, 8], [9]]], dtype=object)
"""
# Check for bad arguments and handle the defaults.
if accmap.shape[:a.ndim] != a.shape:
raise ValueError("The initial dimensions of accmap must be the same as a.shape")
if dtype is None:
dtype = a.dtype
adims = tuple(range(a.ndim))
if size is None:
size = 1 + np.squeeze(np.apply_over_axes(np.max, accmap, axes=adims))
size = np.atleast_1d(size)
if len(size)>1:
binx = accmap[:,0]
biny = accmap[:,1]
out = np.asarray(sparse.coo_matrix((a.ravel(), (binx, biny)),shape=size, dtype=dtype).todense()).reshape(size)
else:
binx = accmap.ravel()
zero = np.zeros(len(binx))
out = np.asarray(sparse.coo_matrix((a.ravel(), (binx, zero)),shape=(size,1), dtype=dtype).todense()).reshape(size)
return out
def accumsum2(accmap, a, size):
return np.bincount(accmap.ravel(), a.ravel(), np.array(size).max())
def accum(accmap, a, func=None, size=None, fill_value=0, dtype=None):
"""
@ -2156,7 +2250,7 @@ def bitget(int_type, offset):
return np.bitwise_and(int_type, 1 << offset) >> offset
def gridcount(data, X):
def gridcount(data, X, use_sparse=False):
'''
Returns D-dimensional histogram using linear binning.
@ -2231,25 +2325,28 @@ def gridcount(data, X):
raise ValueError('X does not include whole range of the data!')
csiz = np.repeat(inc, d)
if use_sparse:
acfun = accumsum # faster than accum
else:
acfun = accumsum2 #accum
binx = np.asarray(np.floor((dat - xlo[:, newaxis]) / dx), dtype=int)
w = dx.prod()
abs = np.abs
if d == 1:
x.shape = (-1,)
c = (accum(binx, (x[binx + 1] - dat), size=[inc, ]) +
accum(binx, (dat - x[binx]), size=[inc, ])) / w
elif d == 2:
b2 = binx[1]
b1 = binx[0]
c_ = np.c_
stk = np.vstack
c = (accum(c_[b1, b2] , abs(np.prod(stk([X[0, b1 + 1], X[1, b2 + 1]]) - dat, axis=0)), size=[inc, inc]) +
accum(c_[b1 + 1, b2] , abs(np.prod(stk([X[0, b1], X[1, b2 + 1]]) - dat, axis=0)), size=[inc, inc]) +
accum(c_[b1 , b2 + 1], abs(np.prod(stk([X[0, b1 + 1], X[1, b2]]) - dat, axis=0)), size=[inc, inc]) +
accum(c_[b1 + 1, b2 + 1], abs(np.prod(stk([X[0, b1], X[1, b2]]) - dat, axis=0)), size=[inc, inc])) / w
c = c.T # make sure c is stored in the same way as meshgrid
c = (acfun(binx, (x[binx + 1] - dat), size=[inc, ]) +
acfun(binx+1, (dat - x[binx]), size=[inc, ])) / w
# elif d == 2:
# b2 = binx[1]
# b1 = binx[0]
# c_ = np.c_
# stk = np.vstack
# c = (acfun(c_[b1, b2] , abs(np.prod(stk([X[0, b1 + 1], X[1, b2 + 1]]) - dat, axis=0)), size=[inc, inc]) +
# acfun(c_[b1 + 1, b2] , abs(np.prod(stk([X[0, b1], X[1, b2 + 1]]) - dat, axis=0)), size=[inc, inc]) +
# acfun(c_[b1 , b2 + 1], abs(np.prod(stk([X[0, b1 + 1], X[1, b2]]) - dat, axis=0)), size=[inc, inc]) +
# acfun(c_[b1 + 1, b2 + 1], abs(np.prod(stk([X[0, b1], X[1, b2]]) - dat, axis=0)), size=[inc, inc])) / w
# c = c.T # make sure c is stored in the same way as meshgrid
else: # % d>2
Nc = csiz.prod()
@ -2270,13 +2367,13 @@ def gridcount(data, X):
b1 = np.sum((binx + bt0[one]) * fact1, axis=0) #linear index to c
bt2 = bt0[two] + fact2
b2 = binx + bt2 # linear index to X
c += accum(b1, abs(np.prod(X1[b2] - dat, axis=0)), size=(Nc,))
c += acfun(b1, abs(np.prod(X1[b2] - dat, axis=0)), size=(Nc,))
c = np.reshape(c / w, csiz, order='C')
# TODO: check that the flipping of axis is correct
c = np.reshape(c / w, csiz, order='F')
T = range(d)
T[-2], T[-1] = T[-1], T[-2]
T[1], T[0] = T[0], T[1]
#T[-2], T[-1] = T[-1], T[-2]
c = c.transpose(*T) # make sure c is stored in the same way as meshgrid
return c
@ -2291,7 +2388,7 @@ def kde_demo1():
'''
import scipy.stats as st
x = np.linspace(-4, 4)
x = np.linspace(-4, 4, 101)
x0 = x / 2.0
data = np.random.normal(loc=0, scale=1.0, size=7) #rndnorm(0,1,7,1);
kernel = Kernel('gaus')
@ -2330,7 +2427,6 @@ def kde_demo2():
pylab.figure(0)
f.plot()
pylab.plot(x, st.rayleigh.pdf(x, scale=1), ':')
#plotnorm((data).^(L2)) % gives a straight line => L2 = 0.5 reasonable
@ -2344,153 +2440,36 @@ def kde_demo2():
pylab.figure(0)
def test_gridcount():
import numpy as np
#import wafo.kdetools as wk
from matplotlib import pyplot as plb
data = data_rayleigh()
N = len(data)
x = np.linspace(0,max(data)+1,50)
dx = x[1]-x[0]
c = gridcount(data,x)
ctr = np.array([ 0, 4, 10, 14, 15, 23, 16, 18, 21, 19, 37, 32, 24, 29, 29,
24, 29, 26, 14, 13, 23, 9, 13, 11, 7, 12, 5, 2, 2, 6,
2, 2, 5, 0, 0, 0, 1, 2, 1, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0])
print(np.abs(c-ctr)<1e-13)
pdf = c/dx/N
h = plb.plot(x,c,'.') # 1D histogram
plb.show()
pass
def kde_demo3():
'''Demonstrate the difference between and ordinary-KDE
KDEDEMO3 shows that the transformation KDE is a better estimate for
Rayleigh distributed data around 0 than the ordinary KDE.
'''
import scipy.stats as st
data = st.rayleigh.rvs(scale=1, size=(2,300))
#x = np.linspace(1.5e-3, 5, 55)
data1 = data.reshape((2,-1))
c2 = gridcount(data1, np.vstack((x,x)))
c2t = np.array([ 0, 0.635018844262034, 1.170430267508894, 0.480210926714613,
1.256122839305450, 2.050244222017545, 1.250782602003382,
1.253065702416950, 1.295571917793612, 1.978725535031301,
0.829707237562507, 1.842636337195244, 2.767829900577593,
1.449607074995753, 2.759640664415913, 1.634036650764552,
1.990159690320205, 1.201953891214720, 1.277182991907633,
1.293002868977407, 1.157268941919115, 1.275443411253732,
1.132629693243210, 1.418284942741350, 0.770571572744340,
0.475119743103730, 0.982244208825375, 0.681834272971076,
0.359044910769366, 0.582570635672141, 0.658412627992049,
0.784814887272479, 0.448937228228751, 0.314220262358783,
0, 0, 0, 0.404906838651598, 0.113164094088927,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0])
print(np.abs((c2t-c2.max(axis=-1)))<1e-13)
data4 = data.reshape((4,-1))
x = np.linspace(0,max(data)+1,11)
c4 = gridcount(data4, np.vstack((x,x,x,x)))
print(np.abs((c2t-c2.max(axis=-1)))<1e-13)
def data_rayleigh():
return np.array([1.412423619721313, 0.936610012402041, 3.408880790209544, 0.712493911648517, 1.453856820100018,
1.362971623321745, 0.989738148038997, 0.553839936552347, 0.225638048436888, 1.045606709473107,
0.637908826214993, 1.608606426103143, 0.961884939327567, 1.919572795331000, 1.627957520304931,
1.301044712134641, 0.623895791202139, 2.512180741739924, 0.785268132885580, 2.273629106639021,
0.711768125619732, 0.967169434614618, 0.427942932230904, 1.429667825110794, 0.631194936581454,
0.303636149404372, 1.602725333691658, 0.923957338868325, 1.470119382037774, 0.984169729516054,
1.405066725863521, 0.209225286647146, 2.197407087587261, 1.795680986321718, 1.655186235334962,
1.831529484073858, 0.983242434909240, 1.385965094654130, 1.309069260384021, 1.228928476737294,
0.802097056076411, 0.756115979892675, 1.096194354290486, 0.718886376439466, 1.806619521908829,
2.924438974501607, 0.246782936313644, 1.238666429277650, 0.426858243844038, 1.799972319758650,
3.007697177898959, 0.372270006672035, 2.367882325903836, 0.191545163286195, 1.517565471255659,
1.750004351582044, 1.236013671840509, 2.081323476045300, 2.141346897323470, 1.402378494050162,
0.544698152936965, 0.700923468199988, 0.634137874072855, 0.292299453493133, 0.475611960045215,
1.384390337219331, 2.369715926664043, 0.935586970891954, 1.028299144484800, 2.883486469293792
, 2.412885676436705, 1.502666625449783, 2.982736936434333, 1.706454468843206, 0.906120073100622
, 1.473661960328491, 0.748241351472675, 0.836991325956595, 1.509961488710520, 1.225113736935942
, 1.029890543888216, 1.358608202305835, 1.666359355892623, 1.323592437751299, 1.266885170390769
, 1.323660367761004, 1.197556616382116, 0.415219867081348, 1.594635770596585, 3.335047448446035
, 0.935717067162817, 2.664366406420023, 0.922317019697774, 2.086307246777435, 1.101280854500658
, 1.032916883571698, 0.700796651725546, 0.518227310036530, 0.859641628285530, 1.609352902696174
, 1.747723418451391, 1.538490395884064, 0.140361038832643, 1.925029474333574, 2.260668891490430
, 1.716877040260210, 0.295284687152802, 0.974796888317386, 1.561117460932286, 1.617115585994090
, 0.712684884618426, 0.791728102952554, 1.495252766892452, 1.139399282670031, 1.398150348314015
, 0.734533909397005, 0.624418865181972, 1.881415056762913, 1.706681395455110, 2.334683483141081
, 0.477838065222462, 0.634304509316731, 0.456849600683082, 1.160070279761997, 0.655340613711381
, 2.127121229851198, 0.456835914801069, 0.300568039387414, 1.276598603254562, 1.720804090031422
, 0.864730384700170, 0.628029981916123, 0.909872945858993, 0.686886746420088, 1.194705989905012
, 2.176393257858438, 1.408082540391850, 0.462744617618753, 0.995689247143699, 0.335155890849689
, 1.179590017201302, 1.063149176870603, 3.468688654992744, 1.827129780001552, 1.153130138387220
, 2.120636338813882, 0.544011313217379, 0.994288065215423, 2.290060076679768, 2.233778068583924
, 1.581312813059112, 1.387961284638806, 0.917070930336856, 2.344909067035151, 0.516281292935132
, 1.619570115238485, 1.442087344999343, 1.892443909224431, 1.007935276931834, 1.664682222719219
, 1.899024552783311, 0.882368714153905, 1.267711468232034, 2.781870230167854, 1.262515173989300
, 0.895667370955997, 1.390103843633942, 0.945814188732813, 1.680879252209405, 1.033698343725955
, 1.164434112863078, 1.540520689869044, 2.684068016815589, 0.891215308218909, 1.907325227589101
, 1.639214228101874, 2.483108383603044, 0.254728352176505, 0.939581631904974, 1.474208721908681
, 0.813131900087889, 0.723688300231953, 1.575927348326343, 1.399779277625481, 1.336475769769517
, 1.469760951955162, 0.312162051579979, 0.926191271942077, 1.095698311512132, 0.742466620037192
, 1.584565588783017, 1.969369796694313, 0.813142402654688, 1.620637451940408, 0.544472183788396
, 1.903841273483371, 0.546256895921489, 1.332096299659611, 0.954938400592347, 1.813185033558344
, 1.183839081745172, 1.159783992966029, 2.047421367071099, 0.933411156868096, 1.092543634708061
, 2.573430838154017, 0.294001371116989, 2.687145854798348, 0.647676314841560, 1.483222246093897
, 1.328873011650546, 1.499517291077073, 0.946451616282504, 1.391629977859238, 1.825818674800223
, 0.197207089634922, 0.418570979518484, 2.713292260256486, 1.451678603677107, 0.725222188153537
, 1.016524657331659, 0.510160866644535, 0.790663553688482, 0.772267750711634, 0.897737257071539
, 0.574718435129065, 0.924902659911130, 0.509352052679121, 2.076287755824404, 0.445024255426400
, 2.306443399859831, 1.009151694589026, 0.311646355326560, 0.915552448311802, 1.631979165302650
, 1.779435892929737, 1.254791667465325, 1.522546690241251, 2.117005924369452, 0.335708348510442
, 0.850786945794020, 0.307546485903476, 0.659553530770440, 1.595968673282009, 1.599529339207843
, 2.050409047591333, 1.321597656988126, 0.382901575350795, 2.263023675024229, 1.795160219589414
, 0.820728808594631, 1.252616635345433, 2.893059873111469, 1.585968547208113, 1.911105489168721
, 1.065697540675240, 1.127880912464618, 1.282656038601722, 0.791791712034066, 1.662754292624110
, 1.184021211521453, 1.442739185251488, 0.857673288506446, 0.546518081971571, 1.136176847824479
, 0.948827835556975, 1.761649333500106, 1.740961388239338, 1.486044626143792, 0.535345914616625
, 0.208765940502775, 1.281107790531077, 0.845985407399993, 2.367961441281100, 2.813630157287030
, 0.821877833204895, 1.796411857645166, 2.128114489536385, 1.349167308872121, 2.075721258630550
, 2.399008601572707, 1.262250789152573, 1.614544130176768, 1.311344244094387, 0.228900207318000
, 1.087703540854728, 1.441743192607425, 1.213654375953261, 0.965104247192400, 2.352343682973261
, 1.881070184767099, 1.944757925743782, 0.965470015113788, 1.341190290874416, 2.029803572337272
, 2.328337398097465, 1.485947310986503, 0.680661741126981, 1.456629522069083, 0.386727549117631
, 1.021861509017076, 1.482839980680464, 2.329786461679046, 1.825236378759161, 1.151270272972182
, 1.681465022236889, 1.038893153052472, 2.671305569135296, 2.973463508311512, 1.998091967015353
, 0.992439538152367, 1.101359057223470, 0.752694797719731, 1.751820513743222, 2.070842495255286
, 2.213621940109904, 1.278350678290866, 1.351639733749908, 0.567799782374724, 2.144632385787214
, 1.094123263719430, 0.678615107641789, 2.144341891738539, 1.695846624058156, 2.069396249839028
, 0.819027610733285, 1.495651321040951, 1.477482666605742, 0.511932330475827, 1.022837224533765
, 0.802470556959117, 1.588170058614226, 0.816352471969601, 2.128510415901388, 1.871914791729839
, 0.994323676062132, 1.173849936976207, 1.540652455108271, 1.896308447022061, 1.371611808573705
, 1.307706279079749, 0.888355489837264, 1.104161992788381, 1.581802123863791, 2.077336259709684
, 1.597514520759674, 0.193846187739953, 1.498827901810269, 1.074392126178632, 1.073250683084153
, 0.498942436443271, 1.836126539886937, 0.886372885469560, 0.751884958648598, 0.916116650002177
, 0.970681891155015, 1.257679318479529, 1.284798886225563, 1.003879276488743, 1.007685729946785
, 1.203631029712442, 1.463948632472297, 2.455282398854625, 1.600867640016765, 1.010899145846306
, 1.888399192628552, 0.537142702822369, 0.353191429514348, 0.419544177537439, 0.598339442960937
, 0.885310772269136, 0.847519694333472, 0.153295465546788, 1.246051759006313, 0.447732587957780
, 0.562898114036050, 1.412332385111654, 1.980540530235424, 2.704891701651084, 1.300708887507808
, 3.394236570275002, 1.269967710402906, 1.203787442037781, 0.896098313870595, 1.060303799139334
, 1.163522680114773, 0.383891805234107, 2.091377862339729, 0.365559694796422, 1.070541000579430
, 1.872070722040661, 1.001756457029345, 1.378809939003001, 1.847850278543804, 2.085003935284227
, 2.313122510412947, 0.650676881494584, 0.773551613369587, 2.136102299351586, 1.341515248421647
, 1.183940022628347, 1.377562113620296, 1.850185830133746, 1.232112165168803, 0.671923793165544
, 1.099946548218587, 1.056844894152012, 2.601133375396755, 1.391207328945862, 1.541896787253508
, 1.595966007631807, 0.923057590473980, 1.206415179152940, 1.275536301443908, 0.583420447186398
, 1.285040337652167, 1.540648406694559, 1.054438062631050, 1.902387509769504, 1.621409166908371
, 0.944812793164613, 1.100477476680040, 0.988327442233132, 1.728654388101105, 1.628053244977060
, 1.060760561571943, 1.538416018178277, 2.410108392389236, 1.751316245100324, 1.563790463015108
, 0.481219389518454, 0.994165631555275, 1.337016990968870, 1.109088579526755, 0.321407029232422
, 0.720641073906049, 1.895735773634961, 0.177585824024661, 1.996485240483058, 0.403199585960614
, 1.487121772300537, 1.177769008306152, 0.701273995641151, 1.302101876486422, 0.510537251157601
, 1.491444215081535, 1.352963516576160, 0.339422616073620, 0.340565840833962, 0.575265488888648
, 0.199078454324122, 1.068868838035460, 1.889502831203267, 1.386174255623796, 1.211807597487022
, 1.997063801362690, 0.453444401722453, 2.184735356478338, 0.478137766710008, 1.206426055203951
, 0.555876664495711, 1.280274233919441, 0.095813804344955, 1.706079097312628, 1.943477111398666
, 2.230140630510882, 2.946309044620703, 1.186142019401047, 0.795814141941795, 0.460857387230226
, 1.190772316835832, 1.327362504940310, 1.696595922853605, 0.416190042989537, 1.472083830192951
, 1.206395605479538, 0.612524363189761, 2.362058183247366, 1.336246455616561, 1.077916969428414
, 2.385755851351826, 1.460727990062456, 1.096704997935700, 1.913474394478998, 1.233385699260248,
1.270577147048640, 1.509727846778659, 0.956645085964223, 0.739599713571419, 1.315249583679571,
2.008261585625269, 1.021943728886631, 0.488828195617451, 1.083244894832682, 0.844912313732214,
1.013054512108690, 1.893114294699785, 1.016751451332806, 0.994570044372612, 0.945503828258995])
kde = KDE(data)
f = kde(output='plot', title='Ordinary KDE', plotflag=1)
pylab.figure(0)
f.plot()
pylab.plot(data[0], data[1], '.')
#plotnorm((data).^(L2)) % gives a straight line => L2 = 0.5 reasonable
tkde = TKDE(data, L2=0.5)
ft = tkde.eval_grid_fast(output='plot', title='Transformation KDE', plotflag=1)
pylab.figure(1)
ft.plot()
pylab.plot(data[0],data[1], '.')
pylab.figure(0)
def test_docstrings():
@ -2498,5 +2477,5 @@ def test_docstrings():
doctest.testmod()
if __name__ == '__main__':
#test_docstrings()
test_gridcount()
test_docstrings()

@ -26,15 +26,16 @@ def test0_KDE1D():
array([ 0.2039735 , 0.40252503, 0.54595078, 0.52219649, 0.3906213 ,
0.26381501, 0.16407362, 0.08270612, 0.02991145, 0.00720821])
>>> kde0.eval_grid_fast(x)
array([ 0.32343789, 0.51366167, 0.55643329, 0.43688805, 0.28972471,
0.19445277, 0.12473331, 0.06195215, 0.02087712, 0.00449567])
array([ 0.20729484, 0.39865044, 0.53716945, 0.5169322 , 0.39060223,
0.26441126, 0.16388801, 0.08388527, 0.03227164, 0.00883579])
>>> f = kde0.eval_grid_fast(); f
array([ 0.02076721, 0.0612371 , 0.14515308, 0.27604202, 0.42001793,
0.51464781, 0.52131018, 0.45976136, 0.37621768, 0.29589521,
0.21985316, 0.1473364 , 0.08502256, 0.04063749, 0.0155788 ,
0.00466938])
array([ 0.01149411, 0.03485467, 0.08799292, 0.18568718, 0.32473136,
0.46543163, 0.54532016, 0.53005828, 0.44447651, 0.34119612,
0.25103852, 0.1754952 , 0.11072989, 0.05992731, 0.02687784,
0.00974983])
>>> np.trapz(f,kde0.args)
array([ 0.99416766])
array([ 0.99500101])
'''
def test1_TKDE1D():
'''
@ -167,9 +168,13 @@ def test_KDE2D():
>>> kde0 = wk.KDE(data, hs=0.5, alpha=0.0, inc=16)
>>> kde0.eval_grid(x, x)
array([[ 3.27260963e-02, 4.21654678e-02, 5.85338634e-04],
[ 6.78845466e-02, 1.42195839e-01, 1.41676003e-03],
[ 1.39466746e-04, 4.26983850e-03, 2.52736185e-05]])
>>> kde0.eval_grid_fast(x, x)
array([[ 0.08670654, 0.12577712, 0.00808478],
[ 0.1411195 , 0.24160579, 0.01816001],
[ 0.0031541 , 0.01553967, 0.00114854]])
'''
def test_smooth_params():
@ -197,16 +202,16 @@ def test_smooth_params():
[ -2.68892467e-02, 3.91283306e-01, 2.38654678e-02],
[ 3.18932448e-04, 2.38654678e-02, 4.05123874e-01]])
>>> gauss.hscv(data)
array([ 0.16858959, 0.33034332, 0.3046287 ])
array([ 0.16858959, 0.32739383, 0.3046287 ])
>>> gauss.hstt(data)
array([ 0.18196282, 0.51090571, 0.1111913 ])
array([ 0.18099075, 0.50409881, 0.11018912])
>>> gauss.hste(data)
array([ 0.1683984 , 0.29693232, 0.17974833])
array([ 0.16750009, 0.29059113, 0.17994255])
>>> gauss.hldpi(data)
array([ 0.17426948, 0.33672307, 0.31240374])
array([ 0.1732289 , 0.33159097, 0.3107633 ])
'''
def test_gridcount_1D():
'''
@ -219,10 +224,13 @@ def test_gridcount_1D():
>>> x = np.linspace(0, max(data.ravel()) + 1, 10)
>>> dx = x[1] - x[0]
>>> c = wk.gridcount(data, x)
>>> c = wk.gridcount(data, x, use_sparse=False)
>>> c
array([ 1., 6., 7., 2., 1., 2., 1., 0., 0., 0.])
array([ 0.78762626, 1.77520717, 7.99190087, 4.04054449, 1.67156643,
2.38228499, 1.05933195, 0.29153785, 0. , 0. ])
>>> wk.gridcount(data, x, use_sparse=True)
array([ 0.78762626, 1.77520717, 7.99190087, 4.04054449, 1.67156643,
2.38228499, 1.05933195, 0.29153785, 0. , 0. ])
h = plb.plot(x, c, '.') # 1D histogram
@ -281,19 +289,19 @@ def test_gridcount_3D():
>>> x = np.linspace(0, max(data.ravel()) + 1, 3)
>>> dx = x[1] - x[0]
>>> X = np.vstack((x, x, x))
>>> c = wk.gridcount(data, X)
>>> c = wk.gridcount(data, X, use_sparse=True)
>>> c
array([[[ 8.74229894e-01, 1.44969128e+00, 7.49265424e-02],
[ 1.94778915e+00, 2.28951650e+00, 8.53886762e-02],
[ 1.08429416e-01, 1.10905565e-01, 4.16196568e-04]],
array([[[ 8.74229894e-01, 1.27910940e+00, 1.42033973e-01],
[ 1.94778915e+00, 2.59536282e+00, 3.28213680e-01],
[ 1.08429416e-01, 1.69571495e-01, 7.48896775e-03]],
<BLANKLINE>
[[ 1.27910940e+00, 2.58396370e+00, 2.18142488e-01],
[ 2.59536282e+00, 4.49653348e+00, 3.73415131e-01],
[ 1.69571495e-01, 3.18733817e-01, 1.62218824e-02]],
[[ 1.44969128e+00, 2.58396370e+00, 2.45459949e-01],
[ 2.28951650e+00, 4.49653348e+00, 2.73167915e-01],
[ 1.10905565e-01, 3.18733817e-01, 1.12880816e-02]],
<BLANKLINE>
[[ 1.42033973e-01, 2.45459949e-01, 0.00000000e+00],
[ 3.28213680e-01, 2.73167915e-01, 0.00000000e+00],
[ 7.48896775e-03, 1.12880816e-02, 0.00000000e+00]]])
[[ 7.49265424e-02, 2.18142488e-01, 0.00000000e+00],
[ 8.53886762e-02, 3.73415131e-01, 0.00000000e+00],
[ 4.16196568e-04, 1.62218824e-02, 0.00000000e+00]]])
'''
def test_gridcount_4D():
@ -314,43 +322,43 @@ def test_gridcount_4D():
>>> X = np.vstack((x, x, x, x))
>>> c = wk.gridcount(data, X)
>>> c
array([[[[ 1.77163904e-01, 3.41883175e-01, 3.13810608e-03],
[ 1.83770124e-01, 3.58861043e-01, 7.05946179e-03],
[ 0.00000000e+00, 2.91835067e-03, 6.38695629e-04]],
array([[[[ 1.77163904e-01, 1.87720108e-01, 0.00000000e+00],
[ 5.72573585e-01, 6.09557834e-01, 0.00000000e+00],
[ 3.48549923e-03, 4.05931870e-02, 0.00000000e+00]],
<BLANKLINE>
[[ 5.72573585e-01, 5.72071865e-01, 6.71606255e-03],
[ 4.35845892e-01, 8.80697705e-01, 1.09099593e-01],
[ 0.00000000e+00, 3.63686503e-02, 1.00358265e-02]],
[[ 1.83770124e-01, 2.56357594e-01, 0.00000000e+00],
[ 4.35845892e-01, 6.14958970e-01, 0.00000000e+00],
[ 3.07662204e-03, 3.58312786e-02, 0.00000000e+00]],
<BLANKLINE>
[[ 3.48549923e-03, 3.46939323e-03, 0.00000000e+00],
[ 3.07662204e-03, 2.22868504e-01, 6.61257395e-02],
[ 0.00000000e+00, 1.88555613e-02, 5.67244468e-03]]],
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]],
<BLANKLINE>
<BLANKLINE>
[[[ 1.87720108e-01, 5.97977973e-01, 2.11731327e-02],
[ 2.56357594e-01, 6.28962785e-01, 5.44614852e-02],
[ 0.00000000e+00, 2.60268355e-02, 5.69610302e-03]],
[[[ 3.41883175e-01, 5.97977973e-01, 0.00000000e+00],
[ 5.72071865e-01, 8.58566538e-01, 0.00000000e+00],
[ 3.46939323e-03, 4.04056116e-02, 0.00000000e+00]],
<BLANKLINE>
[[ 6.09557834e-01, 8.58566538e-01, 4.53139824e-02],
[ 6.14958970e-01, 1.47373158e+00, 1.95935584e-01],
[ 0.00000000e+00, 1.07959459e-01, 2.44053065e-02]],
[[ 3.58861043e-01, 6.28962785e-01, 0.00000000e+00],
[ 8.80697705e-01, 1.47373158e+00, 0.00000000e+00],
[ 2.22868504e-01, 1.18008528e-01, 0.00000000e+00]],
<BLANKLINE>
[[ 4.05931870e-02, 4.04056116e-02, 0.00000000e+00],
[ 3.58312786e-02, 1.18008528e-01, 2.47717418e-02],
[ 0.00000000e+00, 7.06358976e-03, 2.12498697e-03]]],
[[ 2.91835067e-03, 2.60268355e-02, 0.00000000e+00],
[ 3.63686503e-02, 1.07959459e-01, 0.00000000e+00],
[ 1.88555613e-02, 7.06358976e-03, 0.00000000e+00]]],
<BLANKLINE>
<BLANKLINE>
[[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[[[ 3.13810608e-03, 2.11731327e-02, 0.00000000e+00],
[ 6.71606255e-03, 4.53139824e-02, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],
<BLANKLINE>
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],
[[ 7.05946179e-03, 5.44614852e-02, 0.00000000e+00],
[ 1.09099593e-01, 1.95935584e-01, 0.00000000e+00],
[ 6.61257395e-02, 2.47717418e-02, 0.00000000e+00]],
<BLANKLINE>
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]]])
[[ 6.38695629e-04, 5.69610302e-03, 0.00000000e+00],
[ 1.00358265e-02, 2.44053065e-02, 0.00000000e+00],
[ 5.67244468e-03, 2.12498697e-03, 0.00000000e+00]]]])
h = plb.plot(x, c, '.') # 1D histogram

@ -138,10 +138,15 @@ class AxisLabels:
self.xlab = xlab
self.ylab = ylab
self.zlab = zlab
def __repr__(self):
return self.__str__()
def __str__(self):
return '%s\n%s\n%s\n%s\n' % (self.title, self.xlab, self.ylab, self.zlab)
def copy(self):
newcopy = empty_copy(self)
newcopy.__dict__.update(self.__dict__)
return newcopy
def labelfig(self):
try:
h1 = plotbackend.title(self.title)
@ -169,7 +174,6 @@ class Plotter_1d(object):
step : stair-step plot
scatter : scatter plot
"""
def __init__(self, plotmethod='plot'):
self.plotfun = None
if plotmethod is None:
@ -179,11 +183,12 @@ class Plotter_1d(object):
self.plotfun = getattr(plotbackend, plotmethod)
except:
pass
def show(self):
plotbackend.show()
def plot(self, wdata, *args, **kwds):
plotflag = kwds.pop('plotflag', None)
plotflag = kwds.pop('plotflag', False)
if plotflag:
h1 = self._plot(plotflag, wdata, **kwds)
else:
@ -201,6 +206,7 @@ class Plotter_1d(object):
dataCI = ()
h1 = plot1d(x, data, dataCI, plotflag, *args, **kwds)
return h1
def plot1d(args, data, dataCI, plotflag, *varargin, **kwds):
plottype = np.mod(plotflag, 10)
@ -223,10 +229,10 @@ def plot1d(args, data, dataCI, plotflag, *varargin, **kwds):
else:
H = plotbackend.fill_between(args, data, *varargin, **kwds);
scale = plotscale(plotflag);
logXscale = any(scale == 'x');
logYscale = any(scale == 'y');
logZscale = any(scale == 'z');
scale = plotscale(plotflag)
logXscale = 'x' in scale
logYscale = 'y' in scale
logZscale = 'z' in scale
ax = plotbackend.gca()
if logXscale:
plotbackend.setp(ax, xscale='log')
@ -346,11 +352,12 @@ def plot2d(wdata, plotflag, *args, **kwds):
else:
args1 = tuple((wdata.args,)) + (wdata.data,) + args
if plotflag in (1, 6, 7, 8, 9):
PL = 0
if hasattr(f, 'cl') and len(f.cl) > 0: # check if contour levels is submitted
CL = f.cl
if hasattr(f, 'pl'):
PL = f.pl # levels defines quantile levels? 0=no 1=yes
isPL = False
if hasattr(f, 'clevels') and len(f.clevels) > 0: # check if contour levels is submitted
CL = f.clevels
isPL = hasattr(f, 'plevels') and f.plevels is not None
if isPL:
PL = f.plevels # levels defines quantile levels? 0=no 1=yes
else:
dmax = np.max(f.data)
dmin = np.min(f.data)
@ -367,10 +374,8 @@ def plot2d(wdata, plotflag, *args, **kwds):
if ncl > 12:
ncl = 12
warnings.warn('Only the first 12 levels will be listed in table.')
if PL:
clvals, isPL = PL[:ncl], True
else:
clvals, isPL = clvec[:ncl], False
clvals = PL[:ncl] if isPL else clvec[:ncl]
unused_axcl = cltext(clvals, percent=isPL) # print contour level text
elif any(plotflag == [7, 9]):
plotbackend.clabel(h)
@ -390,20 +395,19 @@ def plot2d(wdata, plotflag, *args, **kwds):
plotbackend.colorbar(h)
else:
raise ValueError('unknown option for plotflag')
#if any(plotflag==(2:5))
# shading(shad);
#end
# pass
def test_docstrings():
import doctest
doctest.testmod()
def main():
pass
if __name__ == '__main__':
if True: #False : #
import doctest
doctest.testmod()
else:
main()
test_docstrings()
#main()

Loading…
Cancel
Save