From 5a948ffbbc86f37d1386e355a6c8d07ff851a775 Mon Sep 17 00:00:00 2001 From: Per A Brodtkorb Date: Thu, 29 Dec 2016 12:24:08 +0100 Subject: [PATCH] Small updates. --- wafo/kdetools/kernels.py | 19 ++++---- wafo/kdetools/tests/test_gridding.py | 66 +++++++++++++++------------- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/wafo/kdetools/kernels.py b/wafo/kdetools/kernels.py index 4c2ca49..560d479 100644 --- a/wafo/kdetools/kernels.py +++ b/wafo/kdetools/kernels.py @@ -9,8 +9,9 @@ import warnings import numpy as np from numpy import pi, sqrt, exp, percentile from numpy.fft import fft, ifft -from scipy import optimize, linalg +from scipy import optimize from scipy.special import gamma +from scipy.linalg import sqrtm from wafo.misc import tranproc # , trangood from wafo.kdetools.gridding import gridcount from wafo.dctpack import dct @@ -289,8 +290,8 @@ class _Kernel(object): def kernel(self, x): return self._kernel(np.atleast_2d(x)) - def deriv4_6_8_10(self, t, numout=4): - raise NotImplementedError('Method not implemented for this kernel!') +# def deriv4_6_8_10(self, t, numout=4): +# raise NotImplementedError('Method not implemented for this kernel!') def get_ste_constant(self, n): mu2, R = self.stats[:2] @@ -542,7 +543,7 @@ class Kernel(object): """ - def __init__(self, name, fun='hste'): # 'hns'): + def __init__(self, name, fun='hste'): self.kernel = _MKERNEL_DICT[name[:4]] self.get_smoothing = getattr(self, fun) @@ -681,7 +682,7 @@ class Kernel(object): """ return self.hns(data) / 0.93 - def _hmns_scale(self, d): + def _hmns_scale(self, n, d): name = self.name short_name = name[:4].lower() scale_dict = dict(epan=(8.0 * (d + 4.0) * (2 * sqrt(pi)) ** d / @@ -694,7 +695,7 @@ class Kernel(object): if d > 2 and short_name in ['biwe', 'triw']: raise NotImplementedError('Not implemented for d>2 and ' 'kernel {}'.format(name)) - return scale_dict[short_name] + return scale_dict[short_name] * n ** (-1. / (d + 4)) def hmns(self, data): """Returns Multivariate Normal Scale Estimate of Smoothing Parameter. @@ -742,9 +743,7 @@ class Kernel(object): d, n = a.shape if d == 1: return self.hns(data) - scale = self._hmns_scale(d) - cov_a = np.cov(a) - return scale * linalg.sqrtm(cov_a).real * n ** (-1. / (d + 4)) + return self._hmns_scale(n, d) * np.real(sqrtm(np.cov(a))) @staticmethod def _get_g(k_order_2, mu2, psi_order, n, order): @@ -1069,7 +1068,7 @@ class Kernel(object): True import matplotlib.pyplot as plt - plt.plot(hvec,score) + plt.plot(hvec, score) See also: hste, hbcv, hboot, hos, hldpi, hlscv, hstt, kde, kdefun diff --git a/wafo/kdetools/tests/test_gridding.py b/wafo/kdetools/tests/test_gridding.py index 8e8db21..b4fa4bf 100644 --- a/wafo/kdetools/tests/test_gridding.py +++ b/wafo/kdetools/tests/test_gridding.py @@ -19,6 +19,7 @@ class TestKdeTools(unittest.TestCase): dx = x[1] - x[0] c = wkg.gridcount(data, x) + assert_allclose(c.sum(), len(data)) assert_allclose(c, [0.1430937435034, 5.864465648665, 9.418694957317207, 2.9154367000439, 0.6583089504704, 0.0, @@ -34,6 +35,7 @@ class TestKdeTools(unittest.TestCase): dx = x[1] - x[0] X = np.vstack((x, x)) c = wkg.gridcount(data, X) + assert_allclose(c.sum(), N) assert_allclose(c, [[0.38922806, 0.8987982, 0.34676493, 0.21042807, 0.], [1.15012203, 5.16513541, 3.19250588, 0.55420752, 0.], @@ -52,6 +54,7 @@ class TestKdeTools(unittest.TestCase): dx = x[1] - x[0] X = np.vstack((x, x, x)) c = wkg.gridcount(data, X) + assert_allclose(c.sum(), N) assert_allclose(c, [[[8.74229894e-01, 1.27910940e+00, 1.42033973e-01], [1.94778915e+00, 2.59536282e+00, 3.28213680e-01], @@ -69,43 +72,44 @@ class TestKdeTools(unittest.TestCase): @staticmethod def test_gridcount_4d(): - N = 20 - data = np.reshape(DATA2D, (4, -1)) + N = 10 + data = np.reshape(DATA2D, (4, N)) x = np.linspace(0, max(np.ravel(data)) + 1, 3) dx = x[1] - x[0] X = np.vstack((x, x, x, x)) c = wkg.gridcount(data, X) - assert_allclose(c, - [[[[1.77163904e-01, 1.87720108e-01, 0.0], - [5.72573585e-01, 6.09557834e-01, 0.0], - [3.48549923e-03, 4.05931870e-02, 0.0]], - [[1.83770124e-01, 2.56357594e-01, 0.0], - [4.35845892e-01, 6.14958970e-01, 0.0], - [3.07662204e-03, 3.58312786e-02, 0.0]], - [[0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0]]], - [[[3.41883175e-01, 5.97977973e-01, 0.0], - [5.72071865e-01, 8.58566538e-01, 0.0], - [3.46939323e-03, 4.04056116e-02, 0.0]], - [[3.58861043e-01, 6.28962785e-01, 0.0], - [8.80697705e-01, 1.47373158e+00, 0.0], - [2.22868504e-01, 1.18008528e-01, 0.0]], - [[2.91835067e-03, 2.60268355e-02, 0.0], - [3.63686503e-02, 1.07959459e-01, 0.0], - [1.88555613e-02, 7.06358976e-03, 0.0]]], - [[[3.13810608e-03, 2.11731327e-02, 0.0], - [6.71606255e-03, 4.53139824e-02, 0.0], - [0.0, 0.0, 0.0]], - [[7.05946179e-03, 5.44614852e-02, 0.0], - [1.09099593e-01, 1.95935584e-01, 0.0], - [6.61257395e-02, 2.47717418e-02, 0.0]], - [[6.38695629e-04, 5.69610302e-03, 0.0], - [1.00358265e-02, 2.44053065e-02, 0.0], - [5.67244468e-03, 2.12498697e-03, 0.0]]]]) + truth = [[[[1.77163904e-01, 1.87720108e-01, 0.0], + [5.72573585e-01, 6.09557834e-01, 0.0], + [3.48549923e-03, 4.05931870e-02, 0.0]], + [[1.83770124e-01, 2.56357594e-01, 0.0], + [4.35845892e-01, 6.14958970e-01, 0.0], + [3.07662204e-03, 3.58312786e-02, 0.0]], + [[0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]]], + [[[3.41883175e-01, 5.97977973e-01, 0.0], + [5.72071865e-01, 8.58566538e-01, 0.0], + [3.46939323e-03, 4.04056116e-02, 0.0]], + [[3.58861043e-01, 6.28962785e-01, 0.0], + [8.80697705e-01, 1.47373158e+00, 0.0], + [2.22868504e-01, 1.18008528e-01, 0.0]], + [[2.91835067e-03, 2.60268355e-02, 0.0], + [3.63686503e-02, 1.07959459e-01, 0.0], + [1.88555613e-02, 7.06358976e-03, 0.0]]], + [[[3.13810608e-03, 2.11731327e-02, 0.0], + [6.71606255e-03, 4.53139824e-02, 0.0], + [0.0, 0.0, 0.0]], + [[7.05946179e-03, 5.44614852e-02, 0.0], + [1.09099593e-01, 1.95935584e-01, 0.0], + [6.61257395e-02, 2.47717418e-02, 0.0]], + [[6.38695629e-04, 5.69610302e-03, 0.0], + [1.00358265e-02, 2.44053065e-02, 0.0], + [5.67244468e-03, 2.12498697e-03, 0.0]]]] + assert_allclose(c.sum(), N) + assert_allclose(c, truth) t = np.trapz(np.trapz(np.trapz(np.trapz(c / dx**4 / N, x), x), x), x) - assert_allclose(t, 0.21183518274521254) + assert_allclose(t, 0.4236703654904251) if __name__ == "__main__":