You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
55 lines
1.6 KiB
Python
55 lines
1.6 KiB
Python
"""Functions copypasted from newer versions of numpy.
|
|
|
|
"""
|
|
from __future__ import division, print_function, absolute_import
|
|
|
|
import warnings
|
|
|
|
import numpy as np
|
|
|
|
from scipy._lib._version import NumpyVersion
|
|
|
|
if NumpyVersion(np.__version__) > '1.7.0.dev':
|
|
_assert_warns = np.testing.assert_warns
|
|
else:
|
|
def _assert_warns(warning_class, func, *args, **kw):
|
|
r"""
|
|
Fail unless the given callable throws the specified warning.
|
|
|
|
This definition is copypasted from numpy 1.9.0.dev.
|
|
The version in earlier numpy returns None.
|
|
|
|
Parameters
|
|
----------
|
|
warning_class : class
|
|
The class defining the warning that `func` is expected to throw.
|
|
func : callable
|
|
The callable to test.
|
|
*args : Arguments
|
|
Arguments passed to `func`.
|
|
**kwargs : Kwargs
|
|
Keyword arguments passed to `func`.
|
|
|
|
Returns
|
|
-------
|
|
The value returned by `func`.
|
|
|
|
"""
|
|
with warnings.catch_warnings(record=True) as l:
|
|
warnings.simplefilter('always')
|
|
result = func(*args, **kw)
|
|
if not len(l) > 0:
|
|
raise AssertionError("No warning raised when calling %s"
|
|
% func.__name__)
|
|
if not l[0].category is warning_class:
|
|
raise AssertionError("First warning for %s is not a "
|
|
"%s( is %s)" % (func.__name__, warning_class, l[0]))
|
|
return result
|
|
|
|
|
|
if NumpyVersion(np.__version__) >= '1.6.0':
|
|
count_nonzero = np.count_nonzero
|
|
else:
|
|
def count_nonzero(a):
|
|
return (a != 0).sum()
|