Commit 647fe253 authored by Samuel GAIST's avatar Samuel GAIST

[data_loaders] Add reset in the same idea as DataSource

parent bc8cb7f7
......@@ -210,6 +210,14 @@ class DataLoader(object):
self.data_index_start = -1 # Lower index across all inputs
self.data_index_end = -1 # Bigger index across all inputs
def reset(self):
"""Reset all the data sources"""
for infos in self.infos.values():
data_source = infos.get("data_source")
if data_source:
data_source.reset()
def add(self, input_name, data_source):
self.infos[input_name] = dict(
data_source=data_source,
......
......@@ -630,6 +630,25 @@ class DataLoaderTest(DataLoaderBaseTest):
view = data_loader.view("input2", 2)
self.assertTrue(view is None)
def test_reset(self):
# Setup
input_name = "input1"
self.writeData(input_name, [(0, 0), (1, 1), (2, 2)], 1000)
data_loader = DataLoader("channel1")
cached_file = CachedDataSource()
cached_file.setup(self.filenames[input_name], prefix)
data_loader.add(input_name, cached_file)
_, _, _ = data_loader[0]
cached_source = data_loader.infos[input_name]["data_source"]
self.assertIsNotNone(cached_source.current_file_index)
self.assertIsNotNone(cached_source.current_file)
data_loader.reset()
self.assertIsNone(cached_source.current_file_index)
self.assertIsNone(cached_source.current_file)
# ----------------------------------------------------------
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment