Source code for mpi_map.misc

#
# This file is part of mpi_map.
#
# mpi_map is free software: you can redistribute it and/or modify
# it under the terms of the LGNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# mpi_map is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# LGNU Lesser General Public License for more details.
#
# You should have received a copy of the LGNU Lesser General Public License
# along with mpi_map.  If not, see <http://www.gnu.org/licenses/>.
#
# DTU UQ Library
# Copyright (C) 2014 The Technical University of Denmark
# Scientific Computing Section
# Department of Applied Mathematics and Computer Science
#
# Author: Daniele Bigoni
#

__all__ = ['split_data', 'mpi_map_code', 'mpi_map_method', 'barrier', 'MPI_Pool']

import os
import sys
import time
import marshal, types
import dill
import inspect
import itertools
import distutils.spawn
from mpi4py import MPI
import mpi_map

def split_data(x, procs):
    # Split the input data
    ns = [len(x) // procs]*procs
    for i in range(len(x) % procs): ns[i] += 1
    for i in range(1,procs): ns[i] += ns[i-1]
    ns.insert(0,0)
    split_x = [ x[ns[i]:ns[i+1]] for i in range(0, procs) ]
    return (split_x, ns)

[docs]def mpi_map_code(f, x, params, procs, obj_dill=None): """ This function applies the function in ``func_code`` to the ``x`` inputs on ``procs`` processors. Args: f (function): function x (:class:`list` or :class:`ndarray<numpy.ndarray>`): input params (tuple): parameters to be passed to the function (pickable) procs (int): number of processors to be used obj (object): object where ``f`` Returns: (:class:`list` [``nprocs``]) -- (ordered) outputs from all the processes """ sys.setrecursionlimit(10000) func_code = marshal.dumps(f.__code__) if not obj is None: obj_dill = dill.dumps(obj) else: obj_dill = None try: path = os.environ['VIRTUAL_ENV'] + '/bin/mpi_eval.py' except KeyError: path = distutils.spawn.find_executable('mpi_eval.py') if len(x) > 0: cwd = os.getcwd() procs = min(procs,len(x)) comm = MPI.COMM_SELF.Spawn(sys.executable, args=[path], maxprocs=procs) # Broadcast function and parameters comm.bcast((cwd, obj_dill, func_code, params), root=MPI.ROOT) # Split the input data split_x, ns = split_data(x, procs) # Scatter the data comm.scatter(split_x, root=MPI.ROOT) # Avoid busy waiting mpi_map.barrier(MPI.COMM_WORLD) # Gather the results fval = comm.gather(None,root=MPI.ROOT) comm.Disconnect() # Check for exceptions for v in fval: fail = False if isinstance(v, tuple) and isinstance(v[0], Exception): print (v[1]) fail = True if fail: raise RuntimeError("Some of the MPI processes failed") if isinstance(fval[0], list): fval = list(itertools.chain(*fval)) else: fval = [] return fval
[docs]def mpi_map_method(fname, x, params, procs, obj): """ This function applies the method with name ``fname`` of object ``obj`` to the ``x`` inputs on ``procs`` processors. Args: fname (str): name of the function defined in ``obj`` x (:class:`list` or :class:`ndarray<numpy.ndarray>`): input params (tuple): parameters to be passed to the function (pickable) procs (int): number of processors to be used obj (object): object where ``f`` Returns: (:class:`list` [``nprocs``]) -- (ordered) outputs from all the processes """ sys.setrecursionlimit(10000) obj_dill = dill.dumps(obj) try: path = os.environ['VIRTUAL_ENV'] + '/bin/mpi_eval_method.py' except KeyError: path = distutils.spawn.find_executable('mpi_eval_method.py') if len(x) > 0: cwd = os.getcwd() procs = min(procs,len(x)) comm = MPI.COMM_SELF.Spawn(sys.executable, args=[path], maxprocs=procs) # Broadcast function and parameters comm.bcast((cwd, obj_dill, fname, params), root=MPI.ROOT) # Split the input data split_x, ns = split_data(x, procs) # Scatter the data comm.scatter(split_x, root=MPI.ROOT) # Avoid busy waiting mpi_map.barrier(MPI.COMM_WORLD) # Gather the results fval = comm.gather(None,root=MPI.ROOT) comm.Disconnect() # Check for exceptions for v in fval: fail = False if isinstance(v, tuple) and isinstance(v[0], Exception): print (v[1]) fail = True if fail: raise RuntimeError("Some of the MPI processes failed") if isinstance(fval[0], list): fval = list(itertools.chain(*fval)) else: fval = [] return fval
[docs]def barrier(comm, tag=0, sleep=0.01): """ Function used to avoid busy-waiting. As suggested by Lisandro Dalcin at: * http://code.google.com/p/mpi4py/issues/detail?id=4 and * https://groups.google.com/forum/?fromgroups=#!topic/mpi4py/nArVuMXyyZI """ size = comm.Get_size() if size == 1: return rank = comm.Get_rank() mask = 1 while mask < size: dst = (rank + mask) % size src = (rank - mask + size) % size req = comm.isend(None, dst, tag) while not comm.Iprobe(src, tag): time.sleep(sleep) comm.recv(None, src, tag) req.Wait() mask <<= 1
[docs]class MPI_Pool(object): r""" Returns (but not start) a pool of ``nprocs`` processes Args: nprocs (int): number of processes Usage example:: import numpy as np import numpy.random as npr from TransportMaps import get_mpi_pool, mpi_eval class Operator(object): def __init__(self, a): self.a = a def sum(self, x, n=1): out = x for i in range(n): out += self.a return out op = Operator(2.) x = npr.randn(100,5) n = 2 pool = get_mpi_pool(3) pool.start() try: xsum = mpi_eval("sum", op, x, (n,), mpi_pool=pool) finally: pool.stop() """ def __init__(self, nprocs): self.nprocs = nprocs self.comm = None
[docs] def start(self): r""" Start the pool of processes """ if self.comm is None: sys.setrecursionlimit(10000) try: path = os.environ['VIRTUAL_ENV'] + '/bin/mpi_pool.py' except KeyError: path = distutils.spawn.find_executable('mpi_pool.py') cwd = os.getcwd() self.comm = MPI.COMM_SELF.Spawn(sys.executable, args=[path], maxprocs=self.nprocs) # Broadcast cwd self.comm.bcast(cwd, root=MPI.ROOT)
[docs] def stop(self): r""" Stop the pool of processes """ if self.comm is not None: # Stop children self.comm.bcast(("STOP", None, None), root=MPI.ROOT) # Disconnect self.comm.Disconnect() self.comm = None
[docs] def eval_method(self, fname, x, params, obj): r""" Submit a job to the pool. Execute function ``fname`` belonging to the object ``obj`` with scattered input ``x`` and additional parameters ``params`` Args: fname (str): name of the function in ``obj`` to be executed x (:class:`list` or :class:`ndarray<numpy.ndarray>`): input params (tuple): additional parameters obj (object): object where to find function ``fname`` Returns: (:class:`list` [``nprocs``]) -- (ordered) outputs from all the processes """ if len(x) > 0: obj_dill = dill.dumps(obj) # Broadcast function and parameters self.comm.bcast((obj_dill, fname, params), root=MPI.ROOT) # Split the input data split_x, ns = split_data(x, self.nprocs) # Scatter the data self.comm.scatter(split_x, root=MPI.ROOT) # Avoid busy waiting mpi_map.barrier(MPI.COMM_WORLD) # Gather the results fval = self.comm.gather(None,root=MPI.ROOT) # Check for exceptions for v in fval: fail = False if isinstance(v, tuple) and isinstance(v[0], Exception): print (v[1]) fail = True if fail: self.stop() raise RuntimeError("Some of the MPI processes failed") if isinstance(fval[0], list): fval = list(itertools.chain(*fval)) else: fval = [] return fval