diff --git a/wafo/transform/models.py b/wafo/transform/models.py index bbc82b7..b9460fb 100644 --- a/wafo/transform/models.py +++ b/wafo/transform/models.py @@ -395,16 +395,12 @@ class TrLinear(TrCommon2): return y def _dat2gauss(self, x, *xi): - ysigma, xsigma = self.ysigma, self.sigma - ymean, xmean = self.ymean, self.mean - - return self._transform(ymean, ysigma, xmean, xsigma, x, xi) + return self._transform(self.ymean, self.ysigma, + self.mean, self.sigma, x, xi) def _gauss2dat(self, y, *yi): - ysigma, xsigma = self.ysigma, self.sigma - ymean, xmean = self.ymean, self.mean - - return self._transform(xmean, xsigma, ymean, ysigma, y, yi) + return self._transform(self.mean, self.sigma, + self.ymean, self.ysigma, y, yi) class TrOchi(TrCommon2): @@ -512,7 +508,6 @@ class TrOchi(TrCommon2): mean2 = a * sig22 / my2 # % choose the smallest mean self._phat = [sigma1, mean1, gam_a, gam_b, sigma2, mean2] - return def _get_par(self): ''' @@ -529,10 +524,16 @@ class TrOchi(TrCommon2): mean2 = self._phat[5] return ga, gb, sigma2, mean2 - def _forward(self, g, xn, gab, idx): + def _transform(self, fun, x1, x2, gab, idx): if gab != 0: - np.put(g, idx, (-expm1(-gab * xn[idx])) / gab) - return g + np.put(x1, idx, -fun(-gab * x2[idx]) / gab) + return x1 + + def _backward(self, xn, y2, gab, idx): + return self._transform(log1p, xn, y2, gab, idx) + + def _forward(self, y2, xn, gab, idx): + return self._transform(expm1, y2, xn, gab, idx) def _dat2gauss(self, x, *xi): if len(xi) > 0: @@ -546,17 +547,12 @@ class TrOchi(TrCommon2): igp, = where(0 <= xn) igm, = where(xn < 0) - g = xn.copy() - g = self._forward(g, xn, ga, igp) - g = self._forward(g, xn, gb, igm) + y2 = xn.copy() + y2 = self._forward(y2, xn, ga, igp) + y2 = self._forward(y2, xn, gb, igm) - g.shape = shape0 - return (g - mean2) * self.ysigma / sigma2 + self.ymean - - def _backward(self, xn, gab, idx): - if gab != 0: - np.put(xn, idx, -log1p(-gab * xn[idx]) / gab) - return xn + y2.shape = shape0 + return (y2 - mean2) * self.ysigma / sigma2 + self.ymean def _gauss2dat(self, y, *yi): if len(yi) > 0: @@ -567,13 +563,13 @@ class TrOchi(TrCommon2): sigma = self.sigma yn = (atleast_1d(y) - self.ymean) / self.ysigma - xn = sigma2 * yn.ravel() + mean2 - - igp, = where(0 <= xn) - igm, = where(xn < 0) + y2 = sigma2 * yn.ravel() + mean2 - xn = self._backward(xn, ga, igp) - xn = self._backward(xn, gb, igm) + igp, = where(0 <= y2) + igm, = where(y2 < 0) + xn = y2.copy() + xn = self._backward(xn, y2, ga, igp) + xn = self._backward(xn, y2, gb, igm) xn.shape = yn.shape return sigma * xn + mean