Source code for distarray.mpionly_utils
# encoding: utf-8
# ---------------------------------------------------------------------------
# Copyright (C) 2008-2014, IPython Development Team and Enthought, Inc.
# Distributed under the terms of the BSD License. See COPYING.rst.
# ---------------------------------------------------------------------------
"""
Utilities for running Distarray in MPI mode.
"""
from __future__ import absolute_import
import types
from mpi4py import MPI as mpi
from distarray.utils import uid
from distarray.localapi.proxyize import Proxy
client_rank = 0
[docs]def get_comm_world():
return mpi.COMM_WORLD
[docs]def get_world_rank():
return get_comm_world().rank
[docs]def push_function(context, key, func, targets=None):
targets = targets or context.targets
if not isinstance(func, types.BuiltinFunctionType):
func_code = func.__code__
func_globals = func.__globals__ # noqa
func_name = func.__name__
func_defaults = func.__defaults__
func_closure = func.__closure__
func_data = ('function', func_code, func_name, func_defaults, func_closure)
else:
func_data = ('builtin_function_or_method', func)
def reassemble_and_store_func(key_dummy_container, func_data):
import types
from importlib import import_module
from distarray.utils import set_from_dotted_name
key = key_dummy_container[0]
main = import_module('__main__')
if func_data[0] == 'function':
code, name, defaults, closure = func_data[1:]
func = types.FunctionType(code=code, globals=main.__dict__,
name=name, argdefs=defaults,
closure=closure)
elif func_data[0] == 'builtin_function_or_method':
func = func_data[1]
set_from_dotted_name(key, func)
context.apply(reassemble_and_store_func, args=((key,), func_data),
targets=context.targets)
def _set_on_main(name, obj):
"""Add obj as an attribute to the __main__ module with alias `name` like:
__main__.name = obj
"""
return Proxy(name, obj, '__main__')
[docs]def make_targets_comm(targets):
world = get_comm_world()
world_rank = world.rank
if len(targets) > world.size:
raise ValueError("The number of engines (%s) is less than the number"
" of targets you want (%s)." % (world.size - 1,
len(targets)))
targets = targets or list(range(world.size - 1))
# get a universal name for the out comm
if world_rank == 0:
comm_name = uid()
else:
comm_name = ''
comm_name = world.bcast(comm_name)
# create a mapping from the targets to world ranks
all_ranks = range(1, world.size)
all_targets = range(world.size - 1)
target_to_rank_map = {t: r for t, r in zip(all_targets, all_ranks)}
# map the targets to the world ranks
mapped_targets = [target_to_rank_map[t] for t in targets]
# create the targets comm
targets_group = world.group.Incl(mapped_targets)
targets_comm = world.Create(targets_group)
return _set_on_main(comm_name, targets_comm)
[docs]def initial_comm_setup():
"""Setup client and engine intracomm, and intercomm."""
world = get_comm_world()
world_rank = world.rank
# create a comm that is split into client and engines.
if world_rank == client_rank:
split_world = world.Split(0, 0)
else:
split_world = world.Split(1, world_rank)
from distarray.localapi.mpiutils import set_base_comm
set_base_comm(split_world)
# create the intercomm
if world_rank == client_rank:
intercomm = split_world.Create_intercomm(0, world, 1)
else:
intercomm = split_world.Create_intercomm(0, world, 0)
return intercomm
[docs]def is_solo_mpi_process():
if get_comm_world().size == 1:
return True
else:
return False