Source code for distarray.localapi.mpiutils

# encoding: utf-8
# ---------------------------------------------------------------------------
#  Copyright (C) 2008-2014, IPython Development Team and Enthought, Inc.
#  Distributed under the terms of the BSD License.  See COPYING.rst.
# ---------------------------------------------------------------------------
"""
Entry point for MPI.
"""

import numpy as np

from mpi4py import MPI
from distarray.error import InvalidCommSizeError, InvalidRankError

[docs]def get_base_comm(): return _BASE_COMM
_BASE_COMM = None
[docs]def set_base_comm(comm): global _BASE_COMM _BASE_COMM = comm
[docs]def get_comm_private(): return MPI.COMM_WORLD.Clone()
[docs]def create_comm_of_size(size=4): """ Create a subcommunicator of COMM_PRIVATE of given size. """ COMM_PRIVATE = get_comm_private() group = COMM_PRIVATE.Get_group() comm_size = COMM_PRIVATE.Get_size() if size > comm_size: raise InvalidCommSizeError("requested size (%i) is bigger than the comm size (%i)" % (size, comm_size)) else: subgroup = group.Incl(list(range(size))) newcomm = COMM_PRIVATE.Create(subgroup) return newcomm
[docs]def create_comm_with_list(nodes, base_comm=None): """ Create a subcommunicator of base_comm with a list of ranks. If base_comm is not specified, defaults to COMM_PRIVATE. """ base_comm = base_comm or get_comm_private() group = base_comm.Get_group() comm_size = base_comm.Get_size() size = len(nodes) if size > comm_size: raise InvalidCommSizeError("requested size (%i) is bigger than the comm size (%i)" % (size, comm_size)) for i in nodes: if not i in range(comm_size): raise InvalidRankError("rank is not valid: %r" % i) subgroup = group.Incl(nodes) newcomm = base_comm.Create(subgroup) return newcomm
mpi_dtypes = { np.dtype('f'): MPI.FLOAT, np.dtype('d'): MPI.DOUBLE, np.dtype('i'): MPI.INTEGER, np.dtype('l'): MPI.LONG }
[docs]def mpi_type_for_ndarray(a): return mpi_dtypes[a.dtype]