# encoding: utf-8
# ---------------------------------------------------------------------------
# Copyright (C) 2008-2014, IPython Development Team and Enthought, Inc.
# Distributed under the terms of the BSD License. See COPYING.rst.
# ---------------------------------------------------------------------------
"""
Utilities.
"""
from functools import reduce
from importlib import import_module
from math import sqrt
import random
import uuid
from distarray import DISTARRAY_BASE_NAME
from distarray.externals.six import next
DISTARRAY_RANDOM = random.Random()
[docs]def distarray_random_setstate(state):
""" Set the state of the global random number generator. """
global DISTARRAY_RANDOM
DISTARRAY_RANDOM.setstate(state)
[docs]def distarray_random_getstate():
""" Get the state of the global random number generator. """
global DISTARRAY_RANDOM
return DISTARRAY_RANDOM.getstate()
[docs]def nonce():
return uuid.UUID(int=DISTARRAY_RANDOM.getrandbits(8 * 16)).hex[:16]
[docs]def uid():
"""Get a unique name for a distarray object. """
return DISTARRAY_BASE_NAME + nonce()
[docs]def multi_for(iterables):
if not iterables:
yield ()
else:
for item in iterables[0]:
for rest_tuple in multi_for(iterables[1:]):
yield (item,) + rest_tuple
[docs]def divisors_minmax(n, dmin, dmax):
"""Find the divisors of n in the interval (dmin,dmax]."""
i = dmin + 1
while i <= dmax:
if n % i == 0:
yield i
i += 1
[docs]def list_or_tuple(seq):
""" Is the object either a list or a tuple? """
return isinstance(seq, (list, tuple))
[docs]def flatten(seq, to_expand=list_or_tuple):
"""Flatten a nested sequence."""
for item in seq:
if to_expand(item):
for subitem in flatten(item, to_expand):
yield subitem
else:
yield item
[docs]def mult_partitions(n, s):
"""Compute the multiplicative partitions of n of size s
>>> mult_partitions(52,3)
[(2, 2, 13)]
>>> mult_partitions(52,2)
[(2, 26), (4, 13)]
"""
return [tuple(flatten(p)) for p in mult_partitions_recurs(n, s)]
[docs]def mult_partitions_recurs(n, s, pd=0):
if s == 1:
return [n]
divs = divisors_minmax(n, pd, int(sqrt(n)))
fs = []
for d in divs:
fs.extend([(d, f) for f in mult_partitions_recurs(n / d, s - 1, pd)])
pd = d
return fs
[docs]def mirror_sort(seq, ref_seq):
"""Sort `seq` into the order that `ref_seq` is in.
>>> mirror_sort(range(5),[1,5,2,4,3])
[0, 4, 1, 3, 2]
"""
if not len(seq) == len(ref_seq):
raise ValueError("Sequences must have the same length")
shift = list(zip(range(len(ref_seq)), ref_seq))
shift.sort(key=lambda x: x[1])
shift = [s[0] for s in shift]
newseq = len(ref_seq) * [0]
for s_index in range(len(shift)):
newseq[shift[s_index]] = seq[s_index]
return newseq
def _raise_nie():
msg = "This has not yet been implemented for distributed arrays"
raise NotImplementedError(msg)
[docs]def slice_intersection(s1, s2):
"""Compute a slice that represents the intersection of two slices.
Currently only implemented for steps of size 1.
Parameters
----------
s1, s2 : slice objects
Returns
-------
slice object
"""
valid_steps = {None, 1}
if (s1.step in valid_steps) and (s2.step in valid_steps):
step = 1
stop = min(s1.stop, s2.stop)
start = max(s1.start, s2.start)
return slice(start, stop, step)
else:
msg = "Slice intersection only implemented for step=1."
raise NotImplementedError(msg)
[docs]def has_exactly_one(iterable):
"""Does `iterable` have exactly one non-None element?"""
test = (x is not None for x in iterable)
if sum(test) == 1:
return True
else:
return False
[docs]def all_equal(iterable):
"""Return True if all elements in `iterable` are equal.
Also returns True if iterable is empty.
"""
iterator = iter(iterable)
try:
first = next(iterator)
except StopIteration:
return True # vacuously True
return all(element == first for element in iterator)
[docs]class count_round_trips(object):
"""
Context manager for counting the number of roundtrips between a IPython
client and controller.
Usage:
>>> with count_round_trips(client) as r:
... send_42_messages()
>>> r.count
42
"""
def __init__(self, client):
self.client = client
self.orig_count = len(self.client.history)
self.count = 0
[docs] def update_count(self):
self.count = len(self.client.history) - self.orig_count
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
self.update_count()
[docs]def remove_elements(to_remove, seq):
""" Return a list, with the elements with specified indices removed.
Parameters
----------
to_remove: iterable
Indices of elements in list to remove
seq: iterable
Elements in the list.
Returns
-------
List with the specified indices removed.
"""
return [x for (idx, x) in enumerate(seq) if idx not in to_remove]
[docs]def get_from_dotted_name(dotted_name):
main = import_module('__main__')
thing = reduce(getattr, [main] + dotted_name.split('.'))
return thing
[docs]def set_from_dotted_name(name, val):
main = import_module('__main__')
peices = name.split('.')
place = reduce(getattr, [main] + peices[:-1])
setattr(place, peices[-1], val)