# encoding: utf-8
# ---------------------------------------------------------------------------
# Copyright (C) 2008-2014, IPython Development Team and Enthought, Inc.
# Distributed under the terms of the BSD License. See COPYING.rst.
# ---------------------------------------------------------------------------
"""
Functions used for tests.
"""
import unittest
import importlib
import tempfile
import os
import types
from uuid import uuid4
from functools import wraps
import numpy as np
from distarray.externals import six
from distarray.externals import protocol_validator
from distarray.globalapi.context import Context, ContextCreationError
from distarray.error import InvalidCommSizeError
from distarray.localapi.mpiutils import MPI, create_comm_of_size
[docs]def raise_typeerror(fn):
"""Decorator for protocol validator functions.
These functions return (success, err_msg), but sometimes we would rather
have an exception.
"""
@wraps(fn)
def wrapper(*args, **kwargs):
good, msg = fn(*args, **kwargs)
if not good:
raise TypeError(msg)
else:
return (good, msg)
return wrapper
validate_dim_dict = raise_typeerror(protocol_validator.validate_dim_dict)
validate_dim_data = raise_typeerror(protocol_validator.validate_dim_data)
validate_distbuffer = raise_typeerror(protocol_validator.validate)
[docs]def temp_filepath(extension=''):
"""Return a randomly generated filename.
This filename is appended to the directory path returned by
`tempfile.gettempdir()` and has `extension` appended to it.
"""
tempdir = tempfile.gettempdir()
filename = str(uuid4())[:8] + extension
return os.path.join(tempdir, filename)
[docs]def import_or_skip(name):
"""Try importing `name`, raise SkipTest on failure.
Parameters
----------
name : str
Module name to try to import.
Returns
-------
module : module object
Module object imported by importlib.
Raises
------
unittest.SkipTest
If the attempted import raises an ImportError.
Examples
--------
>>> h5py = import_or_skip('h5py')
>>> h5py.get_config()
<h5py.h5.H5PYConfig at 0x103dd5a78>
"""
try:
return importlib.import_module(name)
except ImportError:
errmsg = '%s not found... skipping.' % name
raise unittest.SkipTest(errmsg)
[docs]def comm_null_passes(fn):
"""Decorator. If `self.comm` is COMM_NULL, pass.
This allows our tests to pass on processes that have nothing to do.
"""
@wraps(fn)
def wrapper(self, *args, **kwargs):
if hasattr(self, 'comm') and (self.comm == MPI.COMM_NULL):
pass
else:
return fn(self, *args, **kwargs)
return wrapper
[docs]class CommNullPasser(type):
"""Metaclass.
Applies the `comm_null_passes` decorator to every method on a generated
class.
"""
def __new__(cls, name, bases, attrs):
for attr_name, attr_value in six.iteritems(attrs):
if isinstance(attr_value, types.FunctionType):
attrs[attr_name] = comm_null_passes(attr_value)
return super(CommNullPasser, cls).__new__(cls, name, bases, attrs)
@six.add_metaclass(CommNullPasser)
[docs]class ParallelTestCase(unittest.TestCase):
"""Base test class for fully distributed and client-less test cases.
Overload the `comm_size` class attribute to change the default number of
processes required.
Attributes
----------
comm_size : int, default=4
Indicates how many MPI processes are required for this test to
run. If fewer than `comm_size` are available, the test will be
skipped.
"""
comm_size = 4
@classmethod
[docs] def setUpClass(cls):
try:
cls.comm = create_comm_of_size(cls.comm_size)
except InvalidCommSizeError:
msg = "Must run with comm size >= {}."
raise unittest.SkipTest(msg.format(cls.comm_size))
@classmethod
[docs] def tearDownClass(cls):
if cls.comm != MPI.COMM_NULL:
cls.comm.Free()
[docs]class BaseContextTestCase(unittest.TestCase):
"""Base test class for test cases that use a Context.
Overload the `ntargets` class attribute to change the default number of
engines required. A `cls.context` object will be created with
`targets=range(cls.ntargets)`. Tests will be skipped if there are too
few targets.
Attributes
----------
ntargets : int or 'any', default=4
If an int, indicates how many engines are required for this test to
run. If the string 'any', indicates that any number of engines may
be used with this test.
"""
ntargets = 4
@classmethod
[docs] def setUpClass(cls):
super(BaseContextTestCase, cls).setUpClass()
# skip if there aren't enough engines
try:
if cls.ntargets == 'any':
cls.context = cls.make_context()
cls.ntargets = len(cls.context.targets)
else:
try:
cls.context = cls.make_context(targets=list(range(cls.ntargets)))
except ValueError:
msg = ("Not enough targets available for this test. (%s) "
"required" % (cls.ntargets))
raise unittest.SkipTest(msg)
except ContextCreationError as e:
raise unittest.SkipTest(e.message)
@classmethod
[docs] def tearDownClass(cls):
try:
cls.context.close()
except RuntimeError:
pass
[docs]class MPIContextTestCase(BaseContextTestCase):
@classmethod
[docs] def make_context(cls, targets=None):
return Context(kind='MPI', targets=targets)
[docs]class IPythonContextTestCase(BaseContextTestCase):
@classmethod
[docs] def make_context(cls, targets=None):
try:
return Context(kind='IPython', targets=targets)
except EnvironmentError:
msg = "You must have an ipcluster running to run this test case."
raise unittest.SkipTest(msg)
@classmethod
[docs] def setUpClass(cls):
super(IPythonContextTestCase, cls).setUpClass()
cls.client = cls.context.client
[docs]class DefaultContextTestCase(BaseContextTestCase):
@classmethod
[docs] def make_context(cls, targets=None):
try:
return Context(targets=targets)
except EnvironmentError:
msg = "You must have an ipcluster running to run this test case."
raise unittest.SkipTest(msg)
[docs]def check_targets(required, available):
"""If available < required, raise a SkipTest with a nice error message."""
if available < required:
msg = ("This test requires at least {} engines to run; "
"only {} available.")
msg = msg.format(required, available)
raise unittest.SkipTest(msg)
def _assert_localarray_metadata_equal(l0, l1, check_dtype=False):
np.testing.assert_equal(l0.dist, l1.dist)
np.testing.assert_equal(l0.global_shape, l1.global_shape)
np.testing.assert_equal(l0.ndim, l1.ndim)
np.testing.assert_equal(l0.global_size, l1.global_size)
np.testing.assert_equal(l0.comm_size, l1.comm_size)
np.testing.assert_equal(l0.comm_rank, l1.comm_rank)
np.testing.assert_equal(l0.cart_coords, l1.cart_coords)
np.testing.assert_equal(l0.grid_shape, l1.grid_shape)
np.testing.assert_equal(l0.local_shape, l1.local_shape)
np.testing.assert_equal(l0.local_size, l1.local_size)
np.testing.assert_equal(l0.ndarray.shape, l1.ndarray.shape)
if check_dtype:
np.testing.assert_equal(l0.ndarray.dtype, l1.ndarray.dtype)
[docs]def assert_localarrays_allclose(l0, l1, check_dtype=False, rtol=1e-07, atol=0):
"""Call np.testing.assert_allclose on `l0` and `l1`.
Also, check that LocalArray properties are equal.
"""
_assert_localarray_metadata_equal(l0, l1, check_dtype=check_dtype)
np.testing.assert_allclose(l0.ndarray, l1.ndarray, rtol=rtol, atol=atol)
[docs]def assert_localarrays_equal(l0, l1, check_dtype=False):
"""Call np.testing.assert_equal on `l0` and `l1`.
Also, check that LocalArray properties are equal.
"""
_assert_localarray_metadata_equal(l0, l1, check_dtype=check_dtype)
np.testing.assert_array_equal(l0.ndarray, l1.ndarray)