Simplified code

master
pbrod 8 years ago
parent ebed654ddb
commit 46403b1a5f

@ -387,19 +387,24 @@ class TrLinear(TrCommon2):
""" """
def _dat2gauss(self, x, *xi): def _transform(self, ymean, ysigma, xmean, xsigma, x, xi):
sratio = atleast_1d(self.ysigma / self.sigma) sratio = atleast_1d(ysigma / xsigma)
y = (atleast_1d(x) - self.mean) * sratio + self.ymean y = (atleast_1d(x) - xmean) * sratio + ymean
if len(xi) > 0: if len(xi) > 0:
y = [y, ] + [ix * sratio for ix in xi] y = [y] + [ix * sratio for ix in xi]
return y return y
def _dat2gauss(self, x, *xi):
ysigma, xsigma = self.ysigma, self.sigma
ymean, xmean = self.ymean, self.mean
return self._transform(ymean, ysigma, xmean, xsigma, x, xi)
def _gauss2dat(self, y, *yi): def _gauss2dat(self, y, *yi):
sratio = atleast_1d(self.sigma / self.ysigma) ysigma, xsigma = self.ysigma, self.sigma
x = (atleast_1d(y) - self.ymean) * sratio + self.mean ymean, xmean = self.ymean, self.mean
if len(yi) > 0:
x = [x, ] + [iy * sratio for iy in yi] return self._transform(xmean, xsigma, ymean, ysigma, y, yi)
return x
class TrOchi(TrCommon2): class TrOchi(TrCommon2):
@ -524,6 +529,11 @@ class TrOchi(TrCommon2):
mean2 = self._phat[5] mean2 = self._phat[5]
return ga, gb, sigma2, mean2 return ga, gb, sigma2, mean2
def _forward(self, g, xn, gab, idx):
if gab != 0:
np.put(g, idx, (-expm1(-gab * xn[idx])) / gab)
return g
def _dat2gauss(self, x, *xi): def _dat2gauss(self, x, *xi):
if len(xi) > 0: if len(xi) > 0:
raise ValueError('Transforming derivatives is not implemented!') raise ValueError('Transforming derivatives is not implemented!')
@ -537,16 +547,17 @@ class TrOchi(TrCommon2):
igm, = where(xn < 0) igm, = where(xn < 0)
g = xn.copy() g = xn.copy()
g = self._forward(g, xn, ga, igp)
if ga != 0: g = self._forward(g, xn, gb, igm)
np.put(g, igp, (-expm1(-ga * xn[igp])) / ga)
if gb != 0:
np.put(g, igm, (-expm1(-gb * xn[igm])) / gb)
g.shape = shape0 g.shape = shape0
return (g - mean2) * self.ysigma / sigma2 + self.ymean return (g - mean2) * self.ysigma / sigma2 + self.ymean
def _backward(self, xn, gab, idx):
if gab != 0:
np.put(xn, idx, -log1p(-gab * xn[idx]) / gab)
return xn
def _gauss2dat(self, y, *yi): def _gauss2dat(self, y, *yi):
if len(yi) > 0: if len(yi) > 0:
raise ValueError('Transforming derivatives is not implemented!') raise ValueError('Transforming derivatives is not implemented!')
@ -561,11 +572,8 @@ class TrOchi(TrCommon2):
igp, = where(0 <= xn) igp, = where(0 <= xn)
igm, = where(xn < 0) igm, = where(xn < 0)
if ga != 0: xn = self._backward(xn, ga, igp)
np.put(xn, igp, -log1p(-ga * xn[igp]) / ga) xn = self._backward(xn, gb, igm)
if gb != 0:
np.put(xn, igm, -log1p(-gb * xn[igm]) / gb)
xn.shape = yn.shape xn.shape = yn.shape
return sigma * xn + mean return sigma * xn + mean

Loading…
Cancel
Save