docs for muutils v0.8.5
View Source on GitHub

muutils.parallel


  1import multiprocessing
  2import functools
  3from typing import (
  4    Any,
  5    Callable,
  6    Iterable,
  7    Literal,
  8    Optional,
  9    Tuple,
 10    TypeVar,
 11    Dict,
 12    List,
 13    Union,
 14    Protocol,
 15)
 16
 17# for no tqdm fallback
 18from muutils.spinner import SpinnerContext
 19from muutils.validate_type import get_fn_allowed_kwargs
 20
 21
 22InputType = TypeVar("InputType")
 23OutputType = TypeVar("OutputType")
 24# typevars for our iterable and map
 25
 26
 27class ProgressBarFunction(Protocol):
 28    "a protocol for a progress bar function"
 29
 30    def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...
 31
 32
 33ProgressBarOption = Literal["tqdm", "spinner", "none", None]
 34# type for the progress bar option
 35
 36
 37DEFAULT_PBAR_FN: ProgressBarOption
 38# default progress bar function
 39
 40try:
 41    # use tqdm if it's available
 42    import tqdm  # type: ignore[import-untyped]
 43
 44    DEFAULT_PBAR_FN = "tqdm"
 45
 46except ImportError:
 47    # use progress bar as fallback
 48    DEFAULT_PBAR_FN = "spinner"
 49
 50
 51def spinner_fn_wrap(x: Iterable, **kwargs) -> List:
 52    "spinner wrapper"
 53    spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs(
 54        SpinnerContext.__init__
 55    )
 56    mapped_kwargs: dict = {
 57        k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs
 58    }
 59    if "desc" in kwargs and "message" not in mapped_kwargs:
 60        mapped_kwargs["message"] = kwargs["desc"]
 61
 62    if "message" not in mapped_kwargs and "total" in kwargs:
 63        mapped_kwargs["message"] = f"Processing {kwargs['total']} items"
 64
 65    with SpinnerContext(**mapped_kwargs):
 66        output = list(x)
 67
 68    return output
 69
 70
 71def map_kwargs_for_tqdm(kwargs: dict) -> dict:
 72    "map kwargs for tqdm, cant wrap because the pbar dissapears?"
 73    tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__)
 74    mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs}
 75
 76    if "desc" not in kwargs:
 77        if "message" in kwargs:
 78            mapped_kwargs["desc"] = kwargs["message"]
 79
 80        elif "total" in kwargs:
 81            mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items"
 82    return mapped_kwargs
 83
 84
 85def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
 86    "fallback to no progress bar"
 87    return x
 88
 89
 90def set_up_progress_bar_fn(
 91    pbar: Union[ProgressBarFunction, ProgressBarOption],
 92    pbar_kwargs: Optional[Dict[str, Any]] = None,
 93    **extra_kwargs,
 94) -> Tuple[ProgressBarFunction, dict]:
 95    """set up the progress bar function and its kwargs
 96
 97    # Parameters:
 98     - `pbar : Union[ProgressBarFunction, ProgressBarOption]`
 99       progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use
100     - `pbar_kwargs : Optional[Dict[str, Any]]`
101       kwargs passed to the progress bar function (default to `None`)
102       (defaults to `None`)
103
104    # Returns:
105     - `Tuple[ProgressBarFunction, dict]`
106         a tuple of the progress bar function and its kwargs
107
108    # Raises:
109     - `ValueError` : if `pbar` is not one of the valid options
110    """
111    pbar_fn: ProgressBarFunction
112
113    if pbar_kwargs is None:
114        pbar_kwargs = dict()
115
116    pbar_kwargs = {**extra_kwargs, **pbar_kwargs}
117
118    # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs`
119    if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False):
120        pbar_fn = no_progress_fn_wrap  # type: ignore[assignment]
121
122    # if `pbar` is a different string, figure out which progress bar to use
123    elif isinstance(pbar, str):
124        if pbar == "tqdm":
125            pbar_fn = tqdm.tqdm
126            pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs)
127        elif pbar == "spinner":
128            pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs)
129            pbar_kwargs = dict()
130        else:
131            raise ValueError(
132                f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }"
133            )
134    else:
135        # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this
136        pbar_fn = pbar
137
138    return pbar_fn, pbar_kwargs
139
140
141# TODO: if `parallel` is a negative int, use `multiprocessing.cpu_count() + parallel` to determine the number of processes
142def run_maybe_parallel(
143    func: Callable[[InputType], OutputType],
144    iterable: Iterable[InputType],
145    parallel: Union[bool, int],
146    pbar_kwargs: Optional[Dict[str, Any]] = None,
147    chunksize: Optional[int] = None,
148    keep_ordered: bool = True,
149    use_multiprocess: bool = False,
150    pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN,
151) -> List[OutputType]:
152    """a function to make it easier to sometimes parallelize an operation
153
154    - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)`
155    - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes
156    - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel`
157
158    the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())`
159
160    # Parameters:
161     - `func : Callable[[InputType], OutputType]`
162       function passed to either `map` or `Pool.imap`
163     - `iterable : Iterable[InputType]`
164       iterable passed to either `map` or `Pool.imap`
165     - `parallel : bool | int`
166       whether to run in parallel, and how many processes to use
167     - `pbar_kwargs : Dict[str, Any]`
168       kwargs passed to the progress bar function
169
170    # Returns:
171     - `List[OutputType]`
172       a list of the output of `func` for each element in `iterable`
173
174    # Raises:
175     - `ValueError` : if `parallel` is not a boolean or an integer greater than 1
176     - `ValueError` : if `use_multiprocess=True` and `parallel=False`
177     - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available
178    """
179
180    # number of inputs in iterable
181    n_inputs: int = len(iterable)  # type: ignore[arg-type]
182    if n_inputs == 0:
183        # Return immediately if there is no input
184        return list()
185
186    # which progress bar to use
187    pbar_fn: ProgressBarFunction
188    pbar_kwargs_processed: dict
189    pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn(
190        pbar=pbar,
191        pbar_kwargs=pbar_kwargs,
192        # extra kwargs
193        total=n_inputs,
194    )
195
196    # number of processes
197    num_processes: int
198    if isinstance(parallel, bool):
199        num_processes = multiprocessing.cpu_count() if parallel else 1
200    elif isinstance(parallel, int):
201        if parallel < 2:
202            raise ValueError(
203                f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }"
204            )
205        num_processes = parallel
206    else:
207        raise ValueError(
208            f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }"
209        )
210
211    # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process
212    num_processes = min(num_processes, n_inputs)
213    mp = multiprocessing
214    if num_processes == 1:
215        parallel = False
216
217    if use_multiprocess:
218        if not parallel:
219            raise ValueError("`use_multiprocess=True` requires `parallel=True`")
220
221        try:
222            import multiprocess  # type: ignore[import-untyped]
223        except ImportError as e:
224            raise ImportError(
225                "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`"
226            ) from e
227
228        mp = multiprocess
229
230    # set up the map function -- maybe its parallel, maybe it's just `map`
231    do_map: Callable[
232        [Callable[[InputType], OutputType], Iterable[InputType]],
233        Iterable[OutputType],
234    ]
235    if parallel:
236        # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing`
237        pool = mp.Pool(num_processes)
238
239        # use `imap` if we want to keep the order, otherwise use `imap_unordered`
240        if keep_ordered:
241            do_map = pool.imap
242        else:
243            do_map = pool.imap_unordered
244
245        # figure out a smart chunksize if one is not given
246        chunksize_int: int
247        if chunksize is None:
248            chunksize_int = max(1, n_inputs // num_processes)
249        else:
250            chunksize_int = chunksize
251
252        # set the chunksize
253        do_map = functools.partial(do_map, chunksize=chunksize_int)  # type: ignore
254
255    else:
256        do_map = map
257
258    # run the map function with a progress bar
259    output: List[OutputType] = list(
260        pbar_fn(
261            do_map(
262                func,
263                iterable,
264            ),
265            **pbar_kwargs_processed,
266        )
267    )
268
269    # close the pool if we used one
270    if parallel:
271        pool.close()
272        pool.join()
273
274    # return the output as a list
275    return output

class ProgressBarFunction(typing.Protocol):
28class ProgressBarFunction(Protocol):
29    "a protocol for a progress bar function"
30
31    def __call__(self, iterable: Iterable, **kwargs: Any) -> Iterable: ...

a protocol for a progress bar function

ProgressBarFunction(*args, **kwargs)
1767def _no_init_or_replace_init(self, *args, **kwargs):
1768    cls = type(self)
1769
1770    if cls._is_protocol:
1771        raise TypeError('Protocols cannot be instantiated')
1772
1773    # Already using a custom `__init__`. No need to calculate correct
1774    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1775    if cls.__init__ is not _no_init_or_replace_init:
1776        return
1777
1778    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1779    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1780    # searches for a proper new `__init__` in the MRO. The new `__init__`
1781    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1782    # instantiation of the protocol subclass will thus use the new
1783    # `__init__` and no longer call `_no_init_or_replace_init`.
1784    for base in cls.__mro__:
1785        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1786        if init is not _no_init_or_replace_init:
1787            cls.__init__ = init
1788            break
1789    else:
1790        # should not happen
1791        cls.__init__ = object.__init__
1792
1793    cls.__init__(self, *args, **kwargs)
ProgressBarOption = typing.Literal['tqdm', 'spinner', 'none', None]
DEFAULT_PBAR_FN: Literal['tqdm', 'spinner', 'none', None] = 'tqdm'
def spinner_fn_wrap(x: Iterable, **kwargs) -> List:
52def spinner_fn_wrap(x: Iterable, **kwargs) -> List:
53    "spinner wrapper"
54    spinnercontext_allowed_kwargs: set[str] = get_fn_allowed_kwargs(
55        SpinnerContext.__init__
56    )
57    mapped_kwargs: dict = {
58        k: v for k, v in kwargs.items() if k in spinnercontext_allowed_kwargs
59    }
60    if "desc" in kwargs and "message" not in mapped_kwargs:
61        mapped_kwargs["message"] = kwargs["desc"]
62
63    if "message" not in mapped_kwargs and "total" in kwargs:
64        mapped_kwargs["message"] = f"Processing {kwargs['total']} items"
65
66    with SpinnerContext(**mapped_kwargs):
67        output = list(x)
68
69    return output

spinner wrapper

def map_kwargs_for_tqdm(kwargs: dict) -> dict:
72def map_kwargs_for_tqdm(kwargs: dict) -> dict:
73    "map kwargs for tqdm, cant wrap because the pbar dissapears?"
74    tqdm_allowed_kwargs: set[str] = get_fn_allowed_kwargs(tqdm.tqdm.__init__)
75    mapped_kwargs: dict = {k: v for k, v in kwargs.items() if k in tqdm_allowed_kwargs}
76
77    if "desc" not in kwargs:
78        if "message" in kwargs:
79            mapped_kwargs["desc"] = kwargs["message"]
80
81        elif "total" in kwargs:
82            mapped_kwargs["desc"] = f"Processing {kwargs.get('total')} items"
83    return mapped_kwargs

map kwargs for tqdm, cant wrap because the pbar dissapears?

def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
86def no_progress_fn_wrap(x: Iterable, **kwargs) -> Iterable:
87    "fallback to no progress bar"
88    return x

fallback to no progress bar

def set_up_progress_bar_fn( pbar: Union[ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]], pbar_kwargs: Optional[Dict[str, Any]] = None, **extra_kwargs) -> Tuple[ProgressBarFunction, dict]:
 91def set_up_progress_bar_fn(
 92    pbar: Union[ProgressBarFunction, ProgressBarOption],
 93    pbar_kwargs: Optional[Dict[str, Any]] = None,
 94    **extra_kwargs,
 95) -> Tuple[ProgressBarFunction, dict]:
 96    """set up the progress bar function and its kwargs
 97
 98    # Parameters:
 99     - `pbar : Union[ProgressBarFunction, ProgressBarOption]`
100       progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use
101     - `pbar_kwargs : Optional[Dict[str, Any]]`
102       kwargs passed to the progress bar function (default to `None`)
103       (defaults to `None`)
104
105    # Returns:
106     - `Tuple[ProgressBarFunction, dict]`
107         a tuple of the progress bar function and its kwargs
108
109    # Raises:
110     - `ValueError` : if `pbar` is not one of the valid options
111    """
112    pbar_fn: ProgressBarFunction
113
114    if pbar_kwargs is None:
115        pbar_kwargs = dict()
116
117    pbar_kwargs = {**extra_kwargs, **pbar_kwargs}
118
119    # dont use a progress bar if `pbar` is None or "none", or if `disable` is set to True in `pbar_kwargs`
120    if (pbar is None) or (pbar == "none") or pbar_kwargs.get("disable", False):
121        pbar_fn = no_progress_fn_wrap  # type: ignore[assignment]
122
123    # if `pbar` is a different string, figure out which progress bar to use
124    elif isinstance(pbar, str):
125        if pbar == "tqdm":
126            pbar_fn = tqdm.tqdm
127            pbar_kwargs = map_kwargs_for_tqdm(pbar_kwargs)
128        elif pbar == "spinner":
129            pbar_fn = functools.partial(spinner_fn_wrap, **pbar_kwargs)
130            pbar_kwargs = dict()
131        else:
132            raise ValueError(
133                f"`pbar` must be either 'tqdm' or 'spinner' if `str`, or a valid callable, got {type(pbar) = } {pbar = }"
134            )
135    else:
136        # the default value is a callable which will resolve to tqdm if available or spinner as a fallback. we pass kwargs to this
137        pbar_fn = pbar
138
139    return pbar_fn, pbar_kwargs

