Fixed some bugs

master
Per.Andreas.Brodtkorb 13 years ago
parent bac74a93b5
commit 9dceca985e

@ -13,7 +13,7 @@ from __future__ import division
from itertools import product from itertools import product
from misc import tranproc #, trangood from misc import tranproc #, trangood
from numpy import pi, sqrt, atleast_2d, exp, newaxis #@UnresolvedImport from numpy import pi, sqrt, atleast_2d, exp, newaxis #@UnresolvedImport
from scipy import interpolate, linalg from scipy import interpolate, linalg, sparse
from scipy.special import gamma from scipy.special import gamma
from wafo.misc import meshgrid from wafo.misc import meshgrid
from wafo.wafodata import WafoData from wafo.wafodata import WafoData
@ -128,11 +128,11 @@ class _KDE(object):
if self.xmin is None: if self.xmin is None:
self.xmin = amin - offset self.xmin = amin - offset
else: else:
self.xmin = self.xmin * np.ones(self.d) self.xmin = self.xmin * np.ones((self.d,1))
if self.xmax is None: if self.xmax is None:
self.xmax = amax + offset self.xmax = amax + offset
else: else:
self.xmax = self.xmax * np.ones(self.d) self.xmax = self.xmax * np.ones((self.d,1))
def eval_grid_fast(self, *args, **kwds): def eval_grid_fast(self, *args, **kwds):
"""Evaluate the estimated pdf on a grid. """Evaluate the estimated pdf on a grid.
@ -187,14 +187,20 @@ class _KDE(object):
else: else:
titlestr = 'Kernel density estimate (%s)' % self.kernel.name titlestr = 'Kernel density estimate (%s)' % self.kernel.name
kwds2 = dict(title=titlestr) kwds2 = dict(title=titlestr)
kwds2['plot_kwds'] = dict(plotflag=1)
kwds2.update(**kwds) kwds2.update(**kwds)
if self.d == 1: if self.d == 1:
args = args[0] args = args[0]
elif self.d > 1: wdata = WafoData(f, args, **kwds2)
if self.d > 1:
PL = np.r_[10:90:20, 95, 99, 99.9] PL = np.r_[10:90:20, 95, 99, 99.9]
ql = qlevels(f, p=PL) try:
kwds2.setdefault('levels', ql) ql = qlevels(f, p=PL)
return WafoData(f, args, **kwds2) wdata.clevels = ql
wdata.plevels = PL
except:
pass
return wdata
def _check_shape(self, points): def _check_shape(self, points):
points = atleast_2d(points) points = atleast_2d(points)
@ -329,17 +335,17 @@ class TKDE(_KDE):
tdataset = self._dat2gaus(self.dataset) tdataset = self._dat2gaus(self.dataset)
xmin = self.xmin xmin = self.xmin
if xmin is not None: if xmin is not None:
xmin = self._dat2gaus(xmin) xmin = self._dat2gaus(np.reshape(xmin,(-1,1)))
xmax = self.xmax xmax = self.xmax
if xmax is not None: if xmax is not None:
xmax = self._dat2gaus(xmax) xmax = self._dat2gaus(np.reshape(xmax,(-1,1)))
self.tkde = KDE(tdataset, self.hs, self.kernel, self.alpha, xmin, xmax, self.tkde = KDE(tdataset, self.hs, self.kernel, self.alpha, xmin, xmax,
self.inc) self.inc)
def _check_xmin(self): def _check_xmin(self):
if self.L2 is not None: if self.L2 is not None:
amin = self.dataset.min(axis= -1) amin = self.dataset.min(axis= -1)
L2 = np.atleast_1d(self.L2) * np.ones(self.d) # default no transformation L2 = np.atleast_1d(self.L2) * np.ones(self.d) # default no transformation
self.xmin = np.where(L2 != 1, np.maximum(self.xmin, amin / 100.0), self.xmin) self.xmin = np.where(L2 != 1, np.maximum(self.xmin, amin / 100.0), self.xmin).reshape((-1,1))
def _dat2gaus(self, points): def _dat2gaus(self, points):
if self.L2 is None: if self.L2 is None:
@ -422,7 +428,7 @@ class TKDE(_KDE):
fi = pdf(*args) fi = pdf(*args)
self.args = args self.args = args
#fi.shape = ipoints[0].shape #fi.shape = ipoints[0].shape
return fi return fi*(fi>0)
return f return f
def _eval_grid(self, *args): def _eval_grid(self, *args):
if self.L2 is None: if self.L2 is None:
@ -602,7 +608,6 @@ class KDE(_KDE):
def _eval_grid_fast(self, *args): def _eval_grid_fast(self, *args):
# TODO: This does not work correctly yet! Check it.
X = np.vstack(args) X = np.vstack(args)
d, inc = X.shape d, inc = X.shape
dx = X[:, 1] - X[:, 0] dx = X[:, 1] - X[:, 0]
@ -1631,6 +1636,95 @@ def mkernel(X, kernel):
fun = _MKERNEL_DICT[kernel[:4]] fun = _MKERNEL_DICT[kernel[:4]]
return fun(np.atleast_2d(X)) return fun(np.atleast_2d(X))
def accumsum(accmap, a, size=None, dtype=None):
"""
A sum accumulation function
Parameters
----------
accmap : ndarray
This is the "accumulation map". It maps input (i.e. indices into
`a`) to their destination in the output array. The first `a.ndim`
dimensions of `accmap` must be the same as `a.shape`. That is,
`accmap.shape[:a.ndim]` must equal `a.shape`. For example, if `a`
has shape (15,4), then `accmap.shape[:2]` must equal (15,4). In this
case `accmap[i,j]` gives the index into the output array where
element (i,j) of `a` is to be accumulated. If the output is, say,
a 2D, then `accmap` must have shape (15,4,2). The value in the
last dimension give indices into the output array. If the output is
1D, then the shape of `accmap` can be either (15,4) or (15,4,1)
a : ndarray
The input data to be accumulated.
size : ndarray or None
The size of the output array. If None, the size will be determined
from `accmap`.
dtype : numpy data type, or None
The data type of the output array. If None, the data type of
`a` is used.
Returns
-------
out : ndarray
The accumulated results.
The shape of `out` is `size` if `size` is given. Otherwise the
shape is determined by the (lexicographically) largest indices of
the output found in `accmap`.
Examples
--------
>>> from numpy import array, prod
>>> a = array([[1,2,3],[4,-1,6],[-1,8,9]])
>>> a
array([[ 1, 2, 3],
[ 4, -1, 6],
[-1, 8, 9]])
>>> # Sum the diagonals.
>>> accmap = array([[0,1,2],[2,0,1],[1,2,0]])
>>> s = accum(accmap, a)
>>> s
array([ 9, 7, 15])
>>> # A 2D output, from sub-arrays with shapes and positions like this:
>>> # [ (2,2) (2,1)]
>>> # [ (1,2) (1,1)]
>>> accmap = array([
... [[0,0],[0,0],[0,1]],
... [[0,0],[0,0],[0,1]],
... [[1,0],[1,0],[1,1]]])
>>> # Accumulate using a product.
>>> accum(accmap, a, func=prod, dtype=float)
array([[ -8., 18.],
[ -8., 9.]])
>>> # Same accmap, but create an array of lists of values.
>>> accum(accmap, a, func=lambda x: x, dtype='O')
array([[[1, 2, 4, -1], [3, 6]],
[[-1, 8], [9]]], dtype=object)
"""
# Check for bad arguments and handle the defaults.
if accmap.shape[:a.ndim] != a.shape:
raise ValueError("The initial dimensions of accmap must be the same as a.shape")
if dtype is None:
dtype = a.dtype
adims = tuple(range(a.ndim))
if size is None:
size = 1 + np.squeeze(np.apply_over_axes(np.max, accmap, axes=adims))
size = np.atleast_1d(size)
if len(size)>1:
binx = accmap[:,0]
biny = accmap[:,1]
out = np.asarray(sparse.coo_matrix((a.ravel(), (binx, biny)),shape=size, dtype=dtype).todense()).reshape(size)
else:
binx = accmap.ravel()
zero = np.zeros(len(binx))
out = np.asarray(sparse.coo_matrix((a.ravel(), (binx, zero)),shape=(size,1), dtype=dtype).todense()).reshape(size)
return out
def accumsum2(accmap, a, size):
return np.bincount(accmap.ravel(), a.ravel(), np.array(size).max())
def accum(accmap, a, func=None, size=None, fill_value=0, dtype=None): def accum(accmap, a, func=None, size=None, fill_value=0, dtype=None):
""" """
@ -2156,7 +2250,7 @@ def bitget(int_type, offset):
return np.bitwise_and(int_type, 1 << offset) >> offset return np.bitwise_and(int_type, 1 << offset) >> offset
def gridcount(data, X): def gridcount(data, X, use_sparse=False):
''' '''
Returns D-dimensional histogram using linear binning. Returns D-dimensional histogram using linear binning.
@ -2231,25 +2325,28 @@ def gridcount(data, X):
raise ValueError('X does not include whole range of the data!') raise ValueError('X does not include whole range of the data!')
csiz = np.repeat(inc, d) csiz = np.repeat(inc, d)
if use_sparse:
acfun = accumsum # faster than accum
else:
acfun = accumsum2 #accum
binx = np.asarray(np.floor((dat - xlo[:, newaxis]) / dx), dtype=int) binx = np.asarray(np.floor((dat - xlo[:, newaxis]) / dx), dtype=int)
w = dx.prod() w = dx.prod()
abs = np.abs abs = np.abs
if d == 1: if d == 1:
x.shape = (-1,) x.shape = (-1,)
c = (accum(binx, (x[binx + 1] - dat), size=[inc, ]) + c = (acfun(binx, (x[binx + 1] - dat), size=[inc, ]) +
accum(binx, (dat - x[binx]), size=[inc, ])) / w acfun(binx+1, (dat - x[binx]), size=[inc, ])) / w
elif d == 2: # elif d == 2:
b2 = binx[1] # b2 = binx[1]
b1 = binx[0] # b1 = binx[0]
c_ = np.c_ # c_ = np.c_
stk = np.vstack # stk = np.vstack
c = (accum(c_[b1, b2] , abs(np.prod(stk([X[0, b1 + 1], X[1, b2 + 1]]) - dat, axis=0)), size=[inc, inc]) + # c = (acfun(c_[b1, b2] , abs(np.prod(stk([X[0, b1 + 1], X[1, b2 + 1]]) - dat, axis=0)), size=[inc, inc]) +
accum(c_[b1 + 1, b2] , abs(np.prod(stk([X[0, b1], X[1, b2 + 1]]) - dat, axis=0)), size=[inc, inc]) + # acfun(c_[b1 + 1, b2] , abs(np.prod(stk([X[0, b1], X[1, b2 + 1]]) - dat, axis=0)), size=[inc, inc]) +
accum(c_[b1 , b2 + 1], abs(np.prod(stk([X[0, b1 + 1], X[1, b2]]) - dat, axis=0)), size=[inc, inc]) + # acfun(c_[b1 , b2 + 1], abs(np.prod(stk([X[0, b1 + 1], X[1, b2]]) - dat, axis=0)), size=[inc, inc]) +
accum(c_[b1 + 1, b2 + 1], abs(np.prod(stk([X[0, b1], X[1, b2]]) - dat, axis=0)), size=[inc, inc])) / w # acfun(c_[b1 + 1, b2 + 1], abs(np.prod(stk([X[0, b1], X[1, b2]]) - dat, axis=0)), size=[inc, inc])) / w
c = c.T # make sure c is stored in the same way as meshgrid # c = c.T # make sure c is stored in the same way as meshgrid
else: # % d>2 else: # % d>2
Nc = csiz.prod() Nc = csiz.prod()
@ -2270,13 +2367,13 @@ def gridcount(data, X):
b1 = np.sum((binx + bt0[one]) * fact1, axis=0) #linear index to c b1 = np.sum((binx + bt0[one]) * fact1, axis=0) #linear index to c
bt2 = bt0[two] + fact2 bt2 = bt0[two] + fact2
b2 = binx + bt2 # linear index to X b2 = binx + bt2 # linear index to X
c += accum(b1, abs(np.prod(X1[b2] - dat, axis=0)), size=(Nc,)) c += acfun(b1, abs(np.prod(X1[b2] - dat, axis=0)), size=(Nc,))
c = np.reshape(c / w, csiz, order='C') c = np.reshape(c / w, csiz, order='F')
# TODO: check that the flipping of axis is correct
T = range(d) T = range(d)
T[-2], T[-1] = T[-1], T[-2] T[1], T[0] = T[0], T[1]
#T[-2], T[-1] = T[-1], T[-2]
c = c.transpose(*T) # make sure c is stored in the same way as meshgrid c = c.transpose(*T) # make sure c is stored in the same way as meshgrid
return c return c
@ -2291,7 +2388,7 @@ def kde_demo1():
''' '''
import scipy.stats as st import scipy.stats as st
x = np.linspace(-4, 4) x = np.linspace(-4, 4, 101)
x0 = x / 2.0 x0 = x / 2.0
data = np.random.normal(loc=0, scale=1.0, size=7) #rndnorm(0,1,7,1); data = np.random.normal(loc=0, scale=1.0, size=7) #rndnorm(0,1,7,1);
kernel = Kernel('gaus') kernel = Kernel('gaus')
@ -2330,7 +2427,6 @@ def kde_demo2():
pylab.figure(0) pylab.figure(0)
f.plot() f.plot()
pylab.plot(x, st.rayleigh.pdf(x, scale=1), ':') pylab.plot(x, st.rayleigh.pdf(x, scale=1), ':')
#plotnorm((data).^(L2)) % gives a straight line => L2 = 0.5 reasonable #plotnorm((data).^(L2)) % gives a straight line => L2 = 0.5 reasonable
@ -2343,154 +2439,37 @@ def kde_demo2():
pylab.plot(x, st.rayleigh.pdf(x, scale=1), ':') pylab.plot(x, st.rayleigh.pdf(x, scale=1), ':')
pylab.figure(0) pylab.figure(0)
def kde_demo3():
'''Demonstrate the difference between and ordinary-KDE
def test_gridcount(): KDEDEMO3 shows that the transformation KDE is a better estimate for
import numpy as np Rayleigh distributed data around 0 than the ordinary KDE.
#import wafo.kdetools as wk '''
from matplotlib import pyplot as plb import scipy.stats as st
data = data_rayleigh() data = st.rayleigh.rvs(scale=1, size=(2,300))
N = len(data)
x = np.linspace(0,max(data)+1,50) #x = np.linspace(1.5e-3, 5, 55)
dx = x[1]-x[0]
kde = KDE(data)
c = gridcount(data,x) f = kde(output='plot', title='Ordinary KDE', plotflag=1)
pylab.figure(0)
ctr = np.array([ 0, 4, 10, 14, 15, 23, 16, 18, 21, 19, 37, 32, 24, 29, 29, f.plot()
24, 29, 26, 14, 13, 23, 9, 13, 11, 7, 12, 5, 2, 2, 6,
2, 2, 5, 0, 0, 0, 1, 2, 1, 0, 0, 0, 0, 0, 0, pylab.plot(data[0], data[1], '.')
0, 0, 0, 0, 0])
print(np.abs(c-ctr)<1e-13) #plotnorm((data).^(L2)) % gives a straight line => L2 = 0.5 reasonable
pdf = c/dx/N
h = plb.plot(x,c,'.') # 1D histogram tkde = TKDE(data, L2=0.5)
plb.show() ft = tkde.eval_grid_fast(output='plot', title='Transformation KDE', plotflag=1)
pass
data1 = data.reshape((2,-1)) pylab.figure(1)
c2 = gridcount(data1, np.vstack((x,x))) ft.plot()
c2t = np.array([ 0, 0.635018844262034, 1.170430267508894, 0.480210926714613, pylab.plot(data[0],data[1], '.')
1.256122839305450, 2.050244222017545, 1.250782602003382,
1.253065702416950, 1.295571917793612, 1.978725535031301, pylab.figure(0)
0.829707237562507, 1.842636337195244, 2.767829900577593,
1.449607074995753, 2.759640664415913, 1.634036650764552,
1.990159690320205, 1.201953891214720, 1.277182991907633,
1.293002868977407, 1.157268941919115, 1.275443411253732,
1.132629693243210, 1.418284942741350, 0.770571572744340,
0.475119743103730, 0.982244208825375, 0.681834272971076,
0.359044910769366, 0.582570635672141, 0.658412627992049,
0.784814887272479, 0.448937228228751, 0.314220262358783,
0, 0, 0, 0.404906838651598, 0.113164094088927,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0])
print(np.abs((c2t-c2.max(axis=-1)))<1e-13)
data4 = data.reshape((4,-1))
x = np.linspace(0,max(data)+1,11)
c4 = gridcount(data4, np.vstack((x,x,x,x)))
print(np.abs((c2t-c2.max(axis=-1)))<1e-13)
def data_rayleigh():
return np.array([1.412423619721313, 0.936610012402041, 3.408880790209544, 0.712493911648517, 1.453856820100018,
1.362971623321745, 0.989738148038997, 0.553839936552347, 0.225638048436888, 1.045606709473107,
0.637908826214993, 1.608606426103143, 0.961884939327567, 1.919572795331000, 1.627957520304931,
1.301044712134641, 0.623895791202139, 2.512180741739924, 0.785268132885580, 2.273629106639021,
0.711768125619732, 0.967169434614618, 0.427942932230904, 1.429667825110794, 0.631194936581454,
0.303636149404372, 1.602725333691658, 0.923957338868325, 1.470119382037774, 0.984169729516054,
1.405066725863521, 0.209225286647146, 2.197407087587261, 1.795680986321718, 1.655186235334962,
1.831529484073858, 0.983242434909240, 1.385965094654130, 1.309069260384021, 1.228928476737294,
0.802097056076411, 0.756115979892675, 1.096194354290486, 0.718886376439466, 1.806619521908829,
2.924438974501607, 0.246782936313644, 1.238666429277650, 0.426858243844038, 1.799972319758650,
3.007697177898959, 0.372270006672035, 2.367882325903836, 0.191545163286195, 1.517565471255659,
1.750004351582044, 1.236013671840509, 2.081323476045300, 2.141346897323470, 1.402378494050162,
0.544698152936965, 0.700923468199988, 0.634137874072855, 0.292299453493133, 0.475611960045215,
1.384390337219331, 2.369715926664043, 0.935586970891954, 1.028299144484800, 2.883486469293792
, 2.412885676436705, 1.502666625449783, 2.982736936434333, 1.706454468843206, 0.906120073100622
, 1.473661960328491, 0.748241351472675, 0.836991325956595, 1.509961488710520, 1.225113736935942
, 1.029890543888216, 1.358608202305835, 1.666359355892623, 1.323592437751299, 1.266885170390769
, 1.323660367761004, 1.197556616382116, 0.415219867081348, 1.594635770596585, 3.335047448446035
, 0.935717067162817, 2.664366406420023, 0.922317019697774, 2.086307246777435, 1.101280854500658
, 1.032916883571698, 0.700796651725546, 0.518227310036530, 0.859641628285530, 1.609352902696174
, 1.747723418451391, 1.538490395884064, 0.140361038832643, 1.925029474333574, 2.260668891490430
, 1.716877040260210, 0.295284687152802, 0.974796888317386, 1.561117460932286, 1.617115585994090
, 0.712684884618426, 0.791728102952554, 1.495252766892452, 1.139399282670031, 1.398150348314015
, 0.734533909397005, 0.624418865181972, 1.881415056762913, 1.706681395455110, 2.334683483141081
, 0.477838065222462, 0.634304509316731, 0.456849600683082, 1.160070279761997, 0.655340613711381
, 2.127121229851198, 0.456835914801069, 0.300568039387414, 1.276598603254562, 1.720804090031422
, 0.864730384700170, 0.628029981916123, 0.909872945858993, 0.686886746420088, 1.194705989905012
, 2.176393257858438, 1.408082540391850, 0.462744617618753, 0.995689247143699, 0.335155890849689
, 1.179590017201302, 1.063149176870603, 3.468688654992744, 1.827129780001552, 1.153130138387220
, 2.120636338813882, 0.544011313217379, 0.994288065215423, 2.290060076679768, 2.233778068583924
, 1.581312813059112, 1.387961284638806, 0.917070930336856, 2.344909067035151, 0.516281292935132
, 1.619570115238485, 1.442087344999343, 1.892443909224431, 1.007935276931834, 1.664682222719219
, 1.899024552783311, 0.882368714153905, 1.267711468232034, 2.781870230167854, 1.262515173989300
, 0.895667370955997, 1.390103843633942, 0.945814188732813, 1.680879252209405, 1.033698343725955
, 1.164434112863078, 1.540520689869044, 2.684068016815589, 0.891215308218909, 1.907325227589101
, 1.639214228101874, 2.483108383603044, 0.254728352176505, 0.939581631904974, 1.474208721908681
, 0.813131900087889, 0.723688300231953, 1.575927348326343, 1.399779277625481, 1.336475769769517
, 1.469760951955162, 0.312162051579979, 0.926191271942077, 1.095698311512132, 0.742466620037192
, 1.584565588783017, 1.969369796694313, 0.813142402654688, 1.620637451940408, 0.544472183788396
, 1.903841273483371, 0.546256895921489, 1.332096299659611, 0.954938400592347, 1.813185033558344
, 1.183839081745172, 1.159783992966029, 2.047421367071099, 0.933411156868096, 1.092543634708061
, 2.573430838154017, 0.294001371116989, 2.687145854798348, 0.647676314841560, 1.483222246093897
, 1.328873011650546, 1.499517291077073, 0.946451616282504, 1.391629977859238, 1.825818674800223
, 0.197207089634922, 0.418570979518484, 2.713292260256486, 1.451678603677107, 0.725222188153537
, 1.016524657331659, 0.510160866644535, 0.790663553688482, 0.772267750711634, 0.897737257071539
, 0.574718435129065, 0.924902659911130, 0.509352052679121, 2.076287755824404, 0.445024255426400
, 2.306443399859831, 1.009151694589026, 0.311646355326560, 0.915552448311802, 1.631979165302650
, 1.779435892929737, 1.254791667465325, 1.522546690241251, 2.117005924369452, 0.335708348510442
, 0.850786945794020, 0.307546485903476, 0.659553530770440, 1.595968673282009, 1.599529339207843
, 2.050409047591333, 1.321597656988126, 0.382901575350795, 2.263023675024229, 1.795160219589414
, 0.820728808594631, 1.252616635345433, 2.893059873111469, 1.585968547208113, 1.911105489168721
, 1.065697540675240, 1.127880912464618, 1.282656038601722, 0.791791712034066, 1.662754292624110
, 1.184021211521453, 1.442739185251488, 0.857673288506446, 0.546518081971571, 1.136176847824479
, 0.948827835556975, 1.761649333500106, 1.740961388239338, 1.486044626143792, 0.535345914616625
, 0.208765940502775, 1.281107790531077, 0.845985407399993, 2.367961441281100, 2.813630157287030
, 0.821877833204895, 1.796411857645166, 2.128114489536385, 1.349167308872121, 2.075721258630550
, 2.399008601572707, 1.262250789152573, 1.614544130176768, 1.311344244094387, 0.228900207318000
, 1.087703540854728, 1.441743192607425, 1.213654375953261, 0.965104247192400, 2.352343682973261
, 1.881070184767099, 1.944757925743782, 0.965470015113788, 1.341190290874416, 2.029803572337272
, 2.328337398097465, 1.485947310986503, 0.680661741126981, 1.456629522069083, 0.386727549117631
, 1.021861509017076, 1.482839980680464, 2.329786461679046, 1.825236378759161, 1.151270272972182
, 1.681465022236889, 1.038893153052472, 2.671305569135296, 2.973463508311512, 1.998091967015353
, 0.992439538152367, 1.101359057223470, 0.752694797719731, 1.751820513743222, 2.070842495255286
, 2.213621940109904, 1.278350678290866, 1.351639733749908, 0.567799782374724, 2.144632385787214
, 1.094123263719430, 0.678615107641789, 2.144341891738539, 1.695846624058156, 2.069396249839028
, 0.819027610733285, 1.495651321040951, 1.477482666605742, 0.511932330475827, 1.022837224533765
, 0.802470556959117, 1.588170058614226, 0.816352471969601, 2.128510415901388, 1.871914791729839
, 0.994323676062132, 1.173849936976207, 1.540652455108271, 1.896308447022061, 1.371611808573705
, 1.307706279079749, 0.888355489837264, 1.104161992788381, 1.581802123863791, 2.077336259709684
, 1.597514520759674, 0.193846187739953, 1.498827901810269, 1.074392126178632, 1.073250683084153
, 0.498942436443271, 1.836126539886937, 0.886372885469560, 0.751884958648598, 0.916116650002177
, 0.970681891155015, 1.257679318479529, 1.284798886225563, 1.003879276488743, 1.007685729946785
, 1.203631029712442, 1.463948632472297, 2.455282398854625, 1.600867640016765, 1.010899145846306
, 1.888399192628552, 0.537142702822369, 0.353191429514348, 0.419544177537439, 0.598339442960937
, 0.885310772269136, 0.847519694333472, 0.153295465546788, 1.246051759006313, 0.447732587957780
, 0.562898114036050, 1.412332385111654, 1.980540530235424, 2.704891701651084, 1.300708887507808
, 3.394236570275002, 1.269967710402906, 1.203787442037781, 0.896098313870595, 1.060303799139334
, 1.163522680114773, 0.383891805234107, 2.091377862339729, 0.365559694796422, 1.070541000579430
, 1.872070722040661, 1.001756457029345, 1.378809939003001, 1.847850278543804, 2.085003935284227
, 2.313122510412947, 0.650676881494584, 0.773551613369587, 2.136102299351586, 1.341515248421647
, 1.183940022628347, 1.377562113620296, 1.850185830133746, 1.232112165168803, 0.671923793165544
, 1.099946548218587, 1.056844894152012, 2.601133375396755, 1.391207328945862, 1.541896787253508
, 1.595966007631807, 0.923057590473980, 1.206415179152940, 1.275536301443908, 0.583420447186398
, 1.285040337652167, 1.540648406694559, 1.054438062631050, 1.902387509769504, 1.621409166908371
, 0.944812793164613, 1.100477476680040, 0.988327442233132, 1.728654388101105, 1.628053244977060
, 1.060760561571943, 1.538416018178277, 2.410108392389236, 1.751316245100324, 1.563790463015108
, 0.481219389518454, 0.994165631555275, 1.337016990968870, 1.109088579526755, 0.321407029232422
, 0.720641073906049, 1.895735773634961, 0.177585824024661, 1.996485240483058, 0.403199585960614
, 1.487121772300537, 1.177769008306152, 0.701273995641151, 1.302101876486422, 0.510537251157601
, 1.491444215081535, 1.352963516576160, 0.339422616073620, 0.340565840833962, 0.575265488888648
, 0.199078454324122, 1.068868838035460, 1.889502831203267, 1.386174255623796, 1.211807597487022
, 1.997063801362690, 0.453444401722453, 2.184735356478338, 0.478137766710008, 1.206426055203951
, 0.555876664495711, 1.280274233919441, 0.095813804344955, 1.706079097312628, 1.943477111398666
, 2.230140630510882, 2.946309044620703, 1.186142019401047, 0.795814141941795, 0.460857387230226
, 1.190772316835832, 1.327362504940310, 1.696595922853605, 0.416190042989537, 1.472083830192951
, 1.206395605479538, 0.612524363189761, 2.362058183247366, 1.336246455616561, 1.077916969428414
, 2.385755851351826, 1.460727990062456, 1.096704997935700, 1.913474394478998, 1.233385699260248,
1.270577147048640, 1.509727846778659, 0.956645085964223, 0.739599713571419, 1.315249583679571,
2.008261585625269, 1.021943728886631, 0.488828195617451, 1.083244894832682, 0.844912313732214,
1.013054512108690, 1.893114294699785, 1.016751451332806, 0.994570044372612, 0.945503828258995])
def test_docstrings(): def test_docstrings():
@ -2498,5 +2477,5 @@ def test_docstrings():
doctest.testmod() doctest.testmod()
if __name__ == '__main__': if __name__ == '__main__':
#test_docstrings() test_docstrings()
test_gridcount()

