Commit 14a58698 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Make the databases work transparently with with either tarballs or csv files

parent 31e99d38
Pipeline #46251 passed with stage
in 5 minutes and 41 seconds
......@@ -13,7 +13,7 @@ import numpy as np
import itertools
import logging
import bob.db.base
from bob.extension.download import find_element_in_tarball
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database
logger = logging.getLogger(__name__)
......@@ -156,14 +156,13 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
if not "path" in header:
raise ValueError("The field `path` is not available in your dataset.")
def __call__(self, filename):
with open(filename) as cf:
reader = csv.reader(cf)
header = next(reader)
def __call__(self, f):
f.seek(0)
reader = csv.reader(f)
header = next(reader)
self.check_header(header)
return [self.convert_row_to_sample(row, header) for row in reader]
self.check_header(header)
return [self.convert_row_to_sample(row, header) for row in reader]
def convert_row_to_sample(self, row, header):
path = row[0]
......@@ -192,17 +191,16 @@ class LSTToSampleLoader(CSVBaseSampleLoader):
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
"""
def __call__(self, filename):
with open(filename) as cf:
reader = csv.reader(cf, delimiter=" ")
samples = []
for row in reader:
if row[0][0] == "#":
continue
samples.append(self.convert_row_to_sample(row))
def __call__(self, f):
f.seek(0)
reader = csv.reader(f, delimiter=" ")
samples = []
for row in reader:
if row[0][0] == "#":
continue
samples.append(self.convert_row_to_sample(row))
return samples
return samples
def convert_row_to_sample(self, row, header=None):
......@@ -333,57 +331,61 @@ class CSVDatasetDevEval(Database):
if not os.path.exists(dataset_protocol_path):
raise ValueError(f"The path `{dataset_protocol_path}` was not found")
# TODO: Unzip file if dataset path is a zip
protocol_path = os.path.join(dataset_protocol_path, protocol_name)
if not os.path.exists(protocol_path):
raise ValueError(f"The protocol `{protocol_name}` was not found")
def path_discovery(option1, option2):
return option1 if os.path.exists(option1) else option2
# If the input is a directory
if os.path.isdir(dataset_protocol_path):
option1 = os.path.join(dataset_protocol_path, option1)
option2 = os.path.join(dataset_protocol_path, option2)
if os.path.exists(option1):
return open(option1)
else:
return open(option2) if os.path.exists(option2) else None
# If it's not a directory is a tarball
op1 = find_element_in_tarball(dataset_protocol_path, option1)
return (
op1
if op1
else find_element_in_tarball(dataset_protocol_path, option2)
)
# Here we are handling the legacy
train_csv = path_discovery(
os.path.join(protocol_path, "norm", "train_world.lst"),
os.path.join(protocol_path, "norm", "train_world.csv"),
os.path.join(protocol_name, "norm", "train_world.lst"),
os.path.join(protocol_name, "norm", "train_world.csv"),
)
dev_enroll_csv = path_discovery(
os.path.join(protocol_path, "dev", "for_models.lst"),
os.path.join(protocol_path, "dev", "for_models.csv"),
os.path.join(protocol_name, "dev", "for_models.lst"),
os.path.join(protocol_name, "dev", "for_models.csv"),
)
legacy_probe = "for_scores.lst" if self.is_sparse else "for_probes.lst"
dev_probe_csv = path_discovery(
os.path.join(protocol_path, "dev", legacy_probe),
os.path.join(protocol_path, "dev", "for_probes.csv"),
os.path.join(protocol_name, "dev", legacy_probe),
os.path.join(protocol_name, "dev", "for_probes.csv"),
)
eval_enroll_csv = path_discovery(
os.path.join(protocol_path, "eval", "for_models.lst"),
os.path.join(protocol_path, "eval", "for_models.csv"),
os.path.join(protocol_name, "eval", "for_models.lst"),
os.path.join(protocol_name, "eval", "for_models.csv"),
)
eval_probe_csv = path_discovery(
os.path.join(protocol_path, "eval", legacy_probe),
os.path.join(protocol_path, "eval", "for_probes.csv"),
os.path.join(protocol_name, "eval", legacy_probe),
os.path.join(protocol_name, "eval", "for_probes.csv"),
)
# The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`
train_csv = train_csv if os.path.exists(train_csv) else None
# Eval
eval_enroll_csv = (
eval_enroll_csv if os.path.exists(eval_enroll_csv) else None
)
eval_probe_csv = eval_probe_csv if os.path.exists(eval_probe_csv) else None
# Dev
if not os.path.exists(dev_enroll_csv):
if dev_enroll_csv is None:
raise ValueError(
f"The file `{dev_enroll_csv}` is required and it was not found"
)
if not os.path.exists(dev_probe_csv):
if dev_probe_csv is None:
raise ValueError(
f"The file `{dev_probe_csv}` is required and it was not found"
)
......@@ -612,7 +614,7 @@ class CSVDatasetCrossValidation:
self.random_state = random_state
self.cache = get_dict_cache()
self.csv_to_sample_loader = csv_to_sample_loader
self.csv_file_name = csv_file_name
self.csv_file_name = open(csv_file_name)
self.samples_for_enrollment = samples_for_enrollment
self.test_size = test_size
......
......@@ -105,46 +105,50 @@ def test_csv_file_list_dev_eval():
)
)
dataset = CSVDatasetDevEval(
example_dir,
"protocol_dev_eval",
csv_to_sample_loader=CSVToSampleLoader(
data_loader=bob.io.base.load,
metadata_loader=AnnotationsLoader(
annotation_directory=annotation_directory,
annotation_extension=".pos",
annotation_type="eyecenter",
def run(filename):
dataset = CSVDatasetDevEval(
filename,
"protocol_dev_eval",
csv_to_sample_loader=CSVToSampleLoader(
data_loader=bob.io.base.load,
metadata_loader=AnnotationsLoader(
annotation_directory=annotation_directory,
annotation_extension=".pos",
annotation_type="eyecenter",
),
dataset_original_directory="",
extension="",
),
dataset_original_directory="",
extension="",
),
)
assert len(dataset.background_model_samples()) == 8
assert check_all_true(dataset.background_model_samples(), DelayedSample)
)
assert len(dataset.background_model_samples()) == 8
assert check_all_true(dataset.background_model_samples(), DelayedSample)
assert len(dataset.references()) == 2
assert check_all_true(dataset.references(), SampleSet)
assert len(dataset.references()) == 2
assert check_all_true(dataset.references(), SampleSet)
assert len(dataset.probes()) == 8
assert check_all_true(dataset.references(), SampleSet)
assert len(dataset.probes()) == 8
assert check_all_true(dataset.references(), SampleSet)
assert len(dataset.references(group="eval")) == 6
assert check_all_true(dataset.references(group="eval"), SampleSet)
assert len(dataset.references(group="eval")) == 6
assert check_all_true(dataset.references(group="eval"), SampleSet)
assert len(dataset.probes(group="eval")) == 13
assert check_all_true(dataset.probes(group="eval"), SampleSet)
assert len(dataset.probes(group="eval")) == 13
assert check_all_true(dataset.probes(group="eval"), SampleSet)
assert len(dataset.all_samples(groups=None)) == 47
assert check_all_true(dataset.all_samples(groups=None), DelayedSample)
assert len(dataset.all_samples(groups=None)) == 47
assert check_all_true(dataset.all_samples(groups=None), DelayedSample)
# Check the annotations
for s in dataset.all_samples(groups=None):
assert isinstance(s.annotations, dict)
# Check the annotations
for s in dataset.all_samples(groups=None):
assert isinstance(s.annotations, dict)
assert len(dataset.reference_ids(group="dev")) == 2
assert len(dataset.reference_ids(group="eval")) == 6
assert len(dataset.reference_ids(group="dev")) == 2
assert len(dataset.reference_ids(group="eval")) == 6
assert len(dataset.groups()) == 3
assert len(dataset.groups()) == 3
run(example_dir)
run(example_dir + ".tar.gz")
def test_csv_file_list_dev_eval_sparse():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment