diff --git a/wafo/gaussian.py b/wafo/gaussian.py index 2a3a896..1d72e7a 100644 --- a/wafo/gaussian.py +++ b/wafo/gaussian.py @@ -7,7 +7,6 @@ from scipy.special import ndtr as cdfnorm, ndtri as invnorm from scipy.special import erfc import warnings import numpy as np -from .misc import common_shape try: from . import mvn # @UnresolvedImport @@ -836,10 +835,9 @@ def cdfnorm2d(b1, b2, r): # pullman, wa 99164-3113 # email : alangenz@wsu.edu - cshape = common_shape(b1, b2, r, shape=[1, ]) - one = ones(cshape) - - h, k, r = (-b1 * one).ravel(), (-b2 * one).ravel(), (r * one).ravel() + b1, b2, r = np.broadcast_arrays(b1, b2, r) + cshape = b1.shape + h, k, r = -b1.ravel(), -b2.ravel(), r.ravel() bvn = where(abs(r) > 1, nan, 0.0) diff --git a/wafo/misc.py b/wafo/misc.py index f6beab9..c765aea 100644 --- a/wafo/misc.py +++ b/wafo/misc.py @@ -1764,48 +1764,10 @@ def common_shape(*args, ** kwds): -------- broadcast, broadcast_arrays ''' - args = [asarray(x) for x in args] - shapes = [x.shape for x in args] shape = kwds.get('shape') - if shape is not None: - if not isinstance(shape, (list, tuple)): - shape = (shape,) - shapes.append(tuple(shape)) - if len(set(shapes)) == 1: - # Common case where nothing needs to be broadcasted. - return tuple(shapes[0]) - shapes = [list(s) for s in shapes] - nds = [len(s) for s in shapes] - biggest = max(nds) - # Go through each array and prepend dimensions of length 1 to each of the - # shapes in order to make the number of dimensions equal. - for i in range(len(shapes)): - diff = biggest - nds[i] - if diff > 0: - shapes[i] = [1] * diff + shapes[i] - - # Check each dimension for compatibility. A dimension length of 1 is - # accepted as compatible with any other length. - c_shape = [] - for axis in range(biggest): - lengths = [s[axis] for s in shapes] - unique = set(lengths + [1]) - if len(unique) > 2: - # There must be at least two non-1 lengths for this axis. - raise ValueError("shape mismatch: two or more arrays have " - "incompatible dimensions on axis %r." % (axis,)) - elif len(unique) == 2: - # There is exactly one non-1 length. - # The common shape will take this value. - unique.remove(1) - new_length = unique.pop() - c_shape.append(new_length) - else: - # Every array has a length of 1 on this axis. Strides can be left - # alone as nothing is broadcasted. - c_shape.append(1) - - return tuple(c_shape) + x0 = 1 if shape is None else np.ones(shape) + x1 = np.broadcast(x0, *args) + return tuple(x1.shape) def argsreduce(condition, * args): @@ -2787,7 +2749,7 @@ def num2pistr(x, n=3, numerator_max=10, denominator_max=10): return fmt % x -def fourier(data, t=None, T=None, m=None, n=None, method='trapz'): +def fourier(data, t=None, period=None, m=None, n=None, method='trapz'): ''' Returns Fourier coefficients. @@ -2797,7 +2759,7 @@ def fourier(data, t=None, T=None, m=None, n=None, method='trapz'): vector or matrix of row vectors with data points shape p x n. t : array-like vector with n values indexed from 1 to N. - T : real scalar, (default T = t[-1]-t[0]) + period : real scalar, (default T = t[-1]-t[0]) primitive period of signal, i.e., smallest period. m : scalar integer defines no of harmonics desired (default M = N) @@ -2845,19 +2807,12 @@ def fourier(data, t=None, T=None, m=None, n=None, method='trapz'): ''' x = np.atleast_2d(data) p, n = x.shape - if t is None: - t = np.arange(n) - else: - t = np.atleast_1d(t) + t = np.arange(n) if t is None else np.atleast_1d(t) n = len(t) if n is None else n - m = n if n is None else m - T = t[-1] - t[0] if T is None else T - - if method.startswith('trapz'): - intfun = trapz - elif method.startswith('simp'): - intfun = simps + m = n if m is None else m + period = t[-1] - t[0] if period is None else period + intfun = trapz if method.startswith('trapz') else simps # Define the vectors for computing the Fourier coefficients t.shape = (1, -1) @@ -2866,8 +2821,7 @@ def fourier(data, t=None, T=None, m=None, n=None, method='trapz'): a[0] = intfun(x, t, axis=-1) # Compute M-1 more coefficients - tmp = 2 * pi * t / T - # tmp = 2*pi*(0:N-1).'/(N-1); + tmp = 2 * pi * t / period for i in range(1, m): a[i] = intfun(x * cos(i * tmp), t, axis=-1) b[i] = intfun(x * sin(i * tmp), t, axis=-1)