Commit 88f1d1ae authored by André Anjos's avatar André Anjos 💬

Merge branch '32_multiprocessing_support_for_cacheddatasource' into 'master'

Implement multiprocessing support for CachedDataSource

Closes #32

See merge request !71
parents 45dccc3f d89469c1
Pipeline #39912 passed with stages
in 8 minutes and 32 seconds
......@@ -209,6 +209,14 @@ class DataSource(object):
def close(self):
self.infos = []
def reset(self):
"""Reset the state of the data source
This shall only clear the current state, not require a new call
to setup the source.
"""
pass
def __del__(self):
"""Makes sure all resources are released when the object is deleted"""
self.close()
......@@ -267,6 +275,12 @@ class DataSource(object):
# ----------------------------------------------------------
# helper to store file information
# required to be out of the CachedDataSource for pickling reasons
FileInfos = namedtuple(
"FileInfos", ["file_index", "start_index", "end_index", "offset", "size"]
)
class CachedDataSource(DataSource):
"""Utility class to load data from a file in the cache"""
......@@ -282,6 +296,29 @@ class CachedDataSource(DataSource):
self.current_file_index = None
self.unpack = True
def __getstate__(self):
# do not pass open files when being pickled/copied
state = self.__dict__.copy()
if state["current_file"] is not None:
del state["current_file"]
state["__had_open_file__"] = True
return state
def __setstate__(self, state):
# restore the state after being pickled/copied
had_open_file_before_pickle = state.pop("__had_open_file__", False)
self.__dict__.update(state)
if had_open_file_before_pickle:
try:
path = self.filenames[self.current_file_index]
self.current_file = open(path, "rb")
except Exception as e:
raise IOError("Could not read `%s': %s" % (path, e))
def _readHeader(self, file):
"""Read the header of the provided file"""
......@@ -420,11 +457,6 @@ class CachedDataSource(DataSource):
check_consistency(self.filenames, data_checksum_filenames)
# Load all the needed infos from all the files
FileInfos = namedtuple(
"FileInfos", ["file_index", "start_index", "end_index", "offset", "size"]
)
for file_index, current_filename in enumerate(self.filenames):
try:
f = open(current_filename, "rb")
......@@ -467,11 +499,18 @@ class CachedDataSource(DataSource):
return True
def close(self):
if self.current_file is not None:
self.current_file.close()
self.reset()
super(CachedDataSource, self).close()
def reset(self):
"""Rest the current state"""
if self.current_file is not None:
self.current_file.close()
self.current_file = None
self.current_file_index = None
def __getitem__(self, index):
"""Retrieve a block of data
......
......@@ -210,6 +210,14 @@ class DataLoader(object):
self.data_index_start = -1 # Lower index across all inputs
self.data_index_end = -1 # Bigger index across all inputs
def reset(self):
"""Reset all the data sources"""
for infos in self.infos.values():
data_source = infos.get("data_source")
if data_source:
data_source.reset()
def add(self, input_name, data_source):
self.infos[input_name] = dict(
data_source=data_source,
......@@ -302,6 +310,17 @@ class DataLoader(object):
return (result, indices[0], indices[1])
def __getstate__(self):
state = self.__dict__.copy()
#  reset the data cached as its content is not pickable
for infos in state["infos"].values():
infos["data"] = None
infos["start_index"] = -1
infos["end_index"] = -1
return state
# ----------------------------------------------------------
......
{
"schema_version": 2,
"language": "python",
"api_version": 2,
"type": "autonomous",
"splittable": false,
"groups": [
{
"inputs": {
"in1": {
"type": "user/single_integer/1"
}
},
"outputs": {
"out": {
"type": "user/single_integer/1"
}
}
},
{
"inputs": {
"in2": {
"type": "user/single_integer/1"
}
}
}
]
}
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###################################################################################
# #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# Redistribution and use in source and binary forms, with or without #
# modification, are permitted provided that the following conditions are met: #
# #
# 1. Redistributions of source code must retain the above copyright notice, this #
# list of conditions and the following disclaimer. #
# #
# 2. Redistributions in binary form must reproduce the above copyright notice, #
# this list of conditions and the following disclaimer in the documentation #
# and/or other materials provided with the distribution. #
# #
# 3. Neither the name of the copyright holder nor the names of its contributors #
# may be used to endorse or promote products derived from this software without #
# specific prior written permission. #
# #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #
# #
###################################################################################
import multiprocessing
def foo(queue_in, queue_out, index):
text, data_loader = queue_in.get()
data, _, _ = data_loader[index]
value = data["in1"].value
queue_out.put("hello " + text + " {}".format(value))
queue_in.task_done()
class Algorithm:
def prepare(self, data_loaders):
data_loader = data_loaders.loaderOf("in2")
data, _, _ = data_loader[0]
self.offset = data["in2"].value
return True
def process(self, data_loaders, outputs):
data_loader = data_loaders.loaderOf("in1")
# ensure loader has been used before sending it
for i in range(data_loader.count()):
data, _, _ = data_loader[i]
data["in1"].value
num_thread = data_loader.count()
queue_in = multiprocessing.JoinableQueue(num_thread)
queue_out = []
# Start worker processes
jobs = []
for i in range(num_thread):
queue_out.append(multiprocessing.Queue())
p = multiprocessing.Process(target=foo, args=(queue_in, queue_out[i], i))
jobs.append(p)
p.start()
# Add None to the queue to kill the workers
for task in range(num_thread):
queue_in.put(("test {}".format(task), data_loader))
# Wait for all the tasks to finish
queue_in.join()
for i in range(data_loader.count()):
data, _, end = data_loader[i]
outputs["out"].write({"value": data["in1"].value + self.offset}, end)
return True
{
"schema_version": 2,
"language": "python",
"api_version": 2,
"type": "sequential",
"splittable": false,
"groups": [
{
"inputs": {
"in1": {
"type": "user/single_integer/1"
}
},
"outputs": {
"out": {
"type": "user/single_integer/1"
}
}
},
{
"inputs": {
"in2": {
"type": "user/single_integer/1"
}
}
}
]
}
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
###################################################################################
# #
# Copyright (c) 2019 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# Redistribution and use in source and binary forms, with or without #
# modification, are permitted provided that the following conditions are met: #
# #
# 1. Redistributions of source code must retain the above copyright notice, this #
# list of conditions and the following disclaimer. #
# #
# 2. Redistributions in binary form must reproduce the above copyright notice, #
# this list of conditions and the following disclaimer in the documentation #
# and/or other materials provided with the distribution. #
# #
# 3. Neither the name of the copyright holder nor the names of its contributors #
# may be used to endorse or promote products derived from this software without #
# specific prior written permission. #
# #
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND #
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED #
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE #
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE #
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL #
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR #
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER #
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, #
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE #
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #
# #
###################################################################################
import multiprocessing
def foo(queue_in, queue_out, index):
text, data_loader = queue_in.get()
data, _, _ = data_loader[index]
value = data["in2"].value
queue_out.put("hello " + text + " {}".format(value))
queue_in.task_done()
class Algorithm:
def prepare(self, data_loaders):
data_loader = data_loaders.loaderOf("in2")
data, _, _ = data_loader[0]
self.offset = data["in2"].value
return True
def process(self, inputs, data_loaders, outputs):
data_loader = data_loaders.loaderOf("in2")
for i in range(data_loader.count()):
data, _, _ = data_loader[i]
data["in2"].value
num_thread = data_loader.count()
queue_in = multiprocessing.JoinableQueue(num_thread)
queue_out = []
# Start worker processes
jobs = []
for i in range(num_thread):
queue_out.append(multiprocessing.Queue())
p = multiprocessing.Process(target=foo, args=(queue_in, queue_out[i], i))
jobs.append(p)
p.start()
# Add None to the queue to kill the workers
for task in range(num_thread):
queue_in.put(("test {}".format(task), data_loader))
# Wait for all the tasks to finish
queue_in.join()
outputs["out"].write({"value": inputs["in1"].data.value + self.offset})
return True
......@@ -1078,6 +1078,36 @@ class TestSequentialAPI_Process(TestExecutionBase):
self.assertEqual(data_unit.end, 3)
self.assertEqual(data_unit.data.value, 2014)
def test_multiprocess(self):
self.writeData(
"in1",
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7)],
1000,
)
self.writeData("in2", [(0, 1), (2, 3)], 2000)
(data_loaders, inputs, outputs, data_sink) = self.create_io(
{"group1": ["in1"], "group2": ["in2"]}
)
algorithm = Algorithm(prefix, "sequential/multiprocess/1")
runner = algorithm.runner()
self.assertTrue(runner.setup({"sync": "in2"}))
self.assertTrue(runner.prepare(data_loaders=data_loaders))
while inputs.hasMoreData():
inputs.restricted_access = False
inputs.next()
inputs.restricted_access = True
self.assertTrue(
runner.process(
inputs=inputs, data_loaders=data_loaders, outputs=outputs
)
)
self.assertEqual(len(data_sink.written), 8)
# ----------------------------------------------------------
......@@ -1270,3 +1300,25 @@ class TestAutonomousAPI_Process(TestExecutionBase):
self.assertEqual(data_unit.start, 3)
self.assertEqual(data_unit.end, 3)
self.assertEqual(data_unit.data.value, 2014)
def test_multiprocess(self):
self.writeData(
"in1",
[(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7)],
1000,
)
self.writeData("in2", [(0, 1), (2, 3)], 2000)
(data_loaders, outputs, data_sink) = self.create_io(
{"group1": ["in1"], "group2": ["in2"]}
)
algorithm = Algorithm(prefix, "autonomous/multiprocess/1")
runner = algorithm.runner()
self.assertTrue(runner.setup({"sync": "in2"}))
self.assertTrue(runner.prepare(data_loaders=data_loaders.secondaries()))
self.assertTrue(runner.process(data_loaders=data_loaders, outputs=outputs))
self.assertEqual(len(data_sink.written), 8)
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment