From 5187681a4dafe5a411d344ccba930d3348f28e67 Mon Sep 17 00:00:00 2001 From: Samuel Gaist <samuel.gaist@idiap.ch> Date: Fri, 15 May 2020 17:26:30 +0200 Subject: [PATCH] [data] Make CachedDataSource pickable This will allow to pass an instance of it in a queue when using the multiprocessing module. --- beat/backend/python/data.py | 34 +++++++++++++++++++++++---- beat/backend/python/test/test_data.py | 24 +++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/beat/backend/python/data.py b/beat/backend/python/data.py index c126d55..80f7547 100644 --- a/beat/backend/python/data.py +++ b/beat/backend/python/data.py @@ -275,6 +275,12 @@ class DataSource(object): # ---------------------------------------------------------- +# helper to store file information +# required to be out of the CachedDataSource for pickling reasons +FileInfos = namedtuple( + "FileInfos", ["file_index", "start_index", "end_index", "offset", "size"] +) + class CachedDataSource(DataSource): """Utility class to load data from a file in the cache""" @@ -290,6 +296,29 @@ class CachedDataSource(DataSource): self.current_file_index = None self.unpack = True + def __getstate__(self): + # do not pass open files when being pickled/copied + state = self.__dict__.copy() + + if state["current_file"] is not None: + del state["current_file"] + state["__had_open_file__"] = True + + return state + + def __setstate__(self, state): + # restore the state after being pickled/copied + had_open_file_before_pickle = state.pop("__had_open_file__", False) + + self.__dict__.update(state) + + if had_open_file_before_pickle: + try: + path = self.filenames[self.current_file_index] + self.current_file = open(path, "rb") + except Exception as e: + raise IOError("Could not read `%s': %s" % (path, e)) + def _readHeader(self, file): """Read the header of the provided file""" @@ -428,11 +457,6 @@ class CachedDataSource(DataSource): check_consistency(self.filenames, data_checksum_filenames) - # Load all the needed infos from all the files - FileInfos = namedtuple( - "FileInfos", ["file_index", "start_index", "end_index", "offset", "size"] - ) - for file_index, current_filename in enumerate(self.filenames): try: f = open(current_filename, "rb") diff --git a/beat/backend/python/test/test_data.py b/beat/backend/python/test/test_data.py index 227814a..4bd14bf 100644 --- a/beat/backend/python/test/test_data.py +++ b/beat/backend/python/test/test_data.py @@ -37,6 +37,7 @@ import unittest import os import glob +import pickle # nosec import tempfile import shutil @@ -360,6 +361,29 @@ class TestCachedDataSource(TestCachedDataBase): self.assertIsNone(cached_source.current_file_index) self.assertIsNone(cached_source.current_file) + def test_picklability(self): + self.writeData("user/single_integer/1", 0, 9) + + cached_source = CachedDataSource() + cached_source.setup(self.filename, prefix) + + # test pickle before accessing a file + cached_source2 = pickle.loads(pickle.dumps(cached_source)) # nosec + data, start, end = cached_source[0] + data2, start2, end2 = cached_source2[0] + self.assertEqual(data.value, data2.value) + self.assertEqual(start, start2) + self.assertEqual(end, end2) + + # access one file and try again + data, start, end = cached_source[0] + cached_source2 = pickle.loads(pickle.dumps(cached_source)) # nosec + + data2, start2, end2 = cached_source2[0] + self.assertEqual(data.value, data2.value) + self.assertEqual(start, start2) + self.assertEqual(end, end2) + # ---------------------------------------------------------- -- GitLab