Fixed failing test for piecewise

master
pbrod 9 years ago
parent c78cf86ae4
commit f8edea49d5

@ -164,13 +164,15 @@ def piecewise(condlist, funclist, xi=None, fill_value=0.0, args=(), **kw):
if xi is None: if xi is None:
arrays = () arrays = ()
dtype = np.result_type(*funclist) dtype = np.result_type(*funclist)
shape = condlist[0].shape
else: else:
if not isinstance(xi, tuple): if not isinstance(xi, tuple):
xi = (xi,) xi = (xi,)
arrays = np.broadcast_arrays(*xi) arrays = np.broadcast_arrays(*xi)
dtype = np.result_type(*arrays) dtype = np.result_type(*arrays)
shape = arrays[0].shape
out = valarray(condlist[0].shape, fill_value, dtype) out = valarray(shape, fill_value, dtype)
for cond, func in zip(condlist, funclist): for cond, func in zip(condlist, funclist):
if isinstance(func, collections.Callable): if isinstance(func, collections.Callable):
temp = tuple(np.extract(cond, arr) for arr in arrays) + args temp = tuple(np.extract(cond, arr) for arr in arrays) + args

@ -516,8 +516,8 @@ class TestPiecewise(TestCase):
x = 5 x = 5
y = piecewise([[True], [False]], [1, 0], x) y = piecewise([[True], [False]], [1, 0], x)
assert_(y.ndim == 0)
assert_(y == 1) assert_(y == 1)
assert_(y.ndim == 0)
def test_abs_function(self): def test_abs_function(self):
x = np.linspace(-2.5, 2.5, 6) x = np.linspace(-2.5, 2.5, 6)

Loading…
Cancel
Save