From 4b21cc83c2b998fcf06bf86de8ec12bf592ee175 Mon Sep 17 00:00:00 2001 From: pbrod Date: Sun, 15 Jan 2017 13:42:00 +0100 Subject: [PATCH] Simplified code --- wafo/transform/models.py | 63 +++++++++++++++++----------------------- 1 file changed, 27 insertions(+), 36 deletions(-) diff --git a/wafo/transform/models.py b/wafo/transform/models.py index b9460fb..462b578 100644 --- a/wafo/transform/models.py +++ b/wafo/transform/models.py @@ -12,7 +12,7 @@ from numpy import (sqrt, atleast_1d, abs, imag, sign, where, cos, arccos, ceil, expm1, log1p, pi) import numpy as np import warnings -from .core import TrCommon, TrData +from wafo.transform.core import TrCommon, TrData __all__ = ['TrHermite', 'TrLinear', 'TrOchi'] _EPS = np.finfo(float).eps @@ -529,50 +529,41 @@ class TrOchi(TrCommon2): np.put(x1, idx, -fun(-gab * x2[idx]) / gab) return x1 - def _backward(self, xn, y2, gab, idx): - return self._transform(log1p, xn, y2, gab, idx) - - def _forward(self, y2, xn, gab, idx): - return self._transform(expm1, y2, xn, gab, idx) + def _backward(self, yn): + ga, gb, sigma2, mean2 = self._get_par() + y2 = sigma2 * yn.ravel() + mean2 + igp, = where(0 <= y2) + igm, = where(y2 < 0) + xn = y2.copy() + xn = self._transform(log1p, xn, y2, ga, igp) + xn = self._transform(log1p, xn, y2, gb, igm) + return xn - def _dat2gauss(self, x, *xi): - if len(xi) > 0: - raise ValueError('Transforming derivatives is not implemented!') + def _forward(self, xn): ga, gb, sigma2, mean2 = self._get_par() - mean = self.mean - sigma = self.sigma - xn = atleast_1d(x) - shape0 = xn.shape - xn = (xn.ravel() - mean) / sigma + y2 = xn.copy() igp, = where(0 <= xn) igm, = where(xn < 0) + y2 = self._transform(expm1, y2, xn, ga, igp) + y2 = self._transform(expm1, y2, xn, gb, igm) + return (y2 - mean2) / sigma2 - y2 = xn.copy() - y2 = self._forward(y2, xn, ga, igp) - y2 = self._forward(y2, xn, gb, igm) + def _dat2gauss(self, x, *xi): + if len(xi) > 0: + raise ValueError('Transforming derivatives is not implemented!') - y2.shape = shape0 - return (y2 - mean2) * self.ysigma / sigma2 + self.ymean + xn = (atleast_1d(x) - self.mean) / self.sigma + shape0 = xn.shape + yn = np.reshape(self._forward(xn.ravel()), shape0) + return yn * self.ysigma + self.ymean def _gauss2dat(self, y, *yi): if len(yi) > 0: raise ValueError('Transforming derivatives is not implemented!') - - ga, gb, sigma2, mean2 = self._get_par() - mean = self.mean - sigma = self.sigma - yn = (atleast_1d(y) - self.ymean) / self.ysigma - y2 = sigma2 * yn.ravel() + mean2 - - igp, = where(0 <= y2) - igm, = where(y2 < 0) - xn = y2.copy() - xn = self._backward(xn, y2, ga, igp) - xn = self._backward(xn, y2, gb, igm) - - xn.shape = yn.shape - return sigma * xn + mean + shape0 = yn.shape + xn = np.reshape(self._backward(yn.ravel()), shape0) + return xn * self.sigma + self.mean def main(): @@ -590,7 +581,7 @@ def main(): if __name__ == '__main__': if True: # False: # - import doctest - doctest.testmod() + from wafo.testing import test_docstrings + test_docstrings(__file__) else: main()