Skip to content
Snippets Groups Projects
Commit 5187681a authored by Samuel GAIST's avatar Samuel GAIST
Browse files

[data] Make CachedDataSource pickable

This will allow to pass an instance of it in
a queue when using the multiprocessing module.
parent 647fe253
No related branches found
No related tags found
1 merge request!71Implement multiprocessing support for CachedDataSource
......@@ -275,6 +275,12 @@ class DataSource(object):
# ----------------------------------------------------------
# helper to store file information
# required to be out of the CachedDataSource for pickling reasons
FileInfos = namedtuple(
"FileInfos", ["file_index", "start_index", "end_index", "offset", "size"]
)
class CachedDataSource(DataSource):
"""Utility class to load data from a file in the cache"""
......@@ -290,6 +296,29 @@ class CachedDataSource(DataSource):
self.current_file_index = None
self.unpack = True
def __getstate__(self):
# do not pass open files when being pickled/copied
state = self.__dict__.copy()
if state["current_file"] is not None:
del state["current_file"]
state["__had_open_file__"] = True
return state
def __setstate__(self, state):
# restore the state after being pickled/copied
had_open_file_before_pickle = state.pop("__had_open_file__", False)
self.__dict__.update(state)
if had_open_file_before_pickle:
try:
path = self.filenames[self.current_file_index]
self.current_file = open(path, "rb")
except Exception as e:
raise IOError("Could not read `%s': %s" % (path, e))
def _readHeader(self, file):
"""Read the header of the provided file"""
......@@ -428,11 +457,6 @@ class CachedDataSource(DataSource):
check_consistency(self.filenames, data_checksum_filenames)
# Load all the needed infos from all the files
FileInfos = namedtuple(
"FileInfos", ["file_index", "start_index", "end_index", "offset", "size"]
)
for file_index, current_filename in enumerate(self.filenames):
try:
f = open(current_filename, "rb")
......
......@@ -37,6 +37,7 @@
import unittest
import os
import glob
import pickle # nosec
import tempfile
import shutil
......@@ -360,6 +361,29 @@ class TestCachedDataSource(TestCachedDataBase):
self.assertIsNone(cached_source.current_file_index)
self.assertIsNone(cached_source.current_file)
def test_picklability(self):
self.writeData("user/single_integer/1", 0, 9)
cached_source = CachedDataSource()
cached_source.setup(self.filename, prefix)
# test pickle before accessing a file
cached_source2 = pickle.loads(pickle.dumps(cached_source)) # nosec
data, start, end = cached_source[0]
data2, start2, end2 = cached_source2[0]
self.assertEqual(data.value, data2.value)
self.assertEqual(start, start2)
self.assertEqual(end, end2)
# access one file and try again
data, start, end = cached_source[0]
cached_source2 = pickle.loads(pickle.dumps(cached_source)) # nosec
data2, start2, end2 = cached_source2[0]
self.assertEqual(data.value, data2.value)
self.assertEqual(start, start2)
self.assertEqual(end, end2)
# ----------------------------------------------------------
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment