diff --git a/wafo/transform/models.py b/wafo/transform/models.py index 5b7266b..0d0c97a 100644 --- a/wafo/transform/models.py +++ b/wafo/transform/models.py @@ -387,19 +387,24 @@ class TrLinear(TrCommon2): """ - def _dat2gauss(self, x, *xi): - sratio = atleast_1d(self.ysigma / self.sigma) - y = (atleast_1d(x) - self.mean) * sratio + self.ymean + def _transform(self, ymean, ysigma, xmean, xsigma, x, xi): + sratio = atleast_1d(ysigma / xsigma) + y = (atleast_1d(x) - xmean) * sratio + ymean if len(xi) > 0: - y = [y, ] + [ix * sratio for ix in xi] + y = [y] + [ix * sratio for ix in xi] return y + def _dat2gauss(self, x, *xi): + ysigma, xsigma = self.ysigma, self.sigma + ymean, xmean = self.ymean, self.mean + + return self._transform(ymean, ysigma, xmean, xsigma, x, xi) + def _gauss2dat(self, y, *yi): - sratio = atleast_1d(self.sigma / self.ysigma) - x = (atleast_1d(y) - self.ymean) * sratio + self.mean - if len(yi) > 0: - x = [x, ] + [iy * sratio for iy in yi] - return x + ysigma, xsigma = self.ysigma, self.sigma + ymean, xmean = self.ymean, self.mean + + return self._transform(xmean, xsigma, ymean, ysigma, y, yi) class TrOchi(TrCommon2): @@ -524,6 +529,11 @@ class TrOchi(TrCommon2): mean2 = self._phat[5] return ga, gb, sigma2, mean2 + def _forward(self, g, xn, gab, idx): + if gab != 0: + np.put(g, idx, (-expm1(-gab * xn[idx])) / gab) + return g + def _dat2gauss(self, x, *xi): if len(xi) > 0: raise ValueError('Transforming derivatives is not implemented!') @@ -537,16 +547,17 @@ class TrOchi(TrCommon2): igm, = where(xn < 0) g = xn.copy() - - if ga != 0: - np.put(g, igp, (-expm1(-ga * xn[igp])) / ga) - - if gb != 0: - np.put(g, igm, (-expm1(-gb * xn[igm])) / gb) + g = self._forward(g, xn, ga, igp) + g = self._forward(g, xn, gb, igm) g.shape = shape0 return (g - mean2) * self.ysigma / sigma2 + self.ymean + def _backward(self, xn, gab, idx): + if gab != 0: + np.put(xn, idx, -log1p(-gab * xn[idx]) / gab) + return xn + def _gauss2dat(self, y, *yi): if len(yi) > 0: raise ValueError('Transforming derivatives is not implemented!') @@ -561,11 +572,8 @@ class TrOchi(TrCommon2): igp, = where(0 <= xn) igm, = where(xn < 0) - if ga != 0: - np.put(xn, igp, -log1p(-ga * xn[igp]) / ga) - - if gb != 0: - np.put(xn, igm, -log1p(-gb * xn[igm]) / gb) + xn = self._backward(xn, ga, igp) + xn = self._backward(xn, gb, igm) xn.shape = yn.shape return sigma * xn + mean