Fixed tests.

master
Per A Brodtkorb 7 years ago
parent a148caa586
commit f2ddf0010a

@ -181,10 +181,10 @@ def accum(accmap, a, func=None, shape=None, fill_value=0, dtype=None):
def create_array_of_python_lists(accmap, a, shape): def create_array_of_python_lists(accmap, a, shape):
vals = np.empty(shape, dtype='O') vals = np.empty(shape, dtype='O')
for s in product(*[np.arange(k) for k in shape]): for s in product(*[list(range(k)) for k in shape]):
vals[s] = [] vals[s] = []
for s in product(*[np.arange(k) for k in a.shape]): for s in product(*[list(range(k)) for k in a.shape]):
indx = tuple(accmap[s]) indx = tuple(accmap[s])
val = a[s] val = a[s]
vals[indx].append(val) vals[indx].append(val)

@ -94,7 +94,7 @@ def valarray(shape, value=np.NaN, typecode=None):
return out return out
def piecewise(condlist, funclist, xi=None, fillvalue=0.0, args=(), **kw): def piecewise(condlist, funclist, xi=None, fill_value=0.0, args=(), **kw):
""" """
Evaluate a piecewise-defined function. Evaluate a piecewise-defined function.
@ -193,8 +193,8 @@ def piecewise(condlist, funclist, xi=None, fillvalue=0.0, args=(), **kw):
def check_shapes(condlist, funclist): def check_shapes(condlist, funclist):
nc, nf = len(condlist), len(funclist) nc, nf = len(condlist), len(funclist)
_assert(nc in [nf - 1, nf], "function list and condition list" _assert(nc in [nf - 1, nf],
" must be the same length") "function list and condition list must be the same length")
check_shapes(condlist, funclist) check_shapes(condlist, funclist)
@ -210,7 +210,7 @@ def piecewise(condlist, funclist, xi=None, fillvalue=0.0, args=(), **kw):
arrays = np.broadcast_arrays(*xi) arrays = np.broadcast_arrays(*xi)
dtype = np.result_type(*arrays) dtype = np.result_type(*arrays)
out = valarray(condlist[0].shape, fillvalue, dtype) out = valarray(condlist[0].shape, fill_value, dtype)
for cond, func in zip(condlist, funclist): for cond, func in zip(condlist, funclist):
if cond.any(): if cond.any():
if isinstance(func, collections.Callable): if isinstance(func, collections.Callable):

@ -4,7 +4,7 @@ Created on 6. okt. 2016
@author: pab @author: pab
""" """
from __future__ import absolute_import, division from __future__ import absolute_import, division
from numba import guvectorize, jit, float64, int64, int32, int8, void from numba import jit, float64, int64, int32, int8, void
import numpy as np import numpy as np

Loading…
Cancel
Save