diff --git a/wafo/misc.py b/wafo/misc.py index 591ff91..8448dce 100644 --- a/wafo/misc.py +++ b/wafo/misc.py @@ -164,13 +164,15 @@ def piecewise(condlist, funclist, xi=None, fill_value=0.0, args=(), **kw): if xi is None: arrays = () dtype = np.result_type(*funclist) + shape = condlist[0].shape else: if not isinstance(xi, tuple): xi = (xi,) arrays = np.broadcast_arrays(*xi) dtype = np.result_type(*arrays) + shape = arrays[0].shape - out = valarray(condlist[0].shape, fill_value, dtype) + out = valarray(shape, fill_value, dtype) for cond, func in zip(condlist, funclist): if isinstance(func, collections.Callable): temp = tuple(np.extract(cond, arr) for arr in arrays) + args diff --git a/wafo/tests/test_misc.py b/wafo/tests/test_misc.py index 37ccc25..4df73fa 100644 --- a/wafo/tests/test_misc.py +++ b/wafo/tests/test_misc.py @@ -516,8 +516,8 @@ class TestPiecewise(TestCase): x = 5 y = piecewise([[True], [False]], [1, 0], x) - assert_(y.ndim == 0) assert_(y == 1) + assert_(y.ndim == 0) def test_abs_function(self): x = np.linspace(-2.5, 2.5, 6)