Simplified code

master
pbrod 8 years ago
parent 460ae6f819
commit 4b21cc83c2

@ -12,7 +12,7 @@ from numpy import (sqrt, atleast_1d, abs, imag, sign, where, cos, arccos, ceil,
expm1, log1p, pi) expm1, log1p, pi)
import numpy as np import numpy as np
import warnings import warnings
from .core import TrCommon, TrData from wafo.transform.core import TrCommon, TrData
__all__ = ['TrHermite', 'TrLinear', 'TrOchi'] __all__ = ['TrHermite', 'TrLinear', 'TrOchi']
_EPS = np.finfo(float).eps _EPS = np.finfo(float).eps
@ -529,50 +529,41 @@ class TrOchi(TrCommon2):
np.put(x1, idx, -fun(-gab * x2[idx]) / gab) np.put(x1, idx, -fun(-gab * x2[idx]) / gab)
return x1 return x1
def _backward(self, xn, y2, gab, idx): def _backward(self, yn):
return self._transform(log1p, xn, y2, gab, idx) ga, gb, sigma2, mean2 = self._get_par()
y2 = sigma2 * yn.ravel() + mean2
def _forward(self, y2, xn, gab, idx): igp, = where(0 <= y2)
return self._transform(expm1, y2, xn, gab, idx) igm, = where(y2 < 0)
xn = y2.copy()
xn = self._transform(log1p, xn, y2, ga, igp)
xn = self._transform(log1p, xn, y2, gb, igm)
return xn
def _dat2gauss(self, x, *xi): def _forward(self, xn):
if len(xi) > 0:
raise ValueError('Transforming derivatives is not implemented!')
ga, gb, sigma2, mean2 = self._get_par() ga, gb, sigma2, mean2 = self._get_par()
mean = self.mean y2 = xn.copy()
sigma = self.sigma
xn = atleast_1d(x)
shape0 = xn.shape
xn = (xn.ravel() - mean) / sigma
igp, = where(0 <= xn) igp, = where(0 <= xn)
igm, = where(xn < 0) igm, = where(xn < 0)
y2 = self._transform(expm1, y2, xn, ga, igp)
y2 = self._transform(expm1, y2, xn, gb, igm)
return (y2 - mean2) / sigma2
y2 = xn.copy() def _dat2gauss(self, x, *xi):
y2 = self._forward(y2, xn, ga, igp) if len(xi) > 0:
y2 = self._forward(y2, xn, gb, igm) raise ValueError('Transforming derivatives is not implemented!')
y2.shape = shape0 xn = (atleast_1d(x) - self.mean) / self.sigma
return (y2 - mean2) * self.ysigma / sigma2 + self.ymean shape0 = xn.shape
yn = np.reshape(self._forward(xn.ravel()), shape0)
return yn * self.ysigma + self.ymean
def _gauss2dat(self, y, *yi): def _gauss2dat(self, y, *yi):
if len(yi) > 0: if len(yi) > 0:
raise ValueError('Transforming derivatives is not implemented!') raise ValueError('Transforming derivatives is not implemented!')
ga, gb, sigma2, mean2 = self._get_par()
mean = self.mean
sigma = self.sigma
yn = (atleast_1d(y) - self.ymean) / self.ysigma yn = (atleast_1d(y) - self.ymean) / self.ysigma
y2 = sigma2 * yn.ravel() + mean2 shape0 = yn.shape
xn = np.reshape(self._backward(yn.ravel()), shape0)
igp, = where(0 <= y2) return xn * self.sigma + self.mean
igm, = where(y2 < 0)
xn = y2.copy()
xn = self._backward(xn, y2, ga, igp)
xn = self._backward(xn, y2, gb, igm)
xn.shape = yn.shape
return sigma * xn + mean
def main(): def main():
@ -590,7 +581,7 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
if True: # False: # if True: # False: #
import doctest from wafo.testing import test_docstrings
doctest.testmod() test_docstrings(__file__)
else: else:
main() main()

Loading…
Cancel
Save