diff --git a/wafo/wave_theory/dispersion_relation.py b/wafo/wave_theory/dispersion_relation.py index 86b5d64..138a362 100644 --- a/wafo/wave_theory/dispersion_relation.py +++ b/wafo/wave_theory/dispersion_relation.py @@ -9,13 +9,23 @@ import numpy as np from wafo.misc import lazywhere from numpy import (atleast_1d, sqrt, ones_like, zeros_like, arctan2, where, tanh, sin, cos, sign, inf, - flatnonzero, finfo, cosh, abs) + flatnonzero, finfo, cosh) __all__ = ['k2w', 'w2k'] +def _assert(cond, msg): + if not cond: + raise ValueError(msg) + + +def _assert_warn(cond, msg): + if not cond: + warnings.warn(msg) + + def k2w(k1, k2=0e0, h=inf, g=9.81, u1=0e0, u2=0e0): - ''' Translates from wave number to frequency + """ Translates from wave number to frequency using the dispersion relation Parameters @@ -61,7 +71,7 @@ def k2w(k1, k2=0e0, h=inf, g=9.81, u1=0e0, u2=0e0): array([ 0.3132092 , 1.43530485, 2.00551739]) >>> wsd.k2w(arange(0.01,.5,0.2),h=20)[0] array([ 0.13914927, 1.43498213, 2.00551724]) - ''' + """ k1i, k2i, hi, gi, u1i, u2i = atleast_1d(k1, k2, h, g, u1, u2) @@ -75,21 +85,19 @@ def k2w(k1, k2=0e0, h=inf, g=9.81, u1=0e0, u2=0e0): k = sqrt(k1i ** 2 + k2i ** 2) w = where(k > 0, ku1 + ku2 + sqrt(gi * k * tanh(k * hi)), 0.0) - cond = (w < 0) - if np.any(cond): - txt0 = ''' - Waves and current are in opposite directions - making some of the frequencies negative. - Here we are forcing the negative frequencies to zero. - ''' - warnings.warn(txt0) - w = where(cond, 0.0, w) # force w to zero + cond = (0 <= w) + _assert_warn(np.all(cond), """ + Waves and current are in opposite directions + making some of the frequencies negative. + Here we are forcing the negative frequencies to zero. + """) + w = where(cond, w, 0.0) # force w to zero return w, theta def w2k(w, theta=0.0, h=inf, g=9.81, count_limit=100): - ''' + """ Translates from frequency to wave number using the dispersion relation @@ -136,7 +144,7 @@ def w2k(w, theta=0.0, h=inf, g=9.81, count_limit=100): See also -------- k2w - ''' + """ wi, th, hi, gi = atleast_1d(w, theta, h, g) if wi.size == 0: @@ -147,10 +155,8 @@ def w2k(w, theta=0.0, h=inf, g=9.81, count_limit=100): k2 = k * sin(th) * gi[0] / gi[-1] # size np x nf k1 = k * cos(th) return k1, k2 - - if gi.size > 1: - raise ValueError('Finite depth in combination with 3D normalization' + - ' (len(g)=2) is not implemented yet.') + _assert(gi.size == 1, 'Finite depth in combination with 3D normalization' + ' (len(g)=2) is not implemented yet.') find = flatnonzero eps = finfo(float).eps @@ -193,9 +199,8 @@ def w2k(w, theta=0.0, h=inf, g=9.81, count_limit=100): np.abs(hn) > sqrt(eps)) count += 1 - if count == count_limit: - warnings.warn('W2K did not converge. The maximum error in the ' + - 'last step was: %13.8f' % max(hn[ix])) + _assert_warn(count < count_limit, 'W2K did not converge. ' + 'Maximum error in the last step was: %13.8f' % max(hn[ix])) k.shape = oshape