Commit 521b4b5f authored by Samuel GAIST's avatar Samuel GAIST

[data_loaders] Pre-commit cleanup

parent cb6453e9
......@@ -90,35 +90,36 @@ class DataView(object):
"""
def __init__(self, data_loader, data_indices):
self.infos = {}
self.data_indices = data_indices
self.nb_data_units = len(data_indices)
self.data_index_start = data_indices[0][0]
self.data_index_end = data_indices[-1][1]
self.infos = {}
self.data_indices = data_indices
self.nb_data_units = len(data_indices)
self.data_index_start = data_indices[0][0]
self.data_index_end = data_indices[-1][1]
for input_name, infos in data_loader.infos.items():
input_data_indices = []
current_start = self.data_index_start
for i in range(self.data_index_start, self.data_index_end + 1):
for indices in infos['data_indices']:
for indices in infos["data_indices"]:
if indices[1] == i:
input_data_indices.append( (current_start, i) )
input_data_indices.append((current_start, i))
current_start = i + 1
break
if (len(input_data_indices) == 0) or (input_data_indices[-1][1] != self.data_index_end):
input_data_indices.append( (current_start, self.data_index_end) )
if (len(input_data_indices) == 0) or (
input_data_indices[-1][1] != self.data_index_end
):
input_data_indices.append((current_start, self.data_index_end))
self.infos[input_name] = dict(
data_source = infos['data_source'],
data_indices = input_data_indices,
data = None,
start_index = -1,
end_index = -1,
data_source=infos["data_source"],
data_indices=input_data_indices,
data=None,
start_index=-1,
end_index=-1,
)
def count(self, input_name=None):
"""Returns the number of available data indexes for the given input
name. If none given the number of available data units.
......@@ -134,30 +135,30 @@ class DataView(object):
"""
if input_name is not None:
try:
return len(self.infos[input_name]['data_indices'])
except:
return len(self.infos[input_name]["data_indices"])
except Exception:
return None
else:
return self.nb_data_units
def __getitem__(self, index):
if index < 0:
return (None, None, None)
try:
indices = self.data_indices[index]
except:
except Exception:
return (None, None, None)
result = {}
for input_name, infos in self.infos.items():
if (indices[0] < infos['start_index']) or (infos['end_index'] < indices[0]):
(infos['data'], infos['start_index'], infos['end_index']) = \
infos['data_source'].getAtDataIndex(indices[0])
if (indices[0] < infos["start_index"]) or (infos["end_index"] < indices[0]):
(infos["data"], infos["start_index"], infos["end_index"]) = infos[
"data_source"
].getAtDataIndex(indices[0])
result[input_name] = infos['data']
result[input_name] = infos["data"]
return (result, indices[0], indices[1])
......@@ -204,35 +205,34 @@ class DataLoader(object):
"""
def __init__(self, channel):
self.channel = str(channel)
self.infos = {}
self.channel = str(channel)
self.infos = {}
self.mixed_data_indices = None
self.nb_data_units = 0
self.data_index_start = -1 # Lower index across all inputs
self.data_index_end = -1 # Bigger index across all inputs
self.nb_data_units = 0
self.data_index_start = -1 # Lower index across all inputs
self.data_index_end = -1 # Bigger index across all inputs
def add(self, input_name, data_source):
self.infos[input_name] = dict(
data_source = data_source,
data_indices = data_source.data_indices(),
data = None,
start_index = -1,
end_index = -1,
data_source=data_source,
data_indices=data_source.data_indices(),
data=None,
start_index=-1,
end_index=-1,
)
self.mixed_data_indices = mixDataIndices([ x['data_indices'] for x in self.infos.values() ])
self.mixed_data_indices = mixDataIndices(
[x["data_indices"] for x in self.infos.values()]
)
self.nb_data_units = len(self.mixed_data_indices)
self.data_index_start = self.mixed_data_indices[0][0]
self.data_index_end = self.mixed_data_indices[-1][1]
def input_names(self):
"""Returns the name of all inputs associated to this data loader"""
return self.infos.keys()
def count(self, input_name=None):
"""Returns the number of available data indexes for the given input
name. If none given the number of available data units.
......@@ -249,13 +249,12 @@ class DataLoader(object):
if input_name is not None:
try:
return len(self.infos[input_name]['data_indices'])
except:
return len(self.infos[input_name]["data_indices"])
except Exception:
return 0
else:
return self.nb_data_units
def view(self, input_name, index):
""" Returns the view associated with this data loader
......@@ -272,33 +271,36 @@ class DataLoader(object):
return None
try:
indices = self.infos[input_name]['data_indices'][index]
except:
indices = self.infos[input_name]["data_indices"][index]
except Exception:
return None
limited_data_indices = [ x for x in self.mixed_data_indices
if (indices[0] <= x[0]) and (x[1] <= indices[1]) ]
limited_data_indices = [
x
for x in self.mixed_data_indices
if (indices[0] <= x[0]) and (x[1] <= indices[1])
]
return DataView(self, limited_data_indices)
def __getitem__(self, index):
if index < 0:
return (None, None, None)
try:
indices = self.mixed_data_indices[index]
except:
except Exception:
return (None, None, None)
result = {}
for input_name, infos in self.infos.items():
if (indices[0] < infos['start_index']) or (infos['end_index'] < indices[0]):
(infos['data'], infos['start_index'], infos['end_index']) = \
infos['data_source'].getAtDataIndex(indices[0])
if (indices[0] < infos["start_index"]) or (infos["end_index"] < indices[0]):
(infos["data"], infos["start_index"], infos["end_index"]) = infos[
"data_source"
].getAtDataIndex(indices[0])
result[input_name] = infos['data']
result[input_name] = infos["data"]
return (result, indices[0], indices[1])
......@@ -354,7 +356,6 @@ class DataLoaderList(object):
self._loaders = []
self.main_loader = None
def add(self, data_loader):
"""Add a data loader to the list
......@@ -366,7 +367,6 @@ class DataLoaderList(object):
self._loaders.append(data_loader)
def __getitem__(self, name_or_index):
try:
if isinstance(name_or_index, six.string_types):
......@@ -374,30 +374,24 @@ class DataLoaderList(object):
elif isinstance(name_or_index, int):
return self._loaders[name_or_index]
except:
pass
return None
except Exception:
return None
def __iter__(self):
for i in range(len(self._loaders)):
yield self._loaders[i]
def __len__(self):
return len(self._loaders)
def loaderOf(self, input_name):
"""Returns the data loader matching the input name"""
try:
return [ k for k in self._loaders if input_name in k.input_names() ][0]
except:
return [k for k in self._loaders if input_name in k.input_names()][0]
except Exception:
return None
def secondaries(self):
"""Returns a list of all data loaders except the main one"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment