Fixed a bug in TransformEstimator

master
Per A Brodtkorb 8 years ago
parent 3961bedcbb
commit 5da3701f5e

@ -546,7 +546,7 @@ class SmoothSpline(PPform):
self.linear_extrapolate(output=False) self.linear_extrapolate(output=False)
def _compute_coefs(self, xx, yy, p=None, var=1): def _compute_coefs(self, xx, yy, p=None, var=1):
x, y = np.atleast_1d(xx, yy) x, y, var = np.atleast_1d(xx, yy, var)
x = x.ravel() x = x.ravel()
dx = np.diff(x) dx = np.diff(x)
must_sort = (dx < 0).any() must_sort = (dx < 0).any()
@ -573,7 +573,7 @@ class SmoothSpline(PPform):
dydx = np.diff(y) / dx dydx = np.diff(y) / dx
if (n == 2): # % straight line if (n == 2): # straight line
coefs = np.vstack([dydx.ravel(), y[0, :]]) coefs = np.vstack([dydx.ravel(), y[0, :]])
else: else:

@ -7,7 +7,7 @@ from __future__ import absolute_import
from .core import TrData from .core import TrData
from .models import TrHermite, TrOchi, TrLinear from .models import TrHermite, TrOchi, TrLinear
from ..stats import edf, skew, kurtosis from ..stats import edf, skew, kurtosis
from ..interpolate import SmoothSpline from ..interpolate import SmoothSpline, interp1d
from scipy.special import ndtri as invnorm from scipy.special import ndtri as invnorm
from scipy.integrate import cumtrapz from scipy.integrate import cumtrapz
import warnings import warnings
@ -109,6 +109,7 @@ class TransformEstimator(object):
if (dy <= 0).any(): if (dy <= 0).any():
dy[dy > 0] = eps dy[dy > 0] = eps
gvar = -(np.hstack((dy, 0)) + np.hstack((0, dy))) / 2 + eps gvar = -(np.hstack((dy, 0)) + np.hstack((0, dy))) / 2 + eps
gvar = interp1d(tr.args, gvar)(tr_raw.args)
pp_tr = SmoothSpline(tr_raw.args, tr_raw.data, p=1, pp_tr = SmoothSpline(tr_raw.args, tr_raw.data, p=1,
lin_extrap=self.linextrap, lin_extrap=self.linextrap,
var=ix * gvar) var=ix * gvar)

Loading…
Cancel
Save