Commit 2e851b4f authored by Amir MOHAMMADI's avatar Amir MOHAMMADI

Merge branch 'array_slice' into 'master'

Add a helper function to easily slice into a list in array jobs

See merge request !28
parents 9e3291b8 9481459a
Pipeline #30786 passed with stages
in 6 minutes and 40 seconds
from ..tools import get_array_job_slice
import os
class SGE_EnvWrapper:
def __init__(
self,
SGE_TASK_ID=1,
SGE_TASK_LAST=1,
SGE_TASK_FIRST=1,
SGE_TASK_STEPSIZE=1,
**kwargs
):
super().__init__(**kwargs)
self.variables = {
"SGE_TASK_ID": str(SGE_TASK_ID),
"SGE_TASK_LAST": str(SGE_TASK_LAST),
"SGE_TASK_FIRST": str(SGE_TASK_FIRST),
"SGE_TASK_STEPSIZE": str(SGE_TASK_STEPSIZE),
}
self.old_variables = None
def __enter__(self):
# backup current variables
self.old_variables = {name: os.environ.get(name) for name in self.variables}
# set the requested variables
for name, value in self.variables.items():
os.environ[name] = value
return self
def __exit__(self, *args):
# restore old variables
for name, value in self.old_variables.items():
if value is None:
del os.environ[name]
else:
os.environ[name] = value
def set(self, name, value):
assert name in self.old_variables
os.environ[name] = str(value)
def test_get_array_job_slice():
with SGE_EnvWrapper() as wrapper:
s = get_array_job_slice(10)
assert s == slice(0, 10)
wrapper.set("SGE_TASK_LAST", 5)
s = get_array_job_slice(10)
assert s == slice(0, 2)
wrapper.set("SGE_TASK_ID", 2)
s = get_array_job_slice(10)
assert s == slice(2, 4)
wrapper.set("SGE_TASK_ID", 5)
s = get_array_job_slice(10)
assert s == slice(8, 10)
......@@ -11,6 +11,7 @@ import os
import re
import hashlib
import random
import math
# sqlalchemy migration; copied from Bob
......@@ -334,3 +335,38 @@ def qdel(jobid, context='grid'):
from .setshell import sexec
sexec(context, scmd, error_on_nonzero=False)
def get_array_job_slice(total_length):
"""A helper function that let's you chunk a list in an SGE array job.
Use this function like ``a = a[get_array_job_slice(len(a))]`` to only process a chunk
of ``a``.
Parameters
----------
total_length : int
The length of the list that you are trying to slice
Returns
-------
slice
A slice to be used.
Raises
------
NotImplementedError
If "SGE_TASK_FIRST" and "SGE_TASK_STEPSIZE" are not 1.
"""
sge_task_id = os.environ.get("SGE_TASK_ID")
try:
sge_task_id = int(sge_task_id)
except Exception:
return slice(None)
if os.environ["SGE_TASK_FIRST"] != '1' or os.environ["SGE_TASK_STEPSIZE"] != '1':
raise NotImplementedError("Values other than 1 for SGE_TASK_FIRST and SGE_TASK_STEPSIZE is not supported!")
job_id = sge_task_id - 1
number_of_parallel_jobs = int(os.environ["SGE_TASK_LAST"])
number_of_objects_per_job = int(math.ceil(total_length / number_of_parallel_jobs))
start = min(job_id * number_of_objects_per_job, total_length)
end = min((job_id + 1) * number_of_objects_per_job, total_length)
return slice(start, end)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment