diff --git a/wafo/numba_misc.py b/wafo/numba_misc.py index 9ef00d0..81a3c69 100644 --- a/wafo/numba_misc.py +++ b/wafo/numba_misc.py @@ -4,12 +4,14 @@ Created on 6. okt. 2016 @author: pab ''' from __future__ import absolute_import, division -from numba import jit, float64, int64, int32, int8, void +from numba import guvectorize, jit, float64, int64, int32, int8, void import numpy as np -@jit(void(int64, int64, int64, int64[:], int8[:], int64, int64)) -def _find_first_cross(ix, start, dcross, ind, y, v, n): +@guvectorize([(int64[:], int8[:], int64[:])], '(n),(n)->(3)') +def _find_first_cross(ind, y, out): + ix, dcross, start, v = 0, 0, 0, 0 + n = len(y) if y[0] < v: dcross = -1 # first is a up-crossing elif y[0] > v: @@ -28,14 +30,16 @@ def _find_first_cross(ix, start, dcross, ind, y, v, n): ix += 1 dcross = 1 # The next crossing is a down-crossing break + out[0] = ix + out[1] = dcross + out[2] = start @jit(int64(int64[:], int8[:])) def _findcross(ind, y): - ix, dcross, start, v = 0, 0, 0, 0 + v = 0 + ix, dcross, start = _find_first_cross(ind, y) n = len(y) - _find_first_cross(ix, start, dcross, ind, y, v, n) - for i in range(start, n - 1): if ((dcross == -1 and y[i] <= v and v < y[i + 1]) or (dcross == 1 and v <= y[i] and y[i + 1] < v)):