From c4e5bfd66a1299fb61a6ae8ac2d3ae54ab53bf4e Mon Sep 17 00:00:00 2001 From: pbrod Date: Wed, 25 May 2016 21:04:04 +0200 Subject: [PATCH] Deleted namedtuple, fixed bug in piecewise --- wafo/misc.py | 23 ++++--- wafo/namedtuple.py | 144 ---------------------------------------- wafo/tests/test_misc.py | 44 ++++++------ 3 files changed, 35 insertions(+), 176 deletions(-) delete mode 100644 wafo/namedtuple.py diff --git a/wafo/misc.py b/wafo/misc.py index 1d900f8..c113601 100644 --- a/wafo/misc.py +++ b/wafo/misc.py @@ -54,7 +54,7 @@ def valarray(shape, value=np.NaN, typecode=None): return out -def piecewise(xi, condlist, funclist, fill_value=0.0, args=(), **kw): +def piecewise(condlist, funclist, xi=None, fill_value=0.0, args=(), **kw): """ Evaluate a piecewise-defined function. @@ -63,8 +63,6 @@ def piecewise(xi, condlist, funclist, fill_value=0.0, args=(), **kw): Parameters ---------- - xi : tuple - input arguments to the functions in funclist, i.e., (x0, x1,...., xn) condlist : list of bool arrays Each boolean array corresponds to a function in `funclist`. Wherever `condlist[i]` is True, `funclist[i](x0,x1,...,xn)` is used as the @@ -81,6 +79,8 @@ def piecewise(xi, condlist, funclist, fill_value=0.0, args=(), **kw): or a scalar value as output. If, instead of a callable, a scalar is provided then a constant function (``lambda x: scalar``) is assumed. + xi : tuple + input arguments to the functions in funclist, i.e., (x0, x1,...., xn) fill_value : scalar fill value for out of range values. Default 0. args : tuple, optional @@ -157,23 +157,26 @@ def piecewise(xi, condlist, funclist, fill_value=0.0, args=(), **kw): " must be the same length") check_shapes(condlist, funclist) - if not isinstance(xi, tuple): - xi = (xi,) condlist = np.broadcast_arrays(*condlist) if len(condlist) == len(funclist)-1: condlist.append(otherwise_condition(condlist)) - - arrays = np.broadcast_arrays(*xi) - dtype = np.result_type(*arrays) + if xi is None: + arrays = condlist + dtype = np.result_type(*funclist) + else: + if not isinstance(xi, tuple): + xi = (xi,) + arrays = np.broadcast_arrays(*xi) + dtype = np.result_type(*arrays) out = valarray(arrays[0].shape, fill_value, dtype) for cond, func in zip(condlist, funclist): if isinstance(func, collections.Callable): temp = tuple(np.extract(cond, arr) for arr in arrays) + args np.place(out, cond, func(*temp, **kw)) - else: # func is a scalar value - np.place(out, cond, func) + else: # func is a scalar value or a list + np.putmask(out, cond, func) return out diff --git a/wafo/namedtuple.py b/wafo/namedtuple.py deleted file mode 100644 index 857bd30..0000000 --- a/wafo/namedtuple.py +++ /dev/null @@ -1,144 +0,0 @@ -from operator import itemgetter as _itemgetter -from keyword import iskeyword as _iskeyword -import sys as _sys - - -def namedtuple(typename, field_names, verbose=False): - """Returns a new subclass of tuple with named fields. - - >>> Point = namedtuple('Point', 'x y') - >>> Point.__doc__ # docstring for the new class - 'Point(x, y)' - >>> p = Point(11, y=22) # instantiate with positional args or keywords - >>> p[0] + p[1] # indexable like a plain tuple - 33 - >>> x, y = p # unpack like a regular tuple - >>> x, y - (11, 22) - >>> p.x + p.y # fields also accessable by name - 33 - >>> d = p._asdict() # convert to a dictionary - >>> d['x'] - 11 - >>> Point(**d) # convert from a dictionary - Point(x=11, y=22) - >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields - Point(x=100, y=22) - - """ - - # Parse and validate the field names. Validation serves two purposes, - # generating informative error messages and preventing template injection - # attacks. - if isinstance(field_names, basestring): - # names separated by whitespace and/or commas - field_names = field_names.replace(',', ' ').split() - field_names = tuple(field_names) - for name in (typename,) + field_names: - if not min(c.isalnum() or c == '_' for c in name): - raise ValueError( - 'Type names and field names can only contain alphanumeric ' + - 'characters and underscores: %r' % name) - if _iskeyword(name): - raise ValueError( - 'Type names and field names cannot be a keyword: %r' % name) - if name[0].isdigit(): - raise ValueError('Type names and field names cannot start ' + - 'with a number: %r' % name) - seen_names = set() - for name in field_names: - if name.startswith('_'): - raise ValueError( - 'Field names cannot start with an underscore: %r' % name) - if name in seen_names: - raise ValueError('Encountered duplicate field name: %r' % name) - seen_names.add(name) - - # Create and fill-in the class template - numfields = len(field_names) - # tuple repr without parens or quotes - argtxt = repr(field_names).replace("'", "")[1:-1] - reprtxt = ', '.join('%s=%%r' % name for name in field_names) - dicttxt = ', '.join('%r: t[%d]' % (name, pos) - for pos, name in enumerate(field_names)) - template = '''class %(typename)s(tuple): - '%(typename)s(%(argtxt)s)' \n - __slots__ = () \n - _fields = %(field_names)r \n - def __new__(cls, %(argtxt)s): - return tuple.__new__(cls, (%(argtxt)s)) \n - @classmethod - def _make(cls, iterable, new=tuple.__new__, len=len): - 'Make a new %(typename)s object from a sequence or iterable' - result = new(cls, iterable) - if len(result) != %(numfields)d: - raise TypeError('Expected %(numfields)d arguments, got %%d' %% len(result)) - return result \n - def __repr__(self): - return '%(typename)s(%(reprtxt)s)' %% self \n - def _asdict(t): - 'Return a new dict which maps field names to their values' - return {%(dicttxt)s} \n - def _replace(self, **kwds): - 'Return a new %(typename)s object replacing specified fields with new values' - result = self._make(map(kwds.pop, %(field_names)r, self)) - if kwds: - raise ValueError('Got unexpected field names: %%r' %% kwds.keys()) - return result \n\n''' % locals() - for i, name in enumerate(field_names): - template += ' %s = property(itemgetter(%d))\n' % (name, i) - if verbose: - print template - - # Execute the template string in a temporary namespace - namespace = dict(itemgetter=_itemgetter) - try: - exec template in namespace - except SyntaxError, e: - raise SyntaxError(e.message + ':\n' + template) - result = namespace[typename] - - # For pickling to work, the __module__ variable needs to be set to the - # frame where the named tuple is created. Bypass this step in enviroments - # where sys._getframe is not defined (Jython for example). - if hasattr(_sys, '_getframe'): - result.__module__ = _sys._getframe(1).f_globals['__name__'] - - return result - - -if __name__ == '__main__': - # verify that instances can be pickled - from cPickle import loads, dumps - Point = namedtuple('Point', 'x, y', True) - p = Point(x=10, y=20) - assert p == loads(dumps(p)) - - # test and demonstrate ability to override methods - class Point(namedtuple('Point', 'x y')): - - @property - def hypot(self): - return (self.x ** 2 + self.y ** 2) ** 0.5 - - def __str__(self): - return 'Point: x=%6.3f y=%6.3f hypot=%6.3f' % (self.x, self.y, - self.hypot) - - for p in Point(3, 4), Point(14, 5), Point(9. / 7, 6): - print(p) - - class Point(namedtuple('Point', 'x y')): - '''Point class with optimized _make() and _replace() - without error-checking - ''' - _make = classmethod(tuple.__new__) - - def _replace(self, _map=map, **kwds): - return self._make(_map(kwds.get, ('x', 'y'), self)) - - print(Point(11, 22)._replace(x=100)) - - import doctest - TestResults = namedtuple('TestResults', 'failed attempted') - print(TestResults(*doctest.testmod())) diff --git a/wafo/tests/test_misc.py b/wafo/tests/test_misc.py index 9ad18bc..37ccc25 100644 --- a/wafo/tests/test_misc.py +++ b/wafo/tests/test_misc.py @@ -475,64 +475,64 @@ def test_tranproc(): class TestPiecewise(TestCase): def test_condition_is_single_bool_list(self): - assert_raises(ValueError, piecewise, [0, 0], [True, False], [1]) + assert_raises(ValueError, piecewise, [True, False], [1], [0, 0]) def test_condition_is_list_of_single_bool_list(self): - x = piecewise([0, 0], [[True, False]], [1]) + x = piecewise([[True, False]], [1], [0, 0]) assert_array_equal(x, [1, 0]) def test_conditions_is_list_of_single_bool_array(self): - x = piecewise([0, 0], [np.array([True, False])], [1]) + x = piecewise([np.array([True, False])], [1], [0, 0]) assert_array_equal(x, [1, 0]) def test_condition_is_single_int_array(self): - assert_raises(ValueError, piecewise, [0, 0], np.array([1, 0]), [1]) + assert_raises(ValueError, piecewise, np.array([1, 0]), [1], [0, 0]) def test_condition_is_list_of_single_int_array(self): - x = piecewise([0, 0], [np.array([1, 0])], [1]) + x = piecewise([np.array([1, 0])], [1], [0, 0]) assert_array_equal(x, [1, 0]) def test_simple(self): - x = piecewise([0, 0], [[False, True]], [lambda x:-1]) + x = piecewise([[False, True]], [lambda x:-1], [0, 0]) assert_array_equal(x, [0, -1]) - x = piecewise([1, 2], [[True, False], [False, True]], [3, 4]) + x = piecewise([[True, False], [False, True]], [3, 4], [1, 2]) assert_array_equal(x, [3, 4]) def test_default(self): # No value specified for x[1], should be 0 - x = piecewise([1, 2], [[True, False]], [2]) + x = piecewise([[True, False]], [2], [1, 2],) assert_array_equal(x, [2, 0]) # Should set x[1] to 3 - x = piecewise([1, 2], [[True, False]], [2, 3]) + x = piecewise([[True, False]], [2, 3], [1, 2]) assert_array_equal(x, [2, 3]) def test_0d(self): x = np.array(3) - y = piecewise(x, [x > 3], [4, 0]) + y = piecewise([x > 3], [4, 0], x) assert_(y.ndim == 0) assert_(y == 0) x = 5 - y = piecewise(x, [[True], [False]], [1, 0]) + y = piecewise([[True], [False]], [1, 0], x) assert_(y.ndim == 0) assert_(y == 1) def test_abs_function(self): x = np.linspace(-2.5, 2.5, 6) - vals = piecewise((x,), [x < 0, x >= 0], [lambda x: -x, lambda x: x]) + vals = piecewise([x < 0, x >= 0], [lambda x: -x, lambda x: x], (x,)) assert_array_equal(vals, [2.5, 1.5, 0.5, 0.5, 1.5, 2.5]) def test_abs_function_with_scalar(self): x = np.array(-2.5) - vals = piecewise((x,), [x < 0, x >= 0], [lambda x: -x, lambda x: x]) + vals = piecewise([x < 0, x >= 0], [lambda x: -x, lambda x: x], (x,)) assert_(vals == 2.5) def test_otherwise_condition(self): x = np.linspace(-2.5, 2.5, 6) - vals = piecewise((x,), [x < 0, ], [lambda x: -x, lambda x: x]) + vals = piecewise([x < 0, ], [lambda x: -x, lambda x: x], (x,)) assert_array_equal(vals, [2.5, 1.5, 0.5, 0.5, 1.5, 2.5]) def test_passing_further_args_to_fun(self): @@ -542,24 +542,24 @@ class TestPiecewise(TestCase): def fun1(x, y, scale=1.): return x*y/scale x = np.linspace(-2.5, 2.5, 6) - vals = piecewise((x,), [x < 0, ], [fun0, fun1], args=(2.,), scale=2.) + vals = piecewise([x < 0, ], [fun0, fun1], (x,), args=(2.,), scale=2.) assert_array_equal(vals, [2.5, 1.5, 0.5, 0.5, 1.5, 2.5]) def test_step_function(self): x = np.linspace(-2.5, 2.5, 6) - vals = piecewise(x, [x < 0, x >= 0], [-1, 1]) + vals = piecewise([x < 0, x >= 0], [-1, 1], x) assert_array_equal(vals, [-1., -1., -1., 1., 1., 1.]) def test_step_function_with_scalar(self): x = 1 - vals = piecewise(x, [x < 0, x >= 0], [-1, 1]) + vals = piecewise([x < 0, x >= 0], [-1, 1], x) assert_(vals == 1) def test_function_with_two_args(self): x = np.linspace(-2, 2, 5) X, Y = np.meshgrid(x, x) vals = piecewise( - (X, Y), [X * Y < 0, ], [lambda x, y: -x * y, lambda x, y: x * y]) + [X * Y < 0, ], [lambda x, y: -x * y, lambda x, y: x * y], (X, Y)) assert_array_equal(vals, [[4., 2., -0., 2., 4.], [2., 1., -0., 1., 2.], [-0., -0., 0., 0., 0.], @@ -569,8 +569,8 @@ class TestPiecewise(TestCase): def test_fill_value_and_function_with_two_args(self): x = np.linspace(-2, 2, 5) X, Y = np.meshgrid(x, x) - vals = piecewise((X, Y), [X * Y < -0.5, X * Y > 0.5], - [lambda x, y: -x * y, lambda x, y: x * y], + vals = piecewise([X * Y < -0.5, X * Y > 0.5], + [lambda x, y: -x * y, lambda x, y: x * y], (X, Y), fill_value=np.nan) nan = np.nan assert_array_equal(vals, [[4., 2., nan, 2., 4.], @@ -582,8 +582,8 @@ class TestPiecewise(TestCase): def test_fill_value2_and_function_with_two_args(self): x = np.linspace(-2, 2, 5) X, Y = np.meshgrid(x, x) - vals = piecewise((X, Y), [X * Y < -0.5, X * Y > 0.5], - [lambda x, y: -x * y, lambda x, y: x * y, np.nan]) + vals = piecewise([X * Y < -0.5, X * Y > 0.5], + [lambda x, y: -x * y, lambda x, y: x * y, np.nan], (X, Y)) nan = np.nan assert_array_equal(vals, [[4., 2., nan, 2., 4.], [2., 1., nan, 1., 2.],