set up the progress bar function and its kwargs

Parameters:

  • pbar : Union[ProgressBarFunction, ProgressBarOption] progress bar function or option. if a function, we return as-is. if a string, we figure out which progress bar to use
  • pbar_kwargs : Optional[Dict[str, Any]] kwargs passed to the progress bar function (default to None) (defaults to None)

Returns:

  • Tuple[ProgressBarFunction, dict] a tuple of the progress bar function and its kwargs

Raises:

  • ValueError : if pbar is not one of the valid options
def run_maybe_parallel( func: Callable[[~InputType], ~OutputType], iterable: Iterable[~InputType], parallel: Union[bool, int], pbar_kwargs: Optional[Dict[str, Any]] = None, chunksize: Optional[int] = None, keep_ordered: bool = True, use_multiprocess: bool = False, pbar: Union[ProgressBarFunction, Literal['tqdm', 'spinner', 'none', None]] = 'tqdm') -> List[~OutputType]:
143def run_maybe_parallel(
144    func: Callable[[InputType], OutputType],
145    iterable: Iterable[InputType],
146    parallel: Union[bool, int],
147    pbar_kwargs: Optional[Dict[str, Any]] = None,
148    chunksize: Optional[int] = None,
149    keep_ordered: bool = True,
150    use_multiprocess: bool = False,
151    pbar: Union[ProgressBarFunction, ProgressBarOption] = DEFAULT_PBAR_FN,
152) -> List[OutputType]:
153    """a function to make it easier to sometimes parallelize an operation
154
155    - if `parallel` is `False`, then the function will run in serial, running `map(func, iterable)`
156    - if `parallel` is `True`, then the function will run in parallel, running in parallel with the maximum number of processes
157    - if `parallel` is an `int`, it must be greater than 1, and the function will run in parallel with the number of processes specified by `parallel`
158
159    the maximum number of processes is given by the `min(len(iterable), multiprocessing.cpu_count())`
160
161    # Parameters:
162     - `func : Callable[[InputType], OutputType]`
163       function passed to either `map` or `Pool.imap`
164     - `iterable : Iterable[InputType]`
165       iterable passed to either `map` or `Pool.imap`
166     - `parallel : bool | int`
167       whether to run in parallel, and how many processes to use
168     - `pbar_kwargs : Dict[str, Any]`
169       kwargs passed to the progress bar function
170
171    # Returns:
172     - `List[OutputType]`
173       a list of the output of `func` for each element in `iterable`
174
175    # Raises:
176     - `ValueError` : if `parallel` is not a boolean or an integer greater than 1
177     - `ValueError` : if `use_multiprocess=True` and `parallel=False`
178     - `ImportError` : if `use_multiprocess=True` and `multiprocess` is not available
179    """
180
181    # number of inputs in iterable
182    n_inputs: int = len(iterable)  # type: ignore[arg-type]
183    if n_inputs == 0:
184        # Return immediately if there is no input
185        return list()
186
187    # which progress bar to use
188    pbar_fn: ProgressBarFunction
189    pbar_kwargs_processed: dict
190    pbar_fn, pbar_kwargs_processed = set_up_progress_bar_fn(
191        pbar=pbar,
192        pbar_kwargs=pbar_kwargs,
193        # extra kwargs
194        total=n_inputs,
195    )
196
197    # number of processes
198    num_processes: int
199    if isinstance(parallel, bool):
200        num_processes = multiprocessing.cpu_count() if parallel else 1
201    elif isinstance(parallel, int):
202        if parallel < 2:
203            raise ValueError(
204                f"`parallel` must be a boolean, or be an integer greater than 1, got {type(parallel) = } {parallel = }"
205            )
206        num_processes = parallel
207    else:
208        raise ValueError(
209            f"The 'parallel' parameter must be a boolean or an integer, got {type(parallel) = } {parallel = }"
210        )
211
212    # make sure we don't have more processes than iterable, and don't bother with parallel if there's only one process
213    num_processes = min(num_processes, n_inputs)
214    mp = multiprocessing
215    if num_processes == 1:
216        parallel = False
217
218    if use_multiprocess:
219        if not parallel:
220            raise ValueError("`use_multiprocess=True` requires `parallel=True`")
221
222        try:
223            import multiprocess  # type: ignore[import-untyped]
224        except ImportError as e:
225            raise ImportError(
226                "`use_multiprocess=True` requires the `multiprocess` package -- this is mostly useful when you need to pickle a lambda. install muutils with `pip install muutils[multiprocess]` or just do `pip install multiprocess`"
227            ) from e
228
229        mp = multiprocess
230
231    # set up the map function -- maybe its parallel, maybe it's just `map`
232    do_map: Callable[
233        [Callable[[InputType], OutputType], Iterable[InputType]],
234        Iterable[OutputType],
235    ]
236    if parallel:
237        # use `mp.Pool` since we might want to use `multiprocess` instead of `multiprocessing`
238        pool = mp.Pool(num_processes)
239
240        # use `imap` if we want to keep the order, otherwise use `imap_unordered`
241        if keep_ordered:
242            do_map = pool.imap
243        else:
244            do_map = pool.imap_unordered
245
246        # figure out a smart chunksize if one is not given
247        chunksize_int: int
248        if chunksize is None:
249            chunksize_int = max(1, n_inputs // num_processes)
250        else:
251            chunksize_int = chunksize
252
253        # set the chunksize
254        do_map = functools.partial(do_map, chunksize=chunksize_int)  # type: ignore
255
256    else:
257        do_map = map
258
259    # run the map function with a progress bar
260    output: List[OutputType] = list(
261        pbar_fn(
262            do_map(
263                func,
264                iterable,
265            ),
266            **pbar_kwargs_processed,
267        )
268    )
269
270    # close the pool if we used one
271    if parallel:
272        pool.close()
273        pool.join()
274
275    # return the output as a list
276    return output

a function to make it easier to sometimes parallelize an operation

  • if parallel is False, then the function will run in serial, running map(func, iterable)
  • if parallel is True, then the function will run in parallel, running in parallel with the maximum number of processes
  • if parallel is an int, it must be greater than 1, and the function will run in parallel with the number of processes specified by parallel

the maximum number of processes is given by the min(len(iterable), multiprocessing.cpu_count())

Parameters:

  • func : Callable[[InputType], OutputType] function passed to either map or Pool.imap
  • iterable : Iterable[InputType] iterable passed to either map or Pool.imap
  • parallel : bool | int whether to run in parallel, and how many processes to use
  • pbar_kwargs : Dict[str, Any] kwargs passed to the progress bar function

Returns:

  • List[OutputType] a list of the output of func for each element in iterable

Raises:

  • ValueError : if parallel is not a boolean or an integer greater than 1
  • ValueError : if use_multiprocess=True and parallel=False
  • ImportError : if use_multiprocess=True and multiprocess is not available