Commit a6d4f5f9 authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Added tests for cross_validation all_samples

Fixed a cache issue.
prevent eval when not present.
parent 9fc33fe4
Pipeline #46021 passed with stage
in 7 minutes and 49 seconds
......@@ -10,7 +10,9 @@ import functools
from abc import ABCMeta, abstractmethod
import numpy as np
import itertools
import logging
logger = logging.getLogger(__name__)
class CSVBaseSampleLoader(metaclass=ABCMeta):
"""
......@@ -340,10 +342,13 @@ class CSVDatasetDevEval:
# Get enroll and probe samples
groups = ["dev", "eval"] if not groups else groups
if "eval" in groups and (not self.eval_enroll_csv or not self.eval_probe_csv):
logger.info("'eval' requested, but dataset has no 'eval' group.")
groups.remove("eval")
for group in groups:
for purpose in ("enroll", "probe"):
label = f"{group}_{purpose}_csv"
samples.append(self.csv_to_sample_loader(self.__dict__[label]))
samples = samples + self.csv_to_sample_loader(self.__dict__[label])
return samples
......@@ -489,11 +494,13 @@ class CSVDatasetCrossValidation:
samples = self.background_model_samples()
# Get enroll and probe samples
groups = ["dev", "eval"] if not groups else groups
groups = ["dev"] if not groups else groups
if "eval" in groups:
logger.info("'eval' requested but there is no 'eval' group defined.")
groups.remove("eval")
for group in groups:
for purpose in ("enroll", "probe"):
label = f"{group}_{purpose}_csv"
samples.append(self.csv_to_sample_loader(self.__dict__[label]))
samples = samples+ [s for s_set in self.references(group) for s in s_set]
samples = samples+ [s for s_set in self.probes(group) for s in s_set]
return samples
......
......@@ -107,6 +107,9 @@ def test_csv_file_list_dev_eval():
assert len(dataset.probes(group="eval")) == 13
assert check_all_true(dataset.probes(group="eval"), SampleSet)
assert len(dataset.all_samples(groups=None)) == 49
assert check_all_true(dataset.all_samples(groups=None), DelayedSample)
def test_csv_file_list_atnt():
......@@ -114,6 +117,32 @@ def test_csv_file_list_atnt():
assert len(dataset.background_model_samples()) == 200
assert len(dataset.references()) == 20
assert len(dataset.probes()) == 100
assert len(dataset.all_samples(groups=["dev"])) == 400
assert len(dataset.all_samples(groups=None)) == 400
def data_loader(path):
import bob.io.image
return bob.io.base.load(path)
def test_csv_cross_validation_atnt():
dataset = CSVDatasetCrossValidation(
csv_file_name=atnt_protocol_path_cross_validation,
random_state=0,
test_size=0.8,
csv_to_sample_loader=CSVToSampleLoader(
data_loader=data_loader,
dataset_original_directory=atnt_database_directory(),
extension=".pgm",
),
)
assert len(dataset.background_model_samples()) == 80
assert len(dataset.references("dev")) == 32
assert len(dataset.probes("dev")) == 288
assert len(dataset.all_samples(groups=None)) == 400
def run_experiment(dataset):
......@@ -131,12 +160,6 @@ def run_experiment(dataset):
)
def data_loader(path):
import bob.io.image
return bob.io.base.load(path)
def test_atnt_experiment():
dataset = CSVDatasetDevEval(
......@@ -160,7 +183,7 @@ def test_atnt_experiment_cross_validation():
total_identities = 40
samples_for_enrollment = 1
def run_cross_validataion_experiment(test_size=0.9):
def run_cross_validation_experiment(test_size=0.9):
dataset = CSVDatasetCrossValidation(
csv_file_name=atnt_protocol_path_cross_validation,
random_state=0,
......@@ -179,9 +202,9 @@ def test_atnt_experiment_cross_validation():
* (samples_per_identity - samples_for_enrollment)
)
run_cross_validataion_experiment(test_size=0.9)
run_cross_validataion_experiment(test_size=0.8)
run_cross_validataion_experiment(test_size=0.5)
run_cross_validation_experiment(test_size=0.9)
run_cross_validation_experiment(test_size=0.8)
run_cross_validation_experiment(test_size=0.5)
####
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment