Skip to content
Snippets Groups Projects
Commit fa97b2cd authored by Philip ABBET's avatar Philip ABBET
Browse files

Add the 'inputs.DataLoaderGroup' class

parent 5007485b
Branches
Tags
1 merge request!17Merge development branch 1.5.x
......@@ -35,6 +35,8 @@ from functools import reduce
import six
import zmq
from .data import mixDataIndices
#----------------------------------------------------------
......@@ -493,6 +495,147 @@ class InputGroup:
#----------------------------------------------------------
class DataView(object):
def __init__(self, data_loader_group, data_indices):
self.infos = {}
self.data_indices = data_indices
self.nb_data_units = len(data_indices)
self.data_index = data_indices[0][0]
self.data_index_end = data_indices[-1][1]
for input_name, infos in data_loader_group.infos.items():
input_data_indices = []
current_start = self.data_index
for i in range(self.data_index, self.data_index_end + 1):
for indices in infos['data_indices']:
if indices[1] == i:
input_data_indices.append( (current_start, i) )
current_start = i + 1
break
if (len(input_data_indices) == 0) or (input_data_indices[-1][1] != self.data_index_end):
input_data_indices.append( (current_start, self.data_index_end) )
self.infos[input_name] = dict(
cached_file = infos['cached_file'],
data_indices = input_data_indices,
data = None,
start_index = -1,
end_index = -1,
)
def count(self, input_name=None):
if input_name is not None:
try:
return len(self.infos[input_name]['data_indices'])
except:
return None
else:
return self.nb_data_units
def __getitem__(self, index):
if index < 0:
return (None, None, None)
try:
indices = self.data_indices[index]
except:
return (None, None, None)
result = {}
for input_name, infos in self.infos.items():
if (indices[0] < infos['start_index']) or (infos['end_index'] < indices[0]):
(infos['data'], infos['start_index'], infos['end_index']) = \
infos['cached_file'].getAtDataIndex(indices[0])
result[input_name] = infos['data']
return (result, indices[0], indices[1])
#----------------------------------------------------------
class DataLoaderGroup(object):
def __init__(self, channel):
self.channel = str(channel)
self.infos = {}
self.mixed_data_indices = None
self.nb_data_units = 0
self.data_index = -1 # Lower index across all inputs
self.data_index_end = -1 # Bigger index across all inputs
def add(self, input_name, cached_file):
self.infos[input_name] = dict(
cached_file = cached_file,
data_indices = cached_file.data_indices(),
data = None,
start_index = -1,
end_index = -1,
)
self.mixed_data_indices = mixDataIndices([ x['data_indices'] for x in self.infos.values() ])
self.nb_data_units = len(self.mixed_data_indices)
self.data_index = self.mixed_data_indices[0][0]
self.data_index_end = self.mixed_data_indices[-1][1]
def count(self, input_name=None):
if input_name is not None:
try:
return len(self.infos[input_name]['data_indices'])
except:
return 0
else:
return self.nb_data_units
def view(self, input_name, index):
if index < 0:
return None
try:
indices = self.infos[input_name]['data_indices'][index]
except:
return None
limited_data_indices = [ x for x in self.mixed_data_indices
if (indices[0] <= x[0]) and (x[1] <= indices[1]) ]
return DataView(self, limited_data_indices)
def __getitem__(self, index):
if index < 0:
return (None, None, None)
try:
indices = self.mixed_data_indices[index]
except:
return (None, None, None)
result = {}
for input_name, infos in self.infos.items():
if (indices[0] < infos['start_index']) or (infos['end_index'] < indices[0]):
(infos['data'], infos['start_index'], infos['end_index']) = \
infos['cached_file'].getAtDataIndex(indices[0])
result[input_name] = infos['data']
return (result, indices[0], indices[1])
#----------------------------------------------------------
class InputList:
"""Represents the list of inputs of a processing block
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment