Commit d3db6d93 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Fixed Heterogeneous protocols issue #2

parent d4bad6be
......@@ -229,11 +229,19 @@ def add_protocols(session):
if purpose == 'train':
q = q.filter(and_(File.light == 'controlled', File.pose == 'frontal'))
# now get the right modality
if modality is not None:
q = q.filter(File.modality == modality)
else:
q = q.filter(File.modality == target_modality)
# if HETEROGENEOUS get 2 modalities for the WORLD set
if "rgb2nir" in p.name or "rgb2depth" in p.name:
if modality is not None:
q = q.filter(File.modality.in_([modality,"rgb"]))
else:
q = q.filter(File.modality.in_(["rgb", target_modality]))
else:
if modality is not None:
q = q.filter(File.modality.in_([modality]))
else:
q = q.filter(File.modality.in_([target_modality]))
# for enroll, we have controlled frontal images, for the first recording for each device
if purpose == 'enroll':
q = q.filter(and_(File.light == 'controlled', File.pose == 'frontal', File.recording.in_(recordings_enroll)))
......
......@@ -7,6 +7,7 @@
"""A few checks on the protocols of the FARGO public database
"""
import os, sys
import bob.db.base
import bob.db.fargo
def db_available(test):
......@@ -90,8 +91,9 @@ def test_heterogeneous():
probe_modalities = ["nir", "depth"]
for p, m in zip(protocols, probe_modalities):
assert len(db.objects(protocol=p)) == 3000
assert len(db.objects(protocol=p, groups="world")) == 1000
assert len(db.objects(protocol=p)) == 4000
assert len(db.objects(protocol=p, groups="world")) == 2000
assert len(db.objects(protocol=p, groups="world", modality=m)) == 1000
for g in groups:
assert len(db.objects(protocol=p, groups="dev")) == 1000
......@@ -113,10 +115,11 @@ def test_heterogeneous():
"uo-rgb2nir", "uo-rgb2depth"]
probe_modalities = ["nir", "depth",
"nir", "depth"]
for p, m in zip(protocols, probe_modalities):
assert len(db.objects(protocol=p)) == 4000
assert len(db.objects(protocol=p, groups="world")) == 1000
assert len(db.objects(protocol=p)) == 5000
assert len(db.objects(protocol=p, groups="world")) == 2000
assert len(db.objects(protocol=p, groups="world", modality=m)) == 1000
for g in groups:
assert len(db.objects(protocol=p, groups="dev")) == 1500
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment