Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions gridmap/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from email.mime.image import MIMEImage
from io import open
from importlib import import_module
from multiprocessing import Pool
from multiprocessing import Pool, Value
from socket import gethostname, gethostbyname, getaddrinfo, getfqdn
from smtplib import (SMTPRecipientsRefused, SMTPHeloError, SMTPSenderRefused,
SMTPDataError)
Expand Down Expand Up @@ -730,15 +730,23 @@ def _execute(job):
job.execute()
return job.ret

def _init_pool_processes(the_val):
'''Initialize each process with a global shared variable.
'''
global shared_val
shared_val = the_val


def _process_jobs_locally(jobs, max_processes=1):
def _process_jobs_locally(jobs, max_processes=1, shared_val=None):
"""
Local execution using the package multiprocessing, if present

:param jobs: jobs to be executed
:type jobs: list of Job
:param max_processes: maximal number of processes
:type max_processes: int
:param shared_val: shared value for the jobs
:type shared_val: multiprocessing.Value, optional

:return: list of jobs, each with return in job.ret
:rtype: list of Job
Expand All @@ -751,7 +759,7 @@ def _process_jobs_locally(jobs, max_processes=1):
for job in jobs:
job.execute()
else:
pool = Pool(max_processes)
pool = Pool(processes=max_processes, initializer=_init_pool_processes, initargs=(shared_val,))
result = pool.map(_execute, jobs)
for ret_val, job in zip(result, jobs):
job.ret = ret_val
Expand Down Expand Up @@ -856,7 +864,7 @@ def _append_job_to_session(session, job, temp_dir=DEFAULT_TEMP_DIR, quiet=True):


def process_jobs(jobs, temp_dir=DEFAULT_TEMP_DIR, white_list=None, quiet=True,
max_processes=1, local=False, require_cluster=False):
max_processes=1, local=False, require_cluster=False, shared_val=None):
"""
Take a list of jobs and process them on the cluster.

Expand All @@ -879,6 +887,8 @@ def process_jobs(jobs, temp_dir=DEFAULT_TEMP_DIR, white_list=None, quiet=True,
:param require_cluster: Should we raise an exception if access to cluster
is not available?
:type require_cluster: bool
:param shared_val: A shared value for all of jobs
:type shared_val: multiprocessing.Value, optional

:returns: List of Job results
"""
Expand All @@ -904,7 +914,7 @@ def process_jobs(jobs, temp_dir=DEFAULT_TEMP_DIR, white_list=None, quiet=True,
# handling of inputs, outputs and heartbeats
monitor.check(sid, jobs)
else:
_process_jobs_locally(jobs, max_processes=max_processes)
_process_jobs_locally(jobs, max_processes=max_processes, shared_val=shared_val)

return [job.ret for job in jobs]

Expand Down Expand Up @@ -943,7 +953,7 @@ def grid_map(f, args_list, cleanup=True, mem_free="1G", name='gridmap_job',
interpreting_shell=None, copy_env=True, add_env=None, project=None,
validation_level=None, os_distribution=None, os_minor=None, gpu=0,
h_vmem=None, h_rt=None, resources=None, completion_mail=False,
require_cluster=False, par_env=DEFAULT_PAR_ENV):
require_cluster=False, par_env=DEFAULT_PAR_ENV, shared_val=None):
"""
Maps a function onto the cluster.

Expand Down Expand Up @@ -1016,6 +1026,8 @@ def grid_map(f, args_list, cleanup=True, mem_free="1G", name='gridmap_job',
:type os_distribution: str
:param os_minor: os minor version that need job to run on machine
:type os_minor: str
:param shared_val: A shared value for all the jobs
:type shared_val: multiprocessing.Value, optional

:returns: List of Job results
"""
Expand All @@ -1036,7 +1048,8 @@ def grid_map(f, args_list, cleanup=True, mem_free="1G", name='gridmap_job',
white_list=white_list,
quiet=quiet, local=local,
max_processes=max_processes,
require_cluster=require_cluster)
require_cluster=require_cluster,
shared_val=shared_val)

# send a completion mail (if requested and configured)
if completion_mail and SEND_ERROR_MAIL:
Expand Down