@ -26,15 +26,16 @@ def test0_KDE1D():
array([ 0.2039735 , 0.40252503, 0.54595078, 0.52219649, 0.3906213 , array([ 0.2039735 , 0.40252503, 0.54595078, 0.52219649, 0.3906213 ,
0.26381501, 0.16407362, 0.08270612, 0.02991145, 0.00720821]) 0.26381501, 0.16407362, 0.08270612, 0.02991145, 0.00720821])
>>> kde0.eval_grid_fast(x) >>> kde0.eval_grid_fast(x)
array([ 0.32343789, 0.51366167, 0.55643329, 0.43688805, 0.28972471, array([ 0.20729484, 0.39865044, 0.53716945, 0.5169322 , 0.39060223,
0.19445277, 0.12473331, 0.06195215, 0.02087712, 0.00449567]) 0.26441126, 0.16388801, 0.08388527, 0.03227164, 0.00883579])
>>> f = kde0.eval_grid_fast(); f >>> f = kde0.eval_grid_fast(); f
array([ 0.02076721, 0.0612371 , 0.14515308, 0.27604202, 0.42001793, array([ 0.01149411, 0.03485467, 0.08799292, 0.18568718, 0.32473136,
0.51464781, 0.52131018, 0.45976136, 0.37621768, 0.29589521, 0.46543163, 0.54532016, 0.53005828, 0.44447651, 0.34119612,
0.21985316, 0.1473364 , 0.08502256, 0.04063749, 0.0155788 , 0.25103852, 0.1754952 , 0.11072989, 0.05992731, 0.02687784,
0.00466938]) 0.00974983])
>>> np.trapz(f,kde0.args) >>> np.trapz(f,kde0.args)
array([ 0.99416766]) array([ 0.99500101])
''' '''
def test1_TKDE1D(): def test1_TKDE1D():
''' '''
@ -167,9 +168,13 @@ def test_KDE2D():
>>> kde0 = wk.KDE(data, hs=0.5, alpha=0.0, inc=16) >>> kde0 = wk.KDE(data, hs=0.5, alpha=0.0, inc=16)
>>> kde0.eval_grid(x, x) >>> kde0.eval_grid(x, x)
array([[ 3.27260963e-02, 4.21654678e-02, 5.85338634e-04],
[ 6.78845466e-02, 1.42195839e-01, 1.41676003e-03],
[ 1.39466746e-04, 4.26983850e-03, 2.52736185e-05]])
>>> kde0.eval_grid_fast(x, x) >>> kde0.eval_grid_fast(x, x)
array([[ 0.08670654, 0.12577712, 0.00808478],
[ 0.1411195 , 0.24160579, 0.01816001],
[ 0.0031541 , 0.01553967, 0.00114854]])
''' '''
def test_smooth_params(): def test_smooth_params():
@ -197,16 +202,16 @@ def test_smooth_params():
[ -2.68892467e-02, 3.91283306e-01, 2.38654678e-02], [ -2.68892467e-02, 3.91283306e-01, 2.38654678e-02],
[ 3.18932448e-04, 2.38654678e-02, 4.05123874e-01]]) [ 3.18932448e-04, 2.38654678e-02, 4.05123874e-01]])
>>> gauss.hscv(data) >>> gauss.hscv(data)
array([ 0.16858959, 0.33034332, 0.3046287 ]) array([ 0.16858959, 0.32739383, 0.3046287 ])
>>> gauss.hstt(data) >>> gauss.hstt(data)
array([ 0.18196282, 0.51090571, 0.1111913 ]) array([ 0.18099075, 0.50409881, 0.11018912])
>>> gauss.hste(data) >>> gauss.hste(data)
array([ 0.1683984 , 0.29693232, 0.17974833]) array([ 0.16750009, 0.29059113, 0.17994255])
>>> gauss.hldpi(data) >>> gauss.hldpi(data)
array([ 0.17426948, 0.33672307, 0.31240374]) array([ 0.1732289 , 0.33159097, 0.3107633 ])
''' '''
def test_gridcount_1D(): def test_gridcount_1D():
''' '''
@ -219,11 +224,14 @@ def test_gridcount_1D():
>>> x = np.linspace(0, max(data.ravel()) + 1, 10) >>> x = np.linspace(0, max(data.ravel()) + 1, 10)
>>> dx = x[1] - x[0] >>> dx = x[1] - x[0]
>>> c = wk.gridcount(data, x) >>> c = wk.gridcount(data, x, use_sparse=False)
>>> c >>> c
array([ 1., 6., 7., 2., 1., 2., 1., 0., 0., 0.]) array([ 0.78762626, 1.77520717, 7.99190087, 4.04054449, 1.67156643,
2.38228499, 1.05933195, 0.29153785, 0. , 0. ])
>>> wk.gridcount(data, x, use_sparse=True)
array([ 0.78762626, 1.77520717, 7.99190087, 4.04054449, 1.67156643,
2.38228499, 1.05933195, 0.29153785, 0. , 0. ])
h = plb.plot(x, c, '.') # 1D histogram h = plb.plot(x, c, '.') # 1D histogram
h1 = plb.plot(x, c / dx / N) # 1D probability density plot h1 = plb.plot(x, c / dx / N) # 1D probability density plot
@ -281,19 +289,19 @@ def test_gridcount_3D():
>>> x = np.linspace(0, max(data.ravel()) + 1, 3) >>> x = np.linspace(0, max(data.ravel()) + 1, 3)
>>> dx = x[1] - x[0] >>> dx = x[1] - x[0]
>>> X = np.vstack((x, x, x)) >>> X = np.vstack((x, x, x))
>>> c = wk.gridcount(data, X) >>> c = wk.gridcount(data, X, use_sparse=True)
>>> c >>> c
array([[[ 8.74229894e-01, 1.44969128e+00, 7.49265424e-02], array([[[ 8.74229894e-01, 1.27910940e+00, 1.42033973e-01],
[ 1.94778915e+00, 2.28951650e+00, 8.53886762e-02], [ 1.94778915e+00, 2.59536282e+00, 3.28213680e-01],
[ 1.08429416e-01, 1.10905565e-01, 4.16196568e-04]], [ 1.08429416e-01, 1.69571495e-01, 7.48896775e-03]],
<BLANKLINE> <BLANKLINE>
[[ 1.27910940e+00, 2.58396370e+00, 2.18142488e-01], [[ 1.44969128e+00, 2.58396370e+00, 2.45459949e-01],
[ 2.59536282e+00, 4.49653348e+00, 3.73415131e-01], [ 2.28951650e+00, 4.49653348e+00, 2.73167915e-01],
[ 1.69571495e-01, 3.18733817e-01, 1.62218824e-02]], [ 1.10905565e-01, 3.18733817e-01, 1.12880816e-02]],
<BLANKLINE> <BLANKLINE>
[[ 1.42033973e-01, 2.45459949e-01, 0.00000000e+00], [[ 7.49265424e-02, 2.18142488e-01, 0.00000000e+00],
[ 3.28213680e-01, 2.73167915e-01, 0.00000000e+00], [ 8.53886762e-02, 3.73415131e-01, 0.00000000e+00],
[ 7.48896775e-03, 1.12880816e-02, 0.00000000e+00]]]) [ 4.16196568e-04, 1.62218824e-02, 0.00000000e+00]]])
''' '''
def test_gridcount_4D(): def test_gridcount_4D():
@ -314,43 +322,43 @@ def test_gridcount_4D():
>>> X = np.vstack((x, x, x, x)) >>> X = np.vstack((x, x, x, x))
>>> c = wk.gridcount(data, X) >>> c = wk.gridcount(data, X)
>>> c >>> c
array([[[[ 1.77163904e-01, 3.41883175e-01, 3.13810608e-03], array([[[[ 1.77163904e-01, 1.87720108e-01, 0.00000000e+00],
[ 1.83770124e-01, 3.58861043e-01, 7.05946179e-03], [ 5.72573585e-01, 6.09557834e-01, 0.00000000e+00],
[ 0.00000000e+00, 2.91835067e-03, 6.38695629e-04]], [ 3.48549923e-03, 4.05931870e-02, 0.00000000e+00]],
<BLANKLINE> <BLANKLINE>
[[ 5.72573585e-01, 5.72071865e-01, 6.71606255e-03], [[ 1.83770124e-01, 2.56357594e-01, 0.00000000e+00],
[ 4.35845892e-01, 8.80697705e-01, 1.09099593e-01], [ 4.35845892e-01, 6.14958970e-01, 0.00000000e+00],
[ 0.00000000e+00, 3.63686503e-02, 1.00358265e-02]], [ 3.07662204e-03, 3.58312786e-02, 0.00000000e+00]],
<BLANKLINE> <BLANKLINE>
[[ 3.48549923e-03, 3.46939323e-03, 0.00000000e+00], [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[ 3.07662204e-03, 2.22868504e-01, 6.61257395e-02], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
[ 0.00000000e+00, 1.88555613e-02, 5.67244468e-03]]], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]],
<BLANKLINE> <BLANKLINE>
<BLANKLINE> <BLANKLINE>
[[[ 1.87720108e-01, 5.97977973e-01, 2.11731327e-02], [[[ 3.41883175e-01, 5.97977973e-01, 0.00000000e+00],
[ 2.56357594e-01, 6.28962785e-01, 5.44614852e-02], [ 5.72071865e-01, 8.58566538e-01, 0.00000000e+00],
[ 0.00000000e+00, 2.60268355e-02, 5.69610302e-03]], [ 3.46939323e-03, 4.04056116e-02, 0.00000000e+00]],
<BLANKLINE> <BLANKLINE>
[[ 6.09557834e-01, 8.58566538e-01, 4.53139824e-02], [[ 3.58861043e-01, 6.28962785e-01, 0.00000000e+00],
[ 6.14958970e-01, 1.47373158e+00, 1.95935584e-01], [ 8.80697705e-01, 1.47373158e+00, 0.00000000e+00],
[ 0.00000000e+00, 1.07959459e-01, 2.44053065e-02]], [ 2.22868504e-01, 1.18008528e-01, 0.00000000e+00]],
<BLANKLINE> <BLANKLINE>
[[ 4.05931870e-02, 4.04056116e-02, 0.00000000e+00], [[ 2.91835067e-03, 2.60268355e-02, 0.00000000e+00],
[ 3.58312786e-02, 1.18008528e-01, 2.47717418e-02], [ 3.63686503e-02, 1.07959459e-01, 0.00000000e+00],
[ 0.00000000e+00, 7.06358976e-03, 2.12498697e-03]]], [ 1.88555613e-02, 7.06358976e-03, 0.00000000e+00]]],
<BLANKLINE> <BLANKLINE>
<BLANKLINE> <BLANKLINE>
[[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [[[ 3.13810608e-03, 2.11731327e-02, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 6.71606255e-03, 4.53139824e-02, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]],
<BLANKLINE> <BLANKLINE>
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [[ 7.05946179e-03, 5.44614852e-02, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 1.09099593e-01, 1.95935584e-01, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]], [ 6.61257395e-02, 2.47717418e-02, 0.00000000e+00]],
<BLANKLINE> <BLANKLINE>
[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [[ 6.38695629e-04, 5.69610302e-03, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], [ 1.00358265e-02, 2.44053065e-02, 0.00000000e+00],
[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]]]) [ 5.67244468e-03, 2.12498697e-03, 0.00000000e+00]]]])
h = plb.plot(x, c, '.') # 1D histogram h = plb.plot(x, c, '.') # 1D histogram

@ -137,11 +137,16 @@ class AxisLabels:
self.title = title self.title = title
self.xlab = xlab self.xlab = xlab
self.ylab = ylab self.ylab = ylab
self.zlab = zlab self.zlab = zlab
def __repr__(self):
return self.__str__()
def __str__(self):
return '%s\n%s\n%s\n%s\n' % (self.title, self.xlab, self.ylab, self.zlab)
def copy(self): def copy(self):
newcopy = empty_copy(self) newcopy = empty_copy(self)
newcopy.__dict__.update(self.__dict__) newcopy.__dict__.update(self.__dict__)
return newcopy return newcopy
def labelfig(self): def labelfig(self):
try: try:
h1 = plotbackend.title(self.title) h1 = plotbackend.title(self.title)
@ -169,7 +174,6 @@ class Plotter_1d(object):
step : stair-step plot step : stair-step plot
scatter : scatter plot scatter : scatter plot
""" """
def __init__(self, plotmethod='plot'): def __init__(self, plotmethod='plot'):
self.plotfun = None self.plotfun = None
if plotmethod is None: if plotmethod is None:
@ -179,11 +183,12 @@ class Plotter_1d(object):
self.plotfun = getattr(plotbackend, plotmethod) self.plotfun = getattr(plotbackend, plotmethod)
except: except:
pass pass
def show(self): def show(self):
plotbackend.show() plotbackend.show()
def plot(self, wdata, *args, **kwds): def plot(self, wdata, *args, **kwds):
plotflag = kwds.pop('plotflag', None) plotflag = kwds.pop('plotflag', False)
if plotflag: if plotflag:
h1 = self._plot(plotflag, wdata, **kwds) h1 = self._plot(plotflag, wdata, **kwds)
else: else:
@ -201,6 +206,7 @@ class Plotter_1d(object):
dataCI = () dataCI = ()
h1 = plot1d(x, data, dataCI, plotflag, *args, **kwds) h1 = plot1d(x, data, dataCI, plotflag, *args, **kwds)
return h1 return h1
def plot1d(args, data, dataCI, plotflag, *varargin, **kwds): def plot1d(args, data, dataCI, plotflag, *varargin, **kwds):
plottype = np.mod(plotflag, 10) plottype = np.mod(plotflag, 10)
@ -223,10 +229,10 @@ def plot1d(args, data, dataCI, plotflag, *varargin, **kwds):
else: else:
H = plotbackend.fill_between(args, data, *varargin, **kwds); H = plotbackend.fill_between(args, data, *varargin, **kwds);
scale = plotscale(plotflag); scale = plotscale(plotflag)
logXscale = any(scale == 'x'); logXscale = 'x' in scale
logYscale = any(scale == 'y'); logYscale = 'y' in scale
logZscale = any(scale == 'z'); logZscale = 'z' in scale
ax = plotbackend.gca() ax = plotbackend.gca()
if logXscale: if logXscale:
plotbackend.setp(ax, xscale='log') plotbackend.setp(ax, xscale='log')
@ -346,11 +352,12 @@ def plot2d(wdata, plotflag, *args, **kwds):
else: else:
args1 = tuple((wdata.args,)) + (wdata.data,) + args args1 = tuple((wdata.args,)) + (wdata.data,) + args
if plotflag in (1, 6, 7, 8, 9): if plotflag in (1, 6, 7, 8, 9):
PL = 0 isPL = False
if hasattr(f, 'cl') and len(f.cl) > 0: # check if contour levels is submitted if hasattr(f, 'clevels') and len(f.clevels) > 0: # check if contour levels is submitted
CL = f.cl CL = f.clevels
if hasattr(f, 'pl'): isPL = hasattr(f, 'plevels') and f.plevels is not None
PL = f.pl # levels defines quantile levels? 0=no 1=yes if isPL:
PL = f.plevels # levels defines quantile levels? 0=no 1=yes
else: else:
dmax = np.max(f.data) dmax = np.max(f.data)
dmin = np.min(f.data) dmin = np.min(f.data)
@ -367,10 +374,8 @@ def plot2d(wdata, plotflag, *args, **kwds):
if ncl > 12: if ncl > 12:
ncl = 12 ncl = 12
warnings.warn('Only the first 12 levels will be listed in table.') warnings.warn('Only the first 12 levels will be listed in table.')
if PL:
clvals, isPL = PL[:ncl], True clvals = PL[:ncl] if isPL else clvec[:ncl]
else:
clvals, isPL = clvec[:ncl], False
unused_axcl = cltext(clvals, percent=isPL) # print contour level text unused_axcl = cltext(clvals, percent=isPL) # print contour level text
elif any(plotflag == [7, 9]): elif any(plotflag == [7, 9]):
plotbackend.clabel(h) plotbackend.clabel(h)
@ -390,20 +395,19 @@ def plot2d(wdata, plotflag, *args, **kwds):
plotbackend.colorbar(h) plotbackend.colorbar(h)
else: else:
raise ValueError('unknown option for plotflag') raise ValueError('unknown option for plotflag')
#if any(plotflag==(2:5)) #if any(plotflag==(2:5))
# shading(shad); # shading(shad);
#end #end
# pass # pass
def test_docstrings():
import doctest
doctest.testmod()
def main(): def main():
pass pass
if __name__ == '__main__': if __name__ == '__main__':
if True: #False : # test_docstrings()
import doctest #main()
doctest.testmod()
else:
main()

Loading…
Cancel
Save