diff --git a/pywafo/src/wafo/kdetools.py b/pywafo/src/wafo/kdetools.py index 5f99046..dacae61 100644 --- a/pywafo/src/wafo/kdetools.py +++ b/pywafo/src/wafo/kdetools.py @@ -13,7 +13,7 @@ from __future__ import division from itertools import product from misc import tranproc #, trangood from numpy import pi, sqrt, atleast_2d, exp, newaxis #@UnresolvedImport -from scipy import interpolate, linalg +from scipy import interpolate, linalg, sparse from scipy.special import gamma from wafo.misc import meshgrid from wafo.wafodata import WafoData @@ -128,11 +128,11 @@ class _KDE(object): if self.xmin is None: self.xmin = amin - offset else: - self.xmin = self.xmin * np.ones(self.d) + self.xmin = self.xmin * np.ones((self.d,1)) if self.xmax is None: self.xmax = amax + offset else: - self.xmax = self.xmax * np.ones(self.d) + self.xmax = self.xmax * np.ones((self.d,1)) def eval_grid_fast(self, *args, **kwds): """Evaluate the estimated pdf on a grid. @@ -187,14 +187,20 @@ class _KDE(object): else: titlestr = 'Kernel density estimate (%s)' % self.kernel.name kwds2 = dict(title=titlestr) + kwds2['plot_kwds'] = dict(plotflag=1) kwds2.update(**kwds) if self.d == 1: args = args[0] - elif self.d > 1: + wdata = WafoData(f, args, **kwds2) + if self.d > 1: PL = np.r_[10:90:20, 95, 99, 99.9] - ql = qlevels(f, p=PL) - kwds2.setdefault('levels', ql) - return WafoData(f, args, **kwds2) + try: + ql = qlevels(f, p=PL) + wdata.clevels = ql + wdata.plevels = PL + except: + pass + return wdata def _check_shape(self, points): points = atleast_2d(points) @@ -329,17 +335,17 @@ class TKDE(_KDE): tdataset = self._dat2gaus(self.dataset) xmin = self.xmin if xmin is not None: - xmin = self._dat2gaus(xmin) + xmin = self._dat2gaus(np.reshape(xmin,(-1,1))) xmax = self.xmax if xmax is not None: - xmax = self._dat2gaus(xmax) + xmax = self._dat2gaus(np.reshape(xmax,(-1,1))) self.tkde = KDE(tdataset, self.hs, self.kernel, self.alpha, xmin, xmax, self.inc) def _check_xmin(self): if self.L2 is not None: amin = self.dataset.min(axis= -1) L2 = np.atleast_1d(self.L2) * np.ones(self.d) # default no transformation - self.xmin = np.where(L2 != 1, np.maximum(self.xmin, amin / 100.0), self.xmin) + self.xmin = np.where(L2 != 1, np.maximum(self.xmin, amin / 100.0), self.xmin).reshape((-1,1)) def _dat2gaus(self, points): if self.L2 is None: @@ -422,7 +428,7 @@ class TKDE(_KDE): fi = pdf(*args) self.args = args #fi.shape = ipoints[0].shape - return fi + return fi*(fi>0) return f def _eval_grid(self, *args): if self.L2 is None: @@ -602,7 +608,6 @@ class KDE(_KDE): def _eval_grid_fast(self, *args): - # TODO: This does not work correctly yet! Check it. X = np.vstack(args) d, inc = X.shape dx = X[:, 1] - X[:, 0] @@ -1631,6 +1636,95 @@ def mkernel(X, kernel): fun = _MKERNEL_DICT[kernel[:4]] return fun(np.atleast_2d(X)) +def accumsum(accmap, a, size=None, dtype=None): + """ + A sum accumulation function + + Parameters + ---------- + accmap : ndarray + This is the "accumulation map". It maps input (i.e. indices into + `a`) to their destination in the output array. The first `a.ndim` + dimensions of `accmap` must be the same as `a.shape`. That is, + `accmap.shape[:a.ndim]` must equal `a.shape`. For example, if `a` + has shape (15,4), then `accmap.shape[:2]` must equal (15,4). In this + case `accmap[i,j]` gives the index into the output array where + element (i,j) of `a` is to be accumulated. If the output is, say, + a 2D, then `accmap` must have shape (15,4,2). The value in the + last dimension give indices into the output array. If the output is + 1D, then the shape of `accmap` can be either (15,4) or (15,4,1) + a : ndarray + The input data to be accumulated. + size : ndarray or None + The size of the output array. If None, the size will be determined + from `accmap`. + dtype : numpy data type, or None + The data type of the output array. If None, the data type of + `a` is used. + + Returns + ------- + out : ndarray + The accumulated results. + + The shape of `out` is `size` if `size` is given. Otherwise the + shape is determined by the (lexicographically) largest indices of + the output found in `accmap`. + + + Examples + -------- + >>> from numpy import array, prod + >>> a = array([[1,2,3],[4,-1,6],[-1,8,9]]) + >>> a + array([[ 1, 2, 3], + [ 4, -1, 6], + [-1, 8, 9]]) + >>> # Sum the diagonals. + >>> accmap = array([[0,1,2],[2,0,1],[1,2,0]]) + >>> s = accum(accmap, a) + >>> s + array([ 9, 7, 15]) + >>> # A 2D output, from sub-arrays with shapes and positions like this: + >>> # [ (2,2) (2,1)] + >>> # [ (1,2) (1,1)] + >>> accmap = array([ + ... [[0,0],[0,0],[0,1]], + ... [[0,0],[0,0],[0,1]], + ... [[1,0],[1,0],[1,1]]]) + >>> # Accumulate using a product. + >>> accum(accmap, a, func=prod, dtype=float) + array([[ -8., 18.], + [ -8., 9.]]) + >>> # Same accmap, but create an array of lists of values. + >>> accum(accmap, a, func=lambda x: x, dtype='O') + array([[[1, 2, 4, -1], [3, 6]], + [[-1, 8], [9]]], dtype=object) + """ + + # Check for bad arguments and handle the defaults. + if accmap.shape[:a.ndim] != a.shape: + raise ValueError("The initial dimensions of accmap must be the same as a.shape") + if dtype is None: + dtype = a.dtype + + adims = tuple(range(a.ndim)) + if size is None: + size = 1 + np.squeeze(np.apply_over_axes(np.max, accmap, axes=adims)) + size = np.atleast_1d(size) + if len(size)>1: + binx = accmap[:,0] + biny = accmap[:,1] + out = np.asarray(sparse.coo_matrix((a.ravel(), (binx, biny)),shape=size, dtype=dtype).todense()).reshape(size) + else: + binx = accmap.ravel() + zero = np.zeros(len(binx)) + out = np.asarray(sparse.coo_matrix((a.ravel(), (binx, zero)),shape=(size,1), dtype=dtype).todense()).reshape(size) + + return out + +def accumsum2(accmap, a, size): + return np.bincount(accmap.ravel(), a.ravel(), np.array(size).max()) def accum(accmap, a, func=None, size=None, fill_value=0, dtype=None): """ @@ -2156,7 +2250,7 @@ def bitget(int_type, offset): return np.bitwise_and(int_type, 1 << offset) >> offset -def gridcount(data, X): +def gridcount(data, X, use_sparse=False): ''' Returns D-dimensional histogram using linear binning. @@ -2231,25 +2325,28 @@ def gridcount(data, X): raise ValueError('X does not include whole range of the data!') csiz = np.repeat(inc, d) - + if use_sparse: + acfun = accumsum # faster than accum + else: + acfun = accumsum2 #accum binx = np.asarray(np.floor((dat - xlo[:, newaxis]) / dx), dtype=int) w = dx.prod() abs = np.abs if d == 1: x.shape = (-1,) - c = (accum(binx, (x[binx + 1] - dat), size=[inc, ]) + - accum(binx, (dat - x[binx]), size=[inc, ])) / w - elif d == 2: - b2 = binx[1] - b1 = binx[0] - c_ = np.c_ - stk = np.vstack - c = (accum(c_[b1, b2] , abs(np.prod(stk([X[0, b1 + 1], X[1, b2 + 1]]) - dat, axis=0)), size=[inc, inc]) + - accum(c_[b1 + 1, b2] , abs(np.prod(stk([X[0, b1], X[1, b2 + 1]]) - dat, axis=0)), size=[inc, inc]) + - accum(c_[b1 , b2 + 1], abs(np.prod(stk([X[0, b1 + 1], X[1, b2]]) - dat, axis=0)), size=[inc, inc]) + - accum(c_[b1 + 1, b2 + 1], abs(np.prod(stk([X[0, b1], X[1, b2]]) - dat, axis=0)), size=[inc, inc])) / w - c = c.T # make sure c is stored in the same way as meshgrid + c = (acfun(binx, (x[binx + 1] - dat), size=[inc, ]) + + acfun(binx+1, (dat - x[binx]), size=[inc, ])) / w +# elif d == 2: +# b2 = binx[1] +# b1 = binx[0] +# c_ = np.c_ +# stk = np.vstack +# c = (acfun(c_[b1, b2] , abs(np.prod(stk([X[0, b1 + 1], X[1, b2 + 1]]) - dat, axis=0)), size=[inc, inc]) + +# acfun(c_[b1 + 1, b2] , abs(np.prod(stk([X[0, b1], X[1, b2 + 1]]) - dat, axis=0)), size=[inc, inc]) + +# acfun(c_[b1 , b2 + 1], abs(np.prod(stk([X[0, b1 + 1], X[1, b2]]) - dat, axis=0)), size=[inc, inc]) + +# acfun(c_[b1 + 1, b2 + 1], abs(np.prod(stk([X[0, b1], X[1, b2]]) - dat, axis=0)), size=[inc, inc])) / w +# c = c.T # make sure c is stored in the same way as meshgrid else: # % d>2 Nc = csiz.prod() @@ -2270,13 +2367,13 @@ def gridcount(data, X): b1 = np.sum((binx + bt0[one]) * fact1, axis=0) #linear index to c bt2 = bt0[two] + fact2 b2 = binx + bt2 # linear index to X - c += accum(b1, abs(np.prod(X1[b2] - dat, axis=0)), size=(Nc,)) + c += acfun(b1, abs(np.prod(X1[b2] - dat, axis=0)), size=(Nc,)) - c = np.reshape(c / w, csiz, order='C') - # TODO: check that the flipping of axis is correct + c = np.reshape(c / w, csiz, order='F') T = range(d) - T[-2], T[-1] = T[-1], T[-2] + T[1], T[0] = T[0], T[1] + #T[-2], T[-1] = T[-1], T[-2] c = c.transpose(*T) # make sure c is stored in the same way as meshgrid return c @@ -2291,7 +2388,7 @@ def kde_demo1(): ''' import scipy.stats as st - x = np.linspace(-4, 4) + x = np.linspace(-4, 4, 101) x0 = x / 2.0 data = np.random.normal(loc=0, scale=1.0, size=7) #rndnorm(0,1,7,1); kernel = Kernel('gaus') @@ -2330,7 +2427,6 @@ def kde_demo2(): pylab.figure(0) f.plot() - pylab.plot(x, st.rayleigh.pdf(x, scale=1), ':') #plotnorm((data).^(L2)) % gives a straight line => L2 = 0.5 reasonable @@ -2343,154 +2439,37 @@ def kde_demo2(): pylab.plot(x, st.rayleigh.pdf(x, scale=1), ':') pylab.figure(0) + +def kde_demo3(): + '''Demonstrate the difference between and ordinary-KDE -def test_gridcount(): - import numpy as np - #import wafo.kdetools as wk - from matplotlib import pyplot as plb - data = data_rayleigh() - N = len(data) - x = np.linspace(0,max(data)+1,50) - dx = x[1]-x[0] - - c = gridcount(data,x) - - ctr = np.array([ 0, 4, 10, 14, 15, 23, 16, 18, 21, 19, 37, 32, 24, 29, 29, - 24, 29, 26, 14, 13, 23, 9, 13, 11, 7, 12, 5, 2, 2, 6, - 2, 2, 5, 0, 0, 0, 1, 2, 1, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0]) - print(np.abs(c-ctr)<1e-13) - pdf = c/dx/N - h = plb.plot(x,c,'.') # 1D histogram - plb.show() - pass - - data1 = data.reshape((2,-1)) - c2 = gridcount(data1, np.vstack((x,x))) - - c2t = np.array([ 0, 0.635018844262034, 1.170430267508894, 0.480210926714613, - 1.256122839305450, 2.050244222017545, 1.250782602003382, - 1.253065702416950, 1.295571917793612, 1.978725535031301, - 0.829707237562507, 1.842636337195244, 2.767829900577593, - 1.449607074995753, 2.759640664415913, 1.634036650764552, - 1.990159690320205, 1.201953891214720, 1.277182991907633, - 1.293002868977407, 1.157268941919115, 1.275443411253732, - 1.132629693243210, 1.418284942741350, 0.770571572744340, - 0.475119743103730, 0.982244208825375, 0.681834272971076, - 0.359044910769366, 0.582570635672141, 0.658412627992049, - 0.784814887272479, 0.448937228228751, 0.314220262358783, - 0, 0, 0, 0.404906838651598, 0.113164094088927, - 0, 0, 0, 0, - 0, 0, 0, 0, - 0, 0, 0]) - print(np.abs((c2t-c2.max(axis=-1)))<1e-13) - - data4 = data.reshape((4,-1)) - x = np.linspace(0,max(data)+1,11) - c4 = gridcount(data4, np.vstack((x,x,x,x))) - print(np.abs((c2t-c2.max(axis=-1)))<1e-13) - -def data_rayleigh(): - return np.array([1.412423619721313, 0.936610012402041, 3.408880790209544, 0.712493911648517, 1.453856820100018, - 1.362971623321745, 0.989738148038997, 0.553839936552347, 0.225638048436888, 1.045606709473107, - 0.637908826214993, 1.608606426103143, 0.961884939327567, 1.919572795331000, 1.627957520304931, - 1.301044712134641, 0.623895791202139, 2.512180741739924, 0.785268132885580, 2.273629106639021, - 0.711768125619732, 0.967169434614618, 0.427942932230904, 1.429667825110794, 0.631194936581454, - 0.303636149404372, 1.602725333691658, 0.923957338868325, 1.470119382037774, 0.984169729516054, - 1.405066725863521, 0.209225286647146, 2.197407087587261, 1.795680986321718, 1.655186235334962, - 1.831529484073858, 0.983242434909240, 1.385965094654130, 1.309069260384021, 1.228928476737294, - 0.802097056076411, 0.756115979892675, 1.096194354290486, 0.718886376439466, 1.806619521908829, - 2.924438974501607, 0.246782936313644, 1.238666429277650, 0.426858243844038, 1.799972319758650, - 3.007697177898959, 0.372270006672035, 2.367882325903836, 0.191545163286195, 1.517565471255659, - 1.750004351582044, 1.236013671840509, 2.081323476045300, 2.141346897323470, 1.402378494050162, - 0.544698152936965, 0.700923468199988, 0.634137874072855, 0.292299453493133, 0.475611960045215, - 1.384390337219331, 2.369715926664043, 0.935586970891954, 1.028299144484800, 2.883486469293792 -, 2.412885676436705, 1.502666625449783, 2.982736936434333, 1.706454468843206, 0.906120073100622 -, 1.473661960328491, 0.748241351472675, 0.836991325956595, 1.509961488710520, 1.225113736935942 -, 1.029890543888216, 1.358608202305835, 1.666359355892623, 1.323592437751299, 1.266885170390769 -, 1.323660367761004, 1.197556616382116, 0.415219867081348, 1.594635770596585, 3.335047448446035 -, 0.935717067162817, 2.664366406420023, 0.922317019697774, 2.086307246777435, 1.101280854500658 -, 1.032916883571698, 0.700796651725546, 0.518227310036530, 0.859641628285530, 1.609352902696174 -, 1.747723418451391, 1.538490395884064, 0.140361038832643, 1.925029474333574, 2.260668891490430 -, 1.716877040260210, 0.295284687152802, 0.974796888317386, 1.561117460932286, 1.617115585994090 -, 0.712684884618426, 0.791728102952554, 1.495252766892452, 1.139399282670031, 1.398150348314015 -, 0.734533909397005, 0.624418865181972, 1.881415056762913, 1.706681395455110, 2.334683483141081 -, 0.477838065222462, 0.634304509316731, 0.456849600683082, 1.160070279761997, 0.655340613711381 -, 2.127121229851198, 0.456835914801069, 0.300568039387414, 1.276598603254562, 1.720804090031422 -, 0.864730384700170, 0.628029981916123, 0.909872945858993, 0.686886746420088, 1.194705989905012 -, 2.176393257858438, 1.408082540391850, 0.462744617618753, 0.995689247143699, 0.335155890849689 -, 1.179590017201302, 1.063149176870603, 3.468688654992744, 1.827129780001552, 1.153130138387220 -, 2.120636338813882, 0.544011313217379, 0.994288065215423, 2.290060076679768, 2.233778068583924 -, 1.581312813059112, 1.387961284638806, 0.917070930336856, 2.344909067035151, 0.516281292935132 -, 1.619570115238485, 1.442087344999343, 1.892443909224431, 1.007935276931834, 1.664682222719219 -, 1.899024552783311, 0.882368714153905, 1.267711468232034, 2.781870230167854, 1.262515173989300 -, 0.895667370955997, 1.390103843633942, 0.945814188732813, 1.680879252209405, 1.033698343725955 -, 1.164434112863078, 1.540520689869044, 2.684068016815589, 0.891215308218909, 1.907325227589101 -, 1.639214228101874, 2.483108383603044, 0.254728352176505, 0.939581631904974, 1.474208721908681 -, 0.813131900087889, 0.723688300231953, 1.575927348326343, 1.399779277625481, 1.336475769769517 -, 1.469760951955162, 0.312162051579979, 0.926191271942077, 1.095698311512132, 0.742466620037192 -, 1.584565588783017, 1.969369796694313, 0.813142402654688, 1.620637451940408, 0.544472183788396 -, 1.903841273483371, 0.546256895921489, 1.332096299659611, 0.954938400592347, 1.813185033558344 -, 1.183839081745172, 1.159783992966029, 2.047421367071099, 0.933411156868096, 1.092543634708061 -, 2.573430838154017, 0.294001371116989, 2.687145854798348, 0.647676314841560, 1.483222246093897 -, 1.328873011650546, 1.499517291077073, 0.946451616282504, 1.391629977859238, 1.825818674800223 -, 0.197207089634922, 0.418570979518484, 2.713292260256486, 1.451678603677107, 0.725222188153537 -, 1.016524657331659, 0.510160866644535, 0.790663553688482, 0.772267750711634, 0.897737257071539 -, 0.574718435129065, 0.924902659911130, 0.509352052679121, 2.076287755824404, 0.445024255426400 -, 2.306443399859831, 1.009151694589026, 0.311646355326560, 0.915552448311802, 1.631979165302650 -, 1.779435892929737, 1.254791667465325, 1.522546690241251, 2.117005924369452, 0.335708348510442 -, 0.850786945794020, 0.307546485903476, 0.659553530770440, 1.595968673282009, 1.599529339207843 -, 2.050409047591333, 1.321597656988126, 0.382901575350795, 2.263023675024229, 1.795160219589414 -, 0.820728808594631, 1.252616635345433, 2.893059873111469, 1.585968547208113, 1.911105489168721 -, 1.065697540675240, 1.127880912464618, 1.282656038601722, 0.791791712034066, 1.662754292624110 -, 1.184021211521453, 1.442739185251488, 0.857673288506446, 0.546518081971571, 1.136176847824479 -, 0.948827835556975, 1.761649333500106, 1.740961388239338, 1.486044626143792, 0.535345914616625 -, 0.208765940502775, 1.281107790531077, 0.845985407399993, 2.367961441281100, 2.813630157287030 -, 0.821877833204895, 1.796411857645166, 2.128114489536385, 1.349167308872121, 2.075721258630550 -, 2.399008601572707, 1.262250789152573, 1.614544130176768, 1.311344244094387, 0.228900207318000 -, 1.087703540854728, 1.441743192607425, 1.213654375953261, 0.965104247192400, 2.352343682973261 -, 1.881070184767099, 1.944757925743782, 0.965470015113788, 1.341190290874416, 2.029803572337272 -, 2.328337398097465, 1.485947310986503, 0.680661741126981, 1.456629522069083, 0.386727549117631 -, 1.021861509017076, 1.482839980680464, 2.329786461679046, 1.825236378759161, 1.151270272972182 -, 1.681465022236889, 1.038893153052472, 2.671305569135296, 2.973463508311512, 1.998091967015353 -, 0.992439538152367, 1.101359057223470, 0.752694797719731, 1.751820513743222, 2.070842495255286 -, 2.213621940109904, 1.278350678290866, 1.351639733749908, 0.567799782374724, 2.144632385787214 -, 1.094123263719430, 0.678615107641789, 2.144341891738539, 1.695846624058156, 2.069396249839028 -, 0.819027610733285, 1.495651321040951, 1.477482666605742, 0.511932330475827, 1.022837224533765 -, 0.802470556959117, 1.588170058614226, 0.816352471969601, 2.128510415901388, 1.871914791729839 -, 0.994323676062132, 1.173849936976207, 1.540652455108271, 1.896308447022061, 1.371611808573705 -, 1.307706279079749, 0.888355489837264, 1.104161992788381, 1.581802123863791, 2.077336259709684 -, 1.597514520759674, 0.193846187739953, 1.498827901810269, 1.074392126178632, 1.073250683084153 -, 0.498942436443271, 1.836126539886937, 0.886372885469560, 0.751884958648598, 0.916116650002177 -, 0.970681891155015, 1.257679318479529, 1.284798886225563, 1.003879276488743, 1.007685729946785 -, 1.203631029712442, 1.463948632472297, 2.455282398854625, 1.600867640016765, 1.010899145846306 -, 1.888399192628552, 0.537142702822369, 0.353191429514348, 0.419544177537439, 0.598339442960937 -, 0.885310772269136, 0.847519694333472, 0.153295465546788, 1.246051759006313, 0.447732587957780 -, 0.562898114036050, 1.412332385111654, 1.980540530235424, 2.704891701651084, 1.300708887507808 -, 3.394236570275002, 1.269967710402906, 1.203787442037781, 0.896098313870595, 1.060303799139334 -, 1.163522680114773, 0.383891805234107, 2.091377862339729, 0.365559694796422, 1.070541000579430 -, 1.872070722040661, 1.001756457029345, 1.378809939003001, 1.847850278543804, 2.085003935284227 -, 2.313122510412947, 0.650676881494584, 0.773551613369587, 2.136102299351586, 1.341515248421647 -, 1.183940022628347, 1.377562113620296, 1.850185830133746, 1.232112165168803, 0.671923793165544 -, 1.099946548218587, 1.056844894152012, 2.601133375396755, 1.391207328945862, 1.541896787253508 -, 1.595966007631807, 0.923057590473980, 1.206415179152940, 1.275536301443908, 0.583420447186398 -, 1.285040337652167, 1.540648406694559, 1.054438062631050, 1.902387509769504, 1.621409166908371 -, 0.944812793164613, 1.100477476680040, 0.988327442233132, 1.728654388101105, 1.628053244977060 -, 1.060760561571943, 1.538416018178277, 2.410108392389236, 1.751316245100324, 1.563790463015108 -, 0.481219389518454, 0.994165631555275, 1.337016990968870, 1.109088579526755, 0.321407029232422 -, 0.720641073906049, 1.895735773634961, 0.177585824024661, 1.996485240483058, 0.403199585960614 -, 1.487121772300537, 1.177769008306152, 0.701273995641151, 1.302101876486422, 0.510537251157601 -, 1.491444215081535, 1.352963516576160, 0.339422616073620, 0.340565840833962, 0.575265488888648 -, 0.199078454324122, 1.068868838035460, 1.889502831203267, 1.386174255623796, 1.211807597487022 -, 1.997063801362690, 0.453444401722453, 2.184735356478338, 0.478137766710008, 1.206426055203951 -, 0.555876664495711, 1.280274233919441, 0.095813804344955, 1.706079097312628, 1.943477111398666 -, 2.230140630510882, 2.946309044620703, 1.186142019401047, 0.795814141941795, 0.460857387230226 -, 1.190772316835832, 1.327362504940310, 1.696595922853605, 0.416190042989537, 1.472083830192951 -, 1.206395605479538, 0.612524363189761, 2.362058183247366, 1.336246455616561, 1.077916969428414 -, 2.385755851351826, 1.460727990062456, 1.096704997935700, 1.913474394478998, 1.233385699260248, -1.270577147048640, 1.509727846778659, 0.956645085964223, 0.739599713571419, 1.315249583679571, -2.008261585625269, 1.021943728886631, 0.488828195617451, 1.083244894832682, 0.844912313732214, -1.013054512108690, 1.893114294699785, 1.016751451332806, 0.994570044372612, 0.945503828258995]) + KDEDEMO3 shows that the transformation KDE is a better estimate for + Rayleigh distributed data around 0 than the ordinary KDE. + ''' + import scipy.stats as st + data = st.rayleigh.rvs(scale=1, size=(2,300)) + + #x = np.linspace(1.5e-3, 5, 55) + + kde = KDE(data) + f = kde(output='plot', title='Ordinary KDE', plotflag=1) + pylab.figure(0) + f.plot() + + pylab.plot(data[0], data[1], '.') + + #plotnorm((data).^(L2)) % gives a straight line => L2 = 0.5 reasonable + + tkde = TKDE(data, L2=0.5) + ft = tkde.eval_grid_fast(output='plot', title='Transformation KDE', plotflag=1) + + + pylab.figure(1) + ft.plot() + + pylab.plot(data[0],data[1], '.') + + pylab.figure(0) def test_docstrings(): @@ -2498,5 +2477,5 @@ def test_docstrings(): doctest.testmod() if __name__ == '__main__': - #test_docstrings() - test_gridcount() + test_docstrings() + \ No newline at end of file diff --git a/pywafo/src/wafo/test/test_kdetools.py b/pywafo/src/wafo/test/test_kdetools.py index f589f87..324a035 100644 --- a/pywafo/src/wafo/test/test_kdetools.py +++ b/pywafo/src/wafo/test/test_kdetools.py @@ -26,15 +26,16 @@ def test0_KDE1D(): array([ 0.2039735 , 0.40252503, 0.54595078, 0.52219649, 0.3906213 , 0.26381501, 0.16407362, 0.08270612, 0.02991145, 0.00720821]) >>> kde0.eval_grid_fast(x) - array([ 0.32343789, 0.51366167, 0.55643329, 0.43688805, 0.28972471, - 0.19445277, 0.12473331, 0.06195215, 0.02087712, 0.00449567]) + array([ 0.20729484, 0.39865044, 0.53716945, 0.5169322 , 0.39060223, + 0.26441126, 0.16388801, 0.08388527, 0.03227164, 0.00883579]) + >>> f = kde0.eval_grid_fast(); f - array([ 0.02076721, 0.0612371 , 0.14515308, 0.27604202, 0.42001793, - 0.51464781, 0.52131018, 0.45976136, 0.37621768, 0.29589521, - 0.21985316, 0.1473364 , 0.08502256, 0.04063749, 0.0155788 , - 0.00466938]) + array([ 0.01149411, 0.03485467, 0.08799292, 0.18568718, 0.32473136, + 0.46543163, 0.54532016, 0.53005828, 0.44447651, 0.34119612, + 0.25103852, 0.1754952 , 0.11072989, 0.05992731, 0.02687784, + 0.00974983]) >>> np.trapz(f,kde0.args) - array([ 0.99416766]) + array([ 0.99500101]) ''' def test1_TKDE1D(): ''' @@ -167,9 +168,13 @@ def test_KDE2D(): >>> kde0 = wk.KDE(data, hs=0.5, alpha=0.0, inc=16) >>> kde0.eval_grid(x, x) - + array([[ 3.27260963e-02, 4.21654678e-02, 5.85338634e-04], + [ 6.78845466e-02, 1.42195839e-01, 1.41676003e-03], + [ 1.39466746e-04, 4.26983850e-03, 2.52736185e-05]]) >>> kde0.eval_grid_fast(x, x) - + array([[ 0.08670654, 0.12577712, 0.00808478], + [ 0.1411195 , 0.24160579, 0.01816001], + [ 0.0031541 , 0.01553967, 0.00114854]]) ''' def test_smooth_params(): @@ -197,16 +202,16 @@ def test_smooth_params(): [ -2.68892467e-02, 3.91283306e-01, 2.38654678e-02], [ 3.18932448e-04, 2.38654678e-02, 4.05123874e-01]]) >>> gauss.hscv(data) - array([ 0.16858959, 0.33034332, 0.3046287 ]) + array([ 0.16858959, 0.32739383, 0.3046287 ]) >>> gauss.hstt(data) - array([ 0.18196282, 0.51090571, 0.1111913 ]) + array([ 0.18099075, 0.50409881, 0.11018912]) >>> gauss.hste(data) - array([ 0.1683984 , 0.29693232, 0.17974833]) + array([ 0.16750009, 0.29059113, 0.17994255]) >>> gauss.hldpi(data) - array([ 0.17426948, 0.33672307, 0.31240374]) + array([ 0.1732289 , 0.33159097, 0.3107633 ]) ''' def test_gridcount_1D(): ''' @@ -219,11 +224,14 @@ def test_gridcount_1D(): >>> x = np.linspace(0, max(data.ravel()) + 1, 10) >>> dx = x[1] - x[0] - >>> c = wk.gridcount(data, x) + >>> c = wk.gridcount(data, x, use_sparse=False) >>> c - array([ 1., 6., 7., 2., 1., 2., 1., 0., 0., 0.]) - - + array([ 0.78762626, 1.77520717, 7.99190087, 4.04054449, 1.67156643, + 2.38228499, 1.05933195, 0.29153785, 0. , 0. ]) + >>> wk.gridcount(data, x, use_sparse=True) + array([ 0.78762626, 1.77520717, 7.99190087, 4.04054449, 1.67156643, + 2.38228499, 1.05933195, 0.29153785, 0. , 0. ]) + h = plb.plot(x, c, '.') # 1D histogram h1 = plb.plot(x, c / dx / N) # 1D probability density plot @@ -281,19 +289,19 @@ def test_gridcount_3D(): >>> x = np.linspace(0, max(data.ravel()) + 1, 3) >>> dx = x[1] - x[0] >>> X = np.vstack((x, x, x)) - >>> c = wk.gridcount(data, X) + >>> c = wk.gridcount(data, X, use_sparse=True) >>> c - array([[[ 8.74229894e-01, 1.44969128e+00, 7.49265424e-02], - [ 1.94778915e+00, 2.28951650e+00, 8.53886762e-02], - [ 1.08429416e-01, 1.10905565e-01, 4.16196568e-04]], + array([[[ 8.74229894e-01, 1.27910940e+00, 1.42033973e-01], + [ 1.94778915e+00, 2.59536282e+00, 3.28213680e-01], + [ 1.08429416e-01, 1.69571495e-01, 7.48896775e-03]], - [[ 1.27910940e+00, 2.58396370e+00, 2.18142488e-01], - [ 2.59536282e+00, 4.49653348e+00, 3.73415131e-01], - [ 1.69571495e-01, 3.18733817e-01, 1.62218824e-02]], + [[ 1.44969128e+00, 2.58396370e+00, 2.45459949e-01], + [ 2.28951650e+00, 4.49653348e+00, 2.73167915e-01], + [ 1.10905565e-01, 3.18733817e-01, 1.12880816e-02]], - [[ 1.42033973e-01, 2.45459949e-01, 0.00000000e+00], - [ 3.28213680e-01, 2.73167915e-01, 0.00000000e+00], - [ 7.48896775e-03, 1.12880816e-02, 0.00000000e+00]]]) + [[ 7.49265424e-02, 2.18142488e-01, 0.00000000e+00], + [ 8.53886762e-02, 3.73415131e-01, 0.00000000e+00], + [ 4.16196568e-04, 1.62218824e-02, 0.00000000e+00]]]) ''' def test_gridcount_4D(): @@ -314,43 +322,43 @@ def test_gridcount_4D(): >>> X = np.vstack((x, x, x, x)) >>> c = wk.gridcount(data, X) >>> c - array([[[[ 1.77163904e-01, 3.41883175e-01, 3.13810608e-03], - [ 1.83770124e-01, 3.58861043e-01, 7.05946179e-03], - [ 0.00000000e+00, 2.91835067e-03, 6.38695629e-04]], + array([[[[ 1.77163904e-01, 1.87720108e-01, 0.00000000e+00], + [ 5.72573585e-01, 6.09557834e-01, 0.00000000e+00], + [ 3.48549923e-03, 4.05931870e-02, 0.00000000e+00]], - [[ 5.72573585e-01, 5.72071865e-01, 6.71606255e-03], - [ 4.35845892e-01, 8.80697705e-01, 1.09099593e-01], - [ 0.00000000e+00, 3.63686503e-02, 1.00358265e-02]], + [[ 1.83770124e-01, 2.56357594e-01, 0.00000000e+00], + [ 4.35845892e-01, 6.14958970e-01, 0.00000000e+00], + [ 3.07662204e-03, 3.58312786e-02, 0.00000000e+00]], - [[ 3.48549923e-03, 3.46939323e-03, 0.00000000e+00], - [ 3.07662204e-03, 2.22868504e-01, 6.61257395e-02], - [ 0.00000000e+00, 1.88555613e-02, 5.67244468e-03]]], + [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], + [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], + [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]], - [[[ 1.87720108e-01, 5.97977973e-01, 2.11731327e-02], - [ 2.56357594e-01, 6.28962785e-01, 5.44614852e-02], - [ 0.00000000e+00, 2.60268355e-02, 5.69610302e-03]], + [[[ 3.41883175e-01, 5.97977973e-01, 0.00000000e+00], + [ 5.72071865e-01, 8.58566538e-01, 0.00000000e+00], + [ 3.46939323e-03, 4.04056116e-02, 0.00000000e+00]], - [[ 6.09557834e-01, 8.58566538e-01, 4.53139824e-02], - [ 6.14958970e-01, 1.47373158e+00, 1.95935584e-01], - [ 0.00000000e+00, 1.07959459e-01, 2.44053065e-02]], + [[ 3.58861043e-01, 6.28962785e-01, 0.00000000e+00], + [ 8.80697705e-01, 1.47373158e+00, 0.00000000e+00], + [ 2.22868504e-01, 1.18008528e-01, 0.00000000e+00]], - [[ 4.05931870e-02, 4.04056116e-02, 0.00000000e+00], - [ 3.58312786e-02, 1.18008528e-01, 2.47717418e-02], - [ 0.00000000e+00, 7.06358976e-03, 2.12498697e-03]]], + [[ 2.91835067e-03, 2.60268355e-02, 0.00000000e+00], + [ 3.63686503e-02, 1.07959459e-01, 0.00000000e+00], + [ 1.88555613e-02, 7.06358976e-03, 0.00000000e+00]]], - [[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], - [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], + [[[ 3.13810608e-03, 2.11731327e-02, 0.00000000e+00], + [ 6.71606255e-03, 4.53139824e-02, 0.00000000e+00], [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]], - [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], - [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], - [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]], + [[ 7.05946179e-03, 5.44614852e-02, 0.00000000e+00], + [ 1.09099593e-01, 1.95935584e-01, 0.00000000e+00], + [ 6.61257395e-02, 2.47717418e-02, 0.00000000e+00]], - [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], - [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00], - [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00]]]]) + [[ 6.38695629e-04, 5.69610302e-03, 0.00000000e+00], + [ 1.00358265e-02, 2.44053065e-02, 0.00000000e+00], + [ 5.67244468e-03, 2.12498697e-03, 0.00000000e+00]]]]) h = plb.plot(x, c, '.') # 1D histogram diff --git a/pywafo/src/wafo/wafodata.py b/pywafo/src/wafo/wafodata.py index f9cfe5d..062cec1 100644 --- a/pywafo/src/wafo/wafodata.py +++ b/pywafo/src/wafo/wafodata.py @@ -137,11 +137,16 @@ class AxisLabels: self.title = title self.xlab = xlab self.ylab = ylab - self.zlab = zlab + self.zlab = zlab + def __repr__(self): + return self.__str__() + def __str__(self): + return '%s\n%s\n%s\n%s\n' % (self.title, self.xlab, self.ylab, self.zlab) def copy(self): newcopy = empty_copy(self) newcopy.__dict__.update(self.__dict__) return newcopy + def labelfig(self): try: h1 = plotbackend.title(self.title) @@ -169,7 +174,6 @@ class Plotter_1d(object): step : stair-step plot scatter : scatter plot """ - def __init__(self, plotmethod='plot'): self.plotfun = None if plotmethod is None: @@ -179,11 +183,12 @@ class Plotter_1d(object): self.plotfun = getattr(plotbackend, plotmethod) except: pass + def show(self): plotbackend.show() def plot(self, wdata, *args, **kwds): - plotflag = kwds.pop('plotflag', None) + plotflag = kwds.pop('plotflag', False) if plotflag: h1 = self._plot(plotflag, wdata, **kwds) else: @@ -201,6 +206,7 @@ class Plotter_1d(object): dataCI = () h1 = plot1d(x, data, dataCI, plotflag, *args, **kwds) return h1 + def plot1d(args, data, dataCI, plotflag, *varargin, **kwds): plottype = np.mod(plotflag, 10) @@ -223,10 +229,10 @@ def plot1d(args, data, dataCI, plotflag, *varargin, **kwds): else: H = plotbackend.fill_between(args, data, *varargin, **kwds); - scale = plotscale(plotflag); - logXscale = any(scale == 'x'); - logYscale = any(scale == 'y'); - logZscale = any(scale == 'z'); + scale = plotscale(plotflag) + logXscale = 'x' in scale + logYscale = 'y' in scale + logZscale = 'z' in scale ax = plotbackend.gca() if logXscale: plotbackend.setp(ax, xscale='log') @@ -346,11 +352,12 @@ def plot2d(wdata, plotflag, *args, **kwds): else: args1 = tuple((wdata.args,)) + (wdata.data,) + args if plotflag in (1, 6, 7, 8, 9): - PL = 0 - if hasattr(f, 'cl') and len(f.cl) > 0: # check if contour levels is submitted - CL = f.cl - if hasattr(f, 'pl'): - PL = f.pl # levels defines quantile levels? 0=no 1=yes + isPL = False + if hasattr(f, 'clevels') and len(f.clevels) > 0: # check if contour levels is submitted + CL = f.clevels + isPL = hasattr(f, 'plevels') and f.plevels is not None + if isPL: + PL = f.plevels # levels defines quantile levels? 0=no 1=yes else: dmax = np.max(f.data) dmin = np.min(f.data) @@ -367,10 +374,8 @@ def plot2d(wdata, plotflag, *args, **kwds): if ncl > 12: ncl = 12 warnings.warn('Only the first 12 levels will be listed in table.') - if PL: - clvals, isPL = PL[:ncl], True - else: - clvals, isPL = clvec[:ncl], False + + clvals = PL[:ncl] if isPL else clvec[:ncl] unused_axcl = cltext(clvals, percent=isPL) # print contour level text elif any(plotflag == [7, 9]): plotbackend.clabel(h) @@ -390,20 +395,19 @@ def plot2d(wdata, plotflag, *args, **kwds): plotbackend.colorbar(h) else: raise ValueError('unknown option for plotflag') - - #if any(plotflag==(2:5)) # shading(shad); #end # pass - + +def test_docstrings(): + import doctest + doctest.testmod() + def main(): pass if __name__ == '__main__': - if True: #False : # - import doctest - doctest.testmod() - else: - main() + test_docstrings() + #main()