diff --git a/wafo/dctpack.py b/wafo/dctpack.py index 3741e52..9a01198 100644 --- a/wafo/dctpack.py +++ b/wafo/dctpack.py @@ -1,10 +1,10 @@ import numpy as np -from scipy.fftpack import dct as _dct -from scipy.fftpack import idct as _idct +from scipy.fftpack import dct as _dct, idct as _idct +from scipy.fftpack import dst as _dst, idst as _idst import os path = os.path.dirname(os.path.realpath(__file__)) -__all__ = ['dct', 'idct', 'dctn', 'idctn'] +__all__ = ['dct', 'idct', 'dctn', 'idctn', 'dst', 'idst', 'dstn', 'idstn'] def dct(x, type=2, n=None, axis=-1, norm='ortho'): # @ReservedAssignment @@ -112,6 +112,10 @@ def dct(x, type=2, n=None, axis=-1, norm='ortho'): # @ReservedAssignment return _dct(x, type, n, axis, norm) +def dst(x, type=2, n=None, axis=-1, norm='ortho'): # @ReservedAssignment + return _dst(x, type, n, axis, norm) + + def idct(x, type=2, n=None, axis=-1, norm='ortho'): # @ReservedAssignment ''' Return the Inverse Discrete Cosine Transform of an arbitrary type sequence. @@ -152,6 +156,10 @@ def idct(x, type=2, n=None, axis=-1, norm='ortho'): # @ReservedAssignment return _idct(x, type, n, axis, norm) +def idst(x, type=2, n=None, axis=-1, norm='ortho'): # @ReservedAssignment + return _idst(x, type, n, axis, norm) + + def _get_shape(y, shape, axes): if shape is None: if axes is None: @@ -268,6 +276,12 @@ def dctn(x, type=2, shape=None, axes=None, # @ReservedAssignment return _raw_dctn(y, type, shape, axes, norm, dct) +def dstn(x, type=2, shape=None, axes=None, # @ReservedAssignment + norm='ortho'): + y = np.atleast_1d(x) + return _raw_dctn(y, type, shape, axes, norm, dst) + + def idctn(x, type=2, shape=None, axes=None, # @ReservedAssignment norm='ortho'): '''Return inverse N-D Discrete Cosine Transform of array x. @@ -282,6 +296,12 @@ def idctn(x, type=2, shape=None, axes=None, # @ReservedAssignment return _raw_dctn(y, type, shape, axes, norm, idct) +def idstn(x, type=2, shape=None, axes=None, # @ReservedAssignment + norm='ortho'): + y = np.atleast_1d(x) + return _raw_dctn(y, type, shape, axes, norm, idst) + + def num_leading_ones(x): first = 0 for i, xi in enumerate(x):