From 5187681a4dafe5a411d344ccba930d3348f28e67 Mon Sep 17 00:00:00 2001
From: Samuel Gaist <samuel.gaist@idiap.ch>
Date: Fri, 15 May 2020 17:26:30 +0200
Subject: [PATCH] [data] Make CachedDataSource pickable

This will allow to pass an instance of it in
a queue when using the multiprocessing module.
---
 beat/backend/python/data.py           | 34 +++++++++++++++++++++++----
 beat/backend/python/test/test_data.py | 24 +++++++++++++++++++
 2 files changed, 53 insertions(+), 5 deletions(-)

diff --git a/beat/backend/python/data.py b/beat/backend/python/data.py
index c126d55..80f7547 100644
--- a/beat/backend/python/data.py
+++ b/beat/backend/python/data.py
@@ -275,6 +275,12 @@ class DataSource(object):
 
 # ----------------------------------------------------------
 
+# helper to store file information
+# required to be out of the CachedDataSource for pickling reasons
+FileInfos = namedtuple(
+    "FileInfos", ["file_index", "start_index", "end_index", "offset", "size"]
+)
+
 
 class CachedDataSource(DataSource):
     """Utility class to load data from a file in the cache"""
@@ -290,6 +296,29 @@ class CachedDataSource(DataSource):
         self.current_file_index = None
         self.unpack = True
 
+    def __getstate__(self):
+        # do not pass open files when being pickled/copied
+        state = self.__dict__.copy()
+
+        if state["current_file"] is not None:
+            del state["current_file"]
+            state["__had_open_file__"] = True
+
+        return state
+
+    def __setstate__(self, state):
+        # restore the state after being pickled/copied
+        had_open_file_before_pickle = state.pop("__had_open_file__", False)
+
+        self.__dict__.update(state)
+
+        if had_open_file_before_pickle:
+            try:
+                path = self.filenames[self.current_file_index]
+                self.current_file = open(path, "rb")
+            except Exception as e:
+                raise IOError("Could not read `%s': %s" % (path, e))
+
     def _readHeader(self, file):
         """Read the header of the provided file"""
 
@@ -428,11 +457,6 @@ class CachedDataSource(DataSource):
 
         check_consistency(self.filenames, data_checksum_filenames)
 
-        # Load all the needed infos from all the files
-        FileInfos = namedtuple(
-            "FileInfos", ["file_index", "start_index", "end_index", "offset", "size"]
-        )
-
         for file_index, current_filename in enumerate(self.filenames):
             try:
                 f = open(current_filename, "rb")
diff --git a/beat/backend/python/test/test_data.py b/beat/backend/python/test/test_data.py
index 227814a..4bd14bf 100644
--- a/beat/backend/python/test/test_data.py
+++ b/beat/backend/python/test/test_data.py
@@ -37,6 +37,7 @@
 import unittest
 import os
 import glob
+import pickle  # nosec
 import tempfile
 import shutil
 
@@ -360,6 +361,29 @@ class TestCachedDataSource(TestCachedDataBase):
         self.assertIsNone(cached_source.current_file_index)
         self.assertIsNone(cached_source.current_file)
 
+    def test_picklability(self):
+        self.writeData("user/single_integer/1", 0, 9)
+
+        cached_source = CachedDataSource()
+        cached_source.setup(self.filename, prefix)
+
+        # test pickle before accessing a file
+        cached_source2 = pickle.loads(pickle.dumps(cached_source))  # nosec
+        data, start, end = cached_source[0]
+        data2, start2, end2 = cached_source2[0]
+        self.assertEqual(data.value, data2.value)
+        self.assertEqual(start, start2)
+        self.assertEqual(end, end2)
+
+        # access one file and try again
+        data, start, end = cached_source[0]
+        cached_source2 = pickle.loads(pickle.dumps(cached_source))  # nosec
+
+        data2, start2, end2 = cached_source2[0]
+        self.assertEqual(data.value, data2.value)
+        self.assertEqual(start, start2)
+        self.assertEqual(end, end2)
+
 
 # ----------------------------------------------------------
 
-- 
GitLab