diff --git a/wafo/kdetools/kdetools.py b/wafo/kdetools/kdetools.py index be5728e..f6981ed 100644 --- a/wafo/kdetools/kdetools.py +++ b/wafo/kdetools/kdetools.py @@ -212,18 +212,18 @@ class _KDE(object): pass def get_args(self, xmin=None, xmax=None): - if xmin is None: - xmin = self.xmin - else: - xmin = [min(i, j) for i, j in zip(xmin, self.xmin)] - if xmax is None: - xmax = self.xmax - else: - xmax = [max(i, j) for i, j in zip(xmax, self.xmax)] + sxmin = self.xmin + if xmin is not None: + sxmin = np.minimum(xmin, sxmin) + + sxmax = self.xmax + if xmax is not None: + sxmax = np.maximum(xmax, sxmax) + args = [] inc = self.inc for i in range(self.d): - args.append(np.linspace(xmin[i], xmax[i], inc)) + args.append(np.linspace(sxmin[i], sxmax[i], inc)) return args def eval_grid_fast(self, *args, **kwds): @@ -523,7 +523,7 @@ class TKDE(_KDE): return self._eval_grid_fun(self._eval_grid_fast, *args, **kwds) def _interpolate(self, points, f, *args, **kwds): - ipoints = meshgrid(*args) if self.d > 1 else args + ipoints = meshgrid(*args) # if self.d > 1 else args for i in range(self.d): points[i].shape = -1, points = np.asarray(points).T @@ -536,20 +536,25 @@ class TKDE(_KDE): return fi * (fi > 0) return fi - def _eval_grid_fast(self, *args, **kwds): - if self.L2 is None: - f = self.tkde.eval_grid_fast(*args, **kwds) - self.args = self.tkde.args - return f + def _get_targs(self, args): targs = [] if len(args): targs0 = self._dat2gaus(list(args)) xmin = [min(t) for t in targs0] xmax = [max(t) for t in targs0] targs = self.tkde.get_args(xmin, xmax) + return targs + + def _eval_grid_fast(self, *args, **kwds): + if self.L2 is None: + f = self.tkde.eval_grid_fast(*args, **kwds) + self.args = self.tkde.args + return f + targs = self._get_targs(args) tf = self.tkde.eval_grid_fast(*targs) + self.args = self._gaus2dat(list(self.tkde.args)) - points = meshgrid(*self.args) if self.d > 1 else self.args + points = meshgrid(*self.args) f = self._scale_pdf(tf, points) if len(args): return self._interpolate(points, f, *args, **kwds) @@ -560,7 +565,7 @@ class TKDE(_KDE): return self.tkde.eval_grid(*args, **kwds) targs = self._dat2gaus(list(args)) tf = self.tkde.eval_grid(*targs, **kwds) - points = meshgrid(*args) if self.d > 1 else list(args) + points = meshgrid(*args) f = self._scale_pdf(tf, points) return f @@ -697,7 +702,7 @@ class KDE(_KDE): t = np.trapz(f, x) """ @staticmethod - def _make_grid(dx, d, inc): + def _make_flat_grid(dx, d, inc): Xn = [] x0 = np.linspace(-inc, inc, 2 * inc + 1) for i in range(d): @@ -707,10 +712,9 @@ class KDE(_KDE): for i in range(d): Xnc[i].shape = (-1,) - return Xnc + return np.vstack(Xnc) def _kernel_weights(self, Xn, dx, d, inc): - # Obtain the kernel weights. kw = self.kernel(Xn) norm_fact0 = (kw.sum() * dx.prod() * self.n) norm_fact = (self._norm_factor * self.kernel.norm_factor(d, self.n)) @@ -729,16 +733,16 @@ class KDE(_KDE): d, inc = X.shape dx = X[:, 1] - X[:, 0] - Xnc = self._make_grid(dx, d, inc) + Xnc = self._make_flat_grid(dx, d, inc) - Xn = np.dot(self._inv_hs, np.vstack(Xnc)) + Xn = np.dot(self._inv_hs, Xnc) kw = self._kernel_weights(Xn, dx, d, inc) r = kwds.get('r', 0) if r != 0: - kw *= np.vstack(Xnc) ** r if d > 1 else Xnc[0] ** r - shape0 = (2 * inc, ) * d - kw.shape = shape0 + fun = self._moment_fun(r) + kw *= fun(np.vstack(Xnc)) + kw.shape = (2 * inc, ) * d kw = np.fft.ifftshift(kw) y = kwds.get('y', 1.0) @@ -748,7 +752,7 @@ class KDE(_KDE): # Find the binned kernel weights, c. c = gridcount(self.dataset, X, y=y) # Perform the convolution. - z = np.real(ifftn(fftn(c, s=shape0) * fftn(kw))) + z = np.real(ifftn(fftn(c, s=kw.shape) * fftn(kw))) ix = (slice(0, inc),) * d if r == 0: