Commit df15c5e8 authored by Philip ABBET's avatar Philip ABBET

Update putvein/3 (api change: beat.backend.python v1.4.1)

parent 3558cd46
......@@ -23,240 +23,438 @@
###############################################################################
import numpy as np
import bob.io.base
import bob.io.image
import bob.db.putvein
import bob.ip.color
import os
#----------------------------------------------------------
def get_client_end_index(objs, client_id, client_start_index,
start_index, end_index):
client_end_index = client_start_index
while client_end_index + 1 <= end_index:
obj = objs[client_end_index + 1 - start_index]
if isinstance(obj, tuple):
obj = obj[1]
if obj.get_client_id() != client_id:
return client_end_index
client_end_index += 1
return end_index
#----------------------------------------------------------
def get_model_end_index(objs, model_id, model_start_index,
start_index, end_index):
model_end_index = model_start_index
while model_end_index + 1 <= end_index:
id = objs[model_end_index + 1 - start_index][0]
if id != model_id:
return model_end_index
model_end_index += 1
return end_index
#----------------------------------------------------------
class View:
"""Outputs:
- image: "{{ system_user.username }}/array_2d_uint8/1"
- client_id: "{{ system_user.username }}/uint64/1"
Several "image" are associated with a given "client_id"
def setup(self,
root_folder,
outputs,
parameters,
force_start_index=None,
--------- --------- --------- --------- --------- ---------
| image | | image | | image | | image | | image | | image |
--------- --------- --------- --------- --------- ---------
----------------------------- -----------------------------
| client_id | | client_id |
----------------------------- -----------------------------
"""
def setup(self, root_folder, outputs, parameters, force_start_index=None,
force_end_index=None):
# Initialisations
self.root_folder = os.path.join(root_folder, '')
self.outputs = outputs
self.parameters = parameters
# Open the database and load the objects to provide via the outputs
self.db = bob.db.putvein.Database()
self.objs = sorted(self.db.objects(protocol=self.parameters['protocol'],
purposes=self.parameters.get('purpose', None),
groups=[self.parameters['group']],
kinds=[self.parameters['kind']]
),
kinds=[self.parameters['kind']]),
key=lambda x: x.client_id)
self.next_index = 0
self.force_start_index = force_start_index
self.force_end_index = force_end_index
# Determine the range of indices that must be provided
self.start_index = force_start_index if force_start_index is not None else 0
self.end_index = force_end_index if force_end_index is not None else len(self.objs) - 1
# Retrieve only 'useful' data
# End index
if self.force_end_index is not None:
self.objs = self.objs[:self.force_end_index+1]
# Start index
if self.force_start_index is not None:
self.objs = self.objs[self.force_start_index:]
self.next_index = self.force_start_index
else:
self.force_start_index = 0
self.objs = self.objs[self.start_index : self.end_index + 1]
self.previous_client_id = -1
self.next_index = self.start_index
return True
def done(self):
return (self.next_index-self.force_start_index >= len(self.objs))
def done(self, last_data_index):
return last_data_index >= self.end_index
def next(self):
obj = self.objs[self.next_index-self.force_start_index]
# Please pay attention - we can use CLIENT ID as long as we don't use
# same client both hand data in the same protocol. Currently this
# doesn't happen.
if self.outputs['client_id'].isConnected():
obj = self.objs[self.next_index - self.start_index]
# Output: client_id (only provide data when the client_id change)
if self.outputs['client_id'].isConnected() and \
self.outputs['client_id'].last_written_data_index < self.next_index:
# Please pay attention - we can use the object CLIENT ID as long as we
# don't use both hands data of the same client in the same protocol.
client_id = obj.get_client_id()
if client_id != self.previous_client_id:
# Search for the end index
end_index = self.next_index
while end_index + 1 - self.force_start_index < len(self.objs):
end_index += 1
obj2 = self.objs[end_index - self.force_start_index]
if obj2.get_client_id() != client_id:
break
client_end_index = get_client_end_index(self.objs, client_id,
self.next_index,
self.start_index,
self.end_index)
self.outputs['client_id'].write({'value':
np.uint64(client_id)},
end_index)
self.previous_client_id = client_id
self.outputs['client_id'].write(
{
'value': np.uint64(client_id)
},
client_end_index
)
# Output: image (provide data at each iteration)
if self.outputs['image'].isConnected():
"""
The image returned by the ``bob.db.putvein`` is RGB (with shape
(3, 768, 1024)). This method converts image to a greyscale
(3, 768, 1024)). This method converts image to a grayscale
(shape (768, 1024)) and then rotates image by 270 deg so that
images can be used with ``bob.bio.vein`` algorythms designed for
images can be used with ``bob.bio.vein`` algorithms designed for
the ``bob.db.biowave_v1`` database.
Output images dimentions - (1024, 768).
Output images dimensions: (1024, 768).
"""
color_image = obj.load(self.root_folder)
grayscale_image = bob.ip.color.rgb_to_gray(color_image)
grayscale_image = np.rot90(grayscale_image, k=3)
data = {
'value': grayscale_image
}
self.outputs['image'].write(data, self.next_index)
self.outputs['image'].write(
{
'value': grayscale_image
},
self.next_index
)
self.next_index += 1
# Determine the next data index that must be provided
self.next_index = 1 + min([ x.last_written_data_index for x in self.outputs
if x.isConnected() ]
)
return True
#----------------------------------------------------------
class TemplateView:
# Reasoning: Each client may have a number of models in certain databases.
# So, each model receives an unique identifier. Those identifiers are
# linked to the client identifier and contain a number of images to
# generated the model from.
def setup(self,
root_folder,
outputs,
parameters,
force_start_index=None,
"""Outputs:
- image: "{{ system_user.username }}/array_2d_uint8/1"
- client_id: "{{ system_user.username }}/uint64/1"
- model_id: "{{ system_user.username }}/text/1"
Several "image" are associated with a given "client_id".
Several "client_id" are associated with a given "model_id".
--------- --------- --------- --------- --------- ---------
| image | | image | | image | | image | | image | | image |
--------- --------- --------- --------- --------- ---------
----------------------------- -----------------------------
| client_id | | client_id |
----------------------------- -----------------------------
-----------------------------------------------------------
| model_id |
-----------------------------------------------------------
Note: for this particular database, there is only one "client_id"
per "model_id".
"""
def setup(self, root_folder, outputs, parameters, force_start_index=None,
force_end_index=None):
# Initialisations
self.root_folder = os.path.join(root_folder, '')
self.outputs = outputs
self.parameters = parameters
# Open the database and load the objects to provide via the outputs
self.db = bob.db.putvein.Database()
self.template_ids = \
sorted(self.db.model_ids(protocol=self.parameters['protocol'],
groups=[self.parameters['group']],
kinds=[self.parameters['kind']]
))
model_ids = sorted(self.db.model_ids(protocol=self.parameters['protocol'],
groups=[self.parameters['group']],
kinds=[self.parameters['kind']]),
key=lambda x: int(x))
self.objs = []
self.objs = None
for model_id in model_ids:
objs = sorted(self.db.objects(protocol=self.parameters['protocol'],
purposes=self.parameters.get('purpose', None),
groups=[self.parameters['group']],
kinds=[self.parameters['kind']],
model_ids=[model_id]),
key=lambda x: x.client_id)
self.current_template_index = 0
self.current_obj_index = 0
self.next_index = 0
self.objs.extend([ (model_id, obj) for obj in objs ])
self.force_start_index = force_start_index
self.force_end_index = force_end_index
# We don't know how many objects we will have, so we can't operate with
# self.force_end_index.
if self.force_start_index is None:
self.force_start_index = 0
# Determine the range of indices that must be provided
self.start_index = force_start_index if force_start_index is not None else 0
self.end_index = force_end_index if force_end_index is not None else len(self.objs) - 1
# Example taxen from atnt/3 - idea to iterate through ``next`` method
# to get indexes right
while self.next_index < self.force_start_index:
self.next()
self.objs = self.objs[self.start_index : self.end_index + 1]
self.next_index = self.start_index
return True
def done(self):
# return (self.next_index-self.force_start_index >= len(self.objs))
return ((self.current_template_index >= len(self.template_ids)) or
(self.force_end_index is not None and
self.force_end_index < self.next_index))
def done(self, last_data_index):
return last_data_index >= self.end_index
def next(self):
if self.objs is None:
# probe for the specific objects concerning a given client
self.objs = \
sorted(self.db.objects(protocol=self.parameters['protocol'],
purposes=self.parameters.get('purpose',
None),
groups=[self.parameters['group']],
kinds=[self.parameters['kind']],
model_ids=[self.template_ids[self.current_template_index]]),
key=lambda x: x.id)
if (self.force_start_index <= self.next_index and
(self.force_end_index is None or
self.force_end_index >= self.next_index)):
if self.outputs['model_id'].isConnected():
# for this database the
# ``self.template_ids[self.current_template_index]`` is the
# ``model_id``:
model_id = self.template_ids[self.current_template_index]
model_id = model_id.encode('utf-8')
self.outputs['model_id'].write({'text':
model_id},
self.next_index+len(self.objs)-1)
if self.outputs['client_id'].isConnected():
client_id = self.objs[0].get_client_id()
self.outputs['client_id'].write({'value':
np.uint64(client_id)},
self.next_index+len(self.objs)-1)
obj = self.objs[self.current_obj_index]
(model_id, obj) = self.objs[self.next_index - self.start_index]
# Output: model_id (only provide data when the model_id change)
if self.outputs['model_id'].isConnected() and \
self.outputs['model_id'].last_written_data_index < self.next_index:
model_end_index = get_model_end_index(self.objs, model_id,
self.next_index,
self.start_index,
self.end_index)
self.outputs['model_id'].write(
{
'text': model_id.encode('utf-8')
},
model_end_index
)
# Output: client_id (only provide data when the client_id change)
if self.outputs['client_id'].isConnected() and \
self.outputs['client_id'].last_written_data_index < self.next_index:
# Please pay attention - we can use the object CLIENT ID as long as we
# don't use both hands data of the same client in the same protocol.
client_id = obj.get_client_id()
client_end_index = get_client_end_index(self.objs, client_id,
self.next_index,
self.start_index,
self.end_index)
self.outputs['client_id'].write(
{
'value': np.uint64(client_id)
},
client_end_index
)
# Output: image (provide data at each iteration)
if self.outputs['image'].isConnected():
# This line is taken from the "atnt/3" database's view. It seams
# that this could cause problems, if model has N images, but
# ``self.force_end_index`` /= N*M, where M -- Natural number:
if (self.force_start_index <= self.next_index and
(self.force_end_index is None or
self.force_end_index >= self.next_index)):
# No need to test if output is connected:
# if self.outputs['image'].isConnected():
"""
The image returned by the ``bob.db.putvein`` is RGB (with shape
(3, 768, 1024)). This method converts image to a greyscale
(shape (768, 1024)) and then rotates image by 270 deg so that
images can be used with ``bob.bio.vein`` algorythms designed
for the ``bob.db.biowave_v1`` database.
Output images dimentions - (1024, 768).
"""
color_image = obj.load(self.root_folder)
grayscale_image = bob.ip.color.rgb_to_gray(color_image)
grayscale_image = np.rot90(grayscale_image, k=3)
data = {
'value': grayscale_image
}
self.outputs['image'].write(data, self.next_index)
"""
The image returned by the ``bob.db.putvein`` is RGB (with shape
(3, 768, 1024)). This method converts image to a grayscale
(shape (768, 1024)) and then rotates image by 270 deg so that
images can be used with ``bob.bio.vein`` algorithms designed for
the ``bob.db.biowave_v1`` database.
Output images dimensions: (1024, 768).
"""
color_image = obj.load(self.root_folder)
grayscale_image = bob.ip.color.rgb_to_gray(color_image)
grayscale_image = np.rot90(grayscale_image, k=3)
self.next_index += 1
self.current_obj_index += 1
self.outputs['image'].write(
{
'value': grayscale_image
},
self.next_index
)
else:
self.next_index += len(self.objs)
self.current_obj_index = len(self.objs)
if self.current_obj_index == len(self.objs):
self.objs = None
self.current_obj_index = 0
self.current_template_index += 1
# Determine the next data index that must be provided
self.next_index = 1 + min([ x.last_written_data_index for x in self.outputs
if x.isConnected() ]
)
return True
#----------------------------------------------------------
# Test the behavior of the views (on fake data)
if __name__ == '__main__':
# Install a mock load method for the images
def mock_load(obj, root_folder):
return np.ndarray((3, 768, 1024), dtype=np.uint8)
bob.db.putvein.models.File.load = mock_load
# Mock output class
class MockOutput:
def __init__(self, name, connected):
self.name = name
self.connected = connected
self.last_written_data_index = -1
self.written_data = []
def write(self, data, end_data_index):
self.written_data.append(( self.last_written_data_index + 1, end_data_index, data ))
self.last_written_data_index = end_data_index
def isConnected(self):
return self.connected
# Tester utility class
from beat.backend.python.outputs import OutputList
import itertools
class Tester:
def __init__(self, name, view_class, outputs_declaration, parameters):
self.name = name
self.view_class = view_class
self.outputs_declaration = {}
self.parameters = parameters
self.determine_increments(outputs_declaration)
for L in range(0, len(self.outputs_declaration) + 1):
for subset in itertools.combinations(self.outputs_declaration.keys(), L):
self.run(subset)
def determine_increments(self, outputs_declaration):
outputs = OutputList()
for name in outputs_declaration:
outputs.add(MockOutput(name, True))
view = self.view_class()
view.setup('', outputs, self.parameters)
view.next()
print "View '%s', increments found:" % self.name
for output in outputs:
self.outputs_declaration[output.name] = output.last_written_data_index + 1
print ' - %s: %d' % (output.name, output.last_written_data_index + 1)
def run(self, connected_outputs):
if len(connected_outputs) == 0:
return
print "Testing '%s', with %d output(s): %s" % (self.name, len(connected_outputs),
', '.join(connected_outputs))
connected_outputs = dict([ x for x in self.outputs_declaration.items()
if x[0] in connected_outputs ])
not_connected_outputs = dict([ x for x in self.outputs_declaration.items()
if x[0] not in connected_outputs ])
outputs = OutputList()
for name in self.outputs_declaration.keys():
outputs.add(MockOutput(name, name in connected_outputs))
parameters = dict()
view = self.view_class()
view.setup('', outputs, self.parameters)
biggest_index_increment = max(connected_outputs.values())
biggest_index_output = [ x[0] for x in connected_outputs.items()
if x[1] == biggest_index_increment ][0]
lowest_index_increment = min(connected_outputs.values())
lowest_index_output = [ x[0] for x in connected_outputs.items()
if x[1] == lowest_index_increment ][0]
next_expected_indices = {}
for name, increment in connected_outputs.items():
next_expected_indices[name] = 0
current_index = 0
while not(view.done(outputs[lowest_index_output].last_written_data_index)):
view.next()
for name in connected_outputs.keys():
assert(outputs[name].written_data[-1][0] == next_expected_indices[name])
assert(outputs[name].written_data[-1][1] == next_expected_indices[name] + connected_outputs[name] - 1)
for name in not_connected_outputs.keys():
assert(len(outputs[name].written_data) == 0)
current_index += lowest_index_increment
for name in connected_outputs.keys():
if current_index == next_expected_indices[name] + connected_outputs[name]:
next_expected_indices[name] += connected_outputs[name]
for name in connected_outputs.keys():
assert(len(outputs[name].written_data) == current_index / connected_outputs[name])
"""
# test bob.db.putvein object compatibility with BEAT
import bob.db.putvein
db = bob.db.putvein.Database()
files = db.objects(protocol="RL_4", kinds="wrist", groups="dev", purposes="probe")
objects = sorted(files, key=lambda x: x.client_id)
for nr in range(len(files)):
print("{}-{}, sorted - {}-{}".format(files[nr].client_id, files[nr].nr, objects[nr].client_id, objects[nr].nr))
# The actual tests
Tester('View', View,
[
'client_id',
'image',
],
dict(
protocol = 'LR_4',
kind = 'wrist',
group = 'dev',
purpose = 'probe',
)
)
print(files == objects)
"""
Tester('TemplateView', TemplateView,
[
'model_id',
'client_id',
'image',
],
dict(
protocol = 'LR_4',
kind = 'wrist',
group = 'dev',
purpose = 'enroll',
)
)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment