diff --git a/gridmap/job.py b/gridmap/job.py index 85e8c50..f4986e5 100644 --- a/gridmap/job.py +++ b/gridmap/job.py @@ -50,7 +50,7 @@ from email.mime.image import MIMEImage from io import open from importlib import import_module -from multiprocessing import Pool +from multiprocessing import Pool, Value from socket import gethostname, gethostbyname, getaddrinfo, getfqdn from smtplib import (SMTPRecipientsRefused, SMTPHeloError, SMTPSenderRefused, SMTPDataError) @@ -730,8 +730,14 @@ def _execute(job): job.execute() return job.ret +def _init_pool_processes(the_val): + '''Initialize each process with a global shared variable. + ''' + global shared_val + shared_val = the_val + -def _process_jobs_locally(jobs, max_processes=1): +def _process_jobs_locally(jobs, max_processes=1, shared_val=None): """ Local execution using the package multiprocessing, if present @@ -739,6 +745,8 @@ def _process_jobs_locally(jobs, max_processes=1): :type jobs: list of Job :param max_processes: maximal number of processes :type max_processes: int + :param shared_val: shared value for the jobs + :type shared_val: multiprocessing.Value, optional :return: list of jobs, each with return in job.ret :rtype: list of Job @@ -751,7 +759,7 @@ def _process_jobs_locally(jobs, max_processes=1): for job in jobs: job.execute() else: - pool = Pool(max_processes) + pool = Pool(processes=max_processes, initializer=_init_pool_processes, initargs=(shared_val,)) result = pool.map(_execute, jobs) for ret_val, job in zip(result, jobs): job.ret = ret_val @@ -856,7 +864,7 @@ def _append_job_to_session(session, job, temp_dir=DEFAULT_TEMP_DIR, quiet=True): def process_jobs(jobs, temp_dir=DEFAULT_TEMP_DIR, white_list=None, quiet=True, - max_processes=1, local=False, require_cluster=False): + max_processes=1, local=False, require_cluster=False, shared_val=None): """ Take a list of jobs and process them on the cluster. @@ -879,6 +887,8 @@ def process_jobs(jobs, temp_dir=DEFAULT_TEMP_DIR, white_list=None, quiet=True, :param require_cluster: Should we raise an exception if access to cluster is not available? :type require_cluster: bool + :param shared_val: A shared value for all of jobs + :type shared_val: multiprocessing.Value, optional :returns: List of Job results """ @@ -904,7 +914,7 @@ def process_jobs(jobs, temp_dir=DEFAULT_TEMP_DIR, white_list=None, quiet=True, # handling of inputs, outputs and heartbeats monitor.check(sid, jobs) else: - _process_jobs_locally(jobs, max_processes=max_processes) + _process_jobs_locally(jobs, max_processes=max_processes, shared_val=shared_val) return [job.ret for job in jobs] @@ -943,7 +953,7 @@ def grid_map(f, args_list, cleanup=True, mem_free="1G", name='gridmap_job', interpreting_shell=None, copy_env=True, add_env=None, project=None, validation_level=None, os_distribution=None, os_minor=None, gpu=0, h_vmem=None, h_rt=None, resources=None, completion_mail=False, - require_cluster=False, par_env=DEFAULT_PAR_ENV): + require_cluster=False, par_env=DEFAULT_PAR_ENV, shared_val=None): """ Maps a function onto the cluster. @@ -1016,6 +1026,8 @@ def grid_map(f, args_list, cleanup=True, mem_free="1G", name='gridmap_job', :type os_distribution: str :param os_minor: os minor version that need job to run on machine :type os_minor: str + :param shared_val: A shared value for all the jobs + :type shared_val: multiprocessing.Value, optional :returns: List of Job results """ @@ -1036,7 +1048,8 @@ def grid_map(f, args_list, cleanup=True, mem_free="1G", name='gridmap_job', white_list=white_list, quiet=quiet, local=local, max_processes=max_processes, - require_cluster=require_cluster) + require_cluster=require_cluster, + shared_val=shared_val) # send a completion mail (if requested and configured) if completion_mail and SEND_ERROR_MAIL: