Skip to content
Snippets Groups Projects
Commit 1d4aca34 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'db-fixes' into 'master'

Small fixes the database tests so it flawlessly passes in [mac] builds

See merge request !93
parents bf6eb731 d883e0f9
No related branches found
No related tags found
1 merge request!93Small fixes the database tests so it flawlessly passes in [mac] builds
Pipeline #46785 passed
...@@ -21,13 +21,6 @@ from sklearn.pipeline import make_pipeline ...@@ -21,13 +21,6 @@ from sklearn.pipeline import make_pipeline
import os import os
cache_subdir = "datasets"
filename = "meds.tar.gz"
dataset_protocol_path = os.path.join(
os.path.expanduser("~"), "bob_data", cache_subdir, filename
)
class MEDSDatabase(CSVDatasetZTNorm): class MEDSDatabase(CSVDatasetZTNorm):
""" """
The MEDS II database was developed by NIST to support and assists their biometrics evaluation program. The MEDS II database was developed by NIST to support and assists their biometrics evaluation program.
...@@ -103,17 +96,14 @@ class MEDSDatabase(CSVDatasetZTNorm): ...@@ -103,17 +96,14 @@ class MEDSDatabase(CSVDatasetZTNorm):
def __init__(self, protocol): def __init__(self, protocol):
# Downloading model if not exists # Downloading model if not exists
urls = [ urls = MEDSDatabase.urls()
"https://www.idiap.ch/software/bob/databases/latest/meds.tar.gz", filename = get_file("meds.tar.gz", urls)
"http://www.idiap.ch/software/bob/databases/latest/meds.tar.gz",
]
get_file(filename, urls)
self.annotation_type = "eyes-center" self.annotation_type = "eyes-center"
self.fixed_positions = None self.fixed_positions = None
database = CSVDataset( database = CSVDataset(
dataset_protocol_path, filename,
protocol, protocol,
csv_to_sample_loader=make_pipeline( csv_to_sample_loader=make_pipeline(
CSVToSampleLoader( CSVToSampleLoader(
...@@ -128,3 +118,10 @@ class MEDSDatabase(CSVDatasetZTNorm): ...@@ -128,3 +118,10 @@ class MEDSDatabase(CSVDatasetZTNorm):
) )
super().__init__(database) super().__init__(database)
@staticmethod
def urls():
return [
"https://www.idiap.ch/software/bob/databases/latest/meds.tar.gz",
"http://www.idiap.ch/software/bob/databases/latest/meds.tar.gz",
]
...@@ -51,10 +51,7 @@ class MobioDatabase(CSVDatasetZTNorm): ...@@ -51,10 +51,7 @@ class MobioDatabase(CSVDatasetZTNorm):
def __init__(self, protocol): def __init__(self, protocol):
# Downloading model if not exists # Downloading model if not exists
urls = [ urls = MobioDatabase.urls()
"https://www.idiap.ch/software/bob/databases/latest/mobio.tar.gz",
"http://www.idiap.ch/software/bob/databases/latest/mobio.tar.gz",
]
filename = get_file("mobio.tar.gz", urls) filename = get_file("mobio.tar.gz", urls)
self.annotation_type = "eyes-center" self.annotation_type = "eyes-center"
...@@ -77,9 +74,6 @@ class MobioDatabase(CSVDatasetZTNorm): ...@@ -77,9 +74,6 @@ class MobioDatabase(CSVDatasetZTNorm):
super().__init__(database) super().__init__(database)
# def zprobes(self, proportion=0.20):
# return super().zprobes(proportion=proportion)
@staticmethod @staticmethod
def protocols(): def protocols():
# TODO: Until we have (if we have) a function that dumps the protocols, let's use this one. # TODO: Until we have (if we have) a function that dumps the protocols, let's use this one.
...@@ -94,3 +88,10 @@ class MobioDatabase(CSVDatasetZTNorm): ...@@ -94,3 +88,10 @@ class MobioDatabase(CSVDatasetZTNorm):
"mobile0-male", "mobile0-male",
"mobile1-female", "mobile1-female",
] ]
@staticmethod
def urls():
return [
"https://www.idiap.ch/software/bob/databases/latest/mobio.tar.gz",
"http://www.idiap.ch/software/bob/databases/latest/mobio.tar.gz",
]
...@@ -60,10 +60,7 @@ class MorphDatabase(CSVDatasetZTNorm): ...@@ -60,10 +60,7 @@ class MorphDatabase(CSVDatasetZTNorm):
def __init__(self, protocol): def __init__(self, protocol):
# Downloading model if not exists # Downloading model if not exists
urls = [ urls = MorphDatabase.urls()
"https://www.idiap.ch/software/bob/databases/latest/morph.tar.gz",
"http://www.idiap.ch/software/bob/databases/latest/morph.tar.gz",
]
filename = get_file("morph.tar.gz", urls) filename = get_file("morph.tar.gz", urls)
self.annotation_type = "eyes-center" self.annotation_type = "eyes-center"
...@@ -85,3 +82,10 @@ class MorphDatabase(CSVDatasetZTNorm): ...@@ -85,3 +82,10 @@ class MorphDatabase(CSVDatasetZTNorm):
) )
super().__init__(database) super().__init__(database)
@staticmethod
def urls():
return [
"https://www.idiap.ch/software/bob/databases/latest/morph.tar.gz",
"http://www.idiap.ch/software/bob/databases/latest/morph.tar.gz",
]
...@@ -23,10 +23,7 @@ class MultipieDatabase(CSVDataset): ...@@ -23,10 +23,7 @@ class MultipieDatabase(CSVDataset):
def __init__(self, protocol): def __init__(self, protocol):
# Downloading model if not exists # Downloading model if not exists
urls = [ urls = MultipieDatabase.urls()
"https://www.idiap.ch/software/bob/databases/latest/multipie.tar.gz",
"http://www.idiap.ch/software/bob/databases/latest/multipie.tar.gz",
]
filename = get_file("multipie.tar.gz", urls) filename = get_file("multipie.tar.gz", urls)
self.annotation_type = ["eyes-center", "left-profile", "right-profile"] self.annotation_type = ["eyes-center", "left-profile", "right-profile"]
...@@ -72,3 +69,10 @@ class MultipieDatabase(CSVDataset): ...@@ -72,3 +69,10 @@ class MultipieDatabase(CSVDataset):
"P081", "P081",
"P090", "P090",
] ]
@staticmethod
def urls():
return [
"https://www.idiap.ch/software/bob/databases/latest/multipie.tar.gz",
"http://www.idiap.ch/software/bob/databases/latest/multipie.tar.gz",
]
...@@ -27,6 +27,7 @@ from bob.bio.base.test.test_database_implementations import ( ...@@ -27,6 +27,7 @@ from bob.bio.base.test.test_database_implementations import (
check_database_zt, check_database_zt,
) )
import bob.core import bob.core
from bob.extension.download import get_file
logger = bob.core.log.setup("bob.bio.face") logger = bob.core.log.setup("bob.bio.face")
...@@ -159,6 +160,16 @@ def test_lfw(): ...@@ -159,6 +160,16 @@ def test_lfw():
def test_mobio(): def test_mobio():
from bob.bio.face.database import MobioDatabase from bob.bio.face.database import MobioDatabase
# Getting the absolute path
urls = MobioDatabase.urls()
filename = get_file("mobio.tar.gz", urls)
# Removing the file before the test
try:
os.remove(filename)
except:
pass
protocols = MobioDatabase.protocols() protocols = MobioDatabase.protocols()
for p in protocols: for p in protocols:
database = MobioDatabase(protocol=p) database = MobioDatabase(protocol=p)
...@@ -188,6 +199,16 @@ def test_mobio(): ...@@ -188,6 +199,16 @@ def test_mobio():
def test_multipie(): def test_multipie():
from bob.bio.face.database import MultipieDatabase from bob.bio.face.database import MultipieDatabase
# Getting the absolute path
urls = MultipieDatabase.urls()
filename = get_file("multipie.tar.gz", urls)
# Removing the file before the test
try:
os.remove(filename)
except:
pass
protocols = MultipieDatabase.protocols() protocols = MultipieDatabase.protocols()
for p in protocols: for p in protocols:
...@@ -324,17 +345,54 @@ def test_fargo(): ...@@ -324,17 +345,54 @@ def test_fargo():
def test_meds(): def test_meds():
from bob.bio.face.database import MEDSDatabase from bob.bio.face.database import MEDSDatabase
# Getting the absolute path
urls = MEDSDatabase.urls()
filename = get_file("meds.tar.gz", urls)
# Removing the file before the test
try:
os.remove(filename)
except:
pass
database = MEDSDatabase("verification_fold1") database = MEDSDatabase("verification_fold1")
assert len(database.background_model_samples()) == 234 assert len(database.background_model_samples()) == 234
assert len(database.references()) == 111 assert len(database.references()) == 111
assert len(database.probes()) == 313 assert len(database.probes()) == 313
assert len(database.references(group="dev"))
assert len(database.zprobes()) == 80 assert len(database.zprobes()) == 80
assert len(database.treferences()) == 80 assert len(database.treferences()) == 80
assert len(database.references(group="eval")) == 112 assert len(database.references(group="eval")) == 112
assert len(database.probes(group="eval")) == 309 assert len(database.probes(group="eval")) == 309
def test_morph():
from bob.bio.face.database import MorphDatabase
# Getting the absolute path
urls = MorphDatabase.urls()
filename = get_file("morph.tar.gz", urls)
# Removing the file before the test
try:
os.remove(filename)
except:
pass
database = MorphDatabase("verification_fold1")
assert len(database.background_model_samples()) == 226
assert len(database.references()) == 6738
assert len(database.probes()) == 6557
assert len(database.zprobes()) == 66
assert len(database.treferences()) == 69
assert len(database.references(group="eval")) == 6742
assert len(database.probes(group="eval")) == 6553
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment