Source code for macrostat.util.batchprocessing

"""
Batch processing functionionality
"""

__author__ = ["Karl Naumann-Woleske"]
__credits__ = ["Karl Naumann-Woleske"]
__license__ = "MIT"
__version__ = "0.1.0"
__maintainer__ = ["Karl Naumann-Woleske"]

import logging
import traceback
from contextlib import contextmanager

import torch.multiprocessing as mp
from torch.multiprocessing import Pool

logger = logging.getLogger(__name__)


[docs] @contextmanager def pool_context(*args, **kwargs): """Context manager for process pool to ensure proper cleanup.""" pool = Pool(*args, **kwargs) try: yield pool finally: logger.debug("Cleaning up process pool") pool.terminate() pool.join() logger.debug("Process pool cleanup completed")
[docs] def timeseries_worker(task: tuple): """Worker function for parallel_processor, which will execute a simulation with the given parameters and return the output. Parameters ---------- task : tuple Tuple of (name, model, *args) where name is the name of the simulation, model is the model to be simulated and *args are the arguments to be passed to the model's simulate method. Returns ------- tuple Tuple of (name, *args, output) where name is the name of the simulation, *args are the arguments passed to the model's simulate method and output is the output of the simulation. """ try: model = task[1] _ = model.simulate(*task[2:]) return (task[0], *task[2:], model.variables.to_pandas()) except Exception as e: logger.error(f"Worker failed for task {task[0]}: {str(e)}") logger.error(traceback.format_exc()) raise
[docs] def parallel_processor( tasks: list = [], worker: callable = timeseries_worker, cpu_count: int = 1, ): """Run all of the tasks in parallel using the ProcessPoolExecutor.""" # Set multiprocessing start method to spawn try: mp.set_start_method("spawn", force=True) except RuntimeError: pass # Set sharing strategy mp.set_sharing_strategy("file_system") if len(tasks) == 0: raise ValueError("No tasks to process.") process_count = min(cpu_count, len(tasks)) logger.debug(f"Creating process pool with {process_count} workers") try: with pool_context(processes=process_count) as pool: logger.debug("Process pool created successfully") results = pool.map(worker, tasks) logger.debug("Parallel processing completed") return results except Exception as e: logger.error(f"Error in process pool: {str(e)}") logger.error(traceback.format_exc()) raise