Skip to content
Snippets Groups Projects
Commit 1c9834fc authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'group-probes-by-reference-id' into 'master'

[CSVDataset] Expose option to group the probes in SampleSets by reference id

See merge request !262
parents e5c67e7d 892d204f
No related branches found
No related tags found
1 merge request!262[CSVDataset] Expose option to group the probes in SampleSets by reference id
Pipeline #53249 passed
......@@ -235,6 +235,7 @@ class CSVDataset(Database):
csv_to_sample_loader=None,
is_sparse=False,
allow_scoring_with_all_biometric_references=False,
group_probes_by_reference_id=False,
**kwargs,
):
super().__init__(
......@@ -245,6 +246,7 @@ class CSVDataset(Database):
)
self.dataset_protocol_path = dataset_protocol_path
self.is_sparse = is_sparse
self.group_probes_by_reference_id = group_probes_by_reference_id
if csv_to_sample_loader is None:
csv_to_sample_loader = CSVToSampleLoaderBiometrics(
data_loader=bob.io.base.load,
......@@ -411,7 +413,7 @@ class CSVDataset(Database):
return self._get_samplesets(
group=group,
cache_key=cache_key,
group_by_reference_id=False,
group_by_reference_id=self.group_probes_by_reference_id,
fetching_probes=True,
is_sparse=self.is_sparse,
)
......@@ -563,7 +565,7 @@ class CSVDatasetZTNorm(CSVDataset):
samplesets = self._get_samplesets(
group=group,
cache_key=cache_key,
group_by_reference_id=False,
group_by_reference_id=self.group_probes_by_reference_id,
fetching_probes=True,
is_sparse=False,
)
......@@ -645,6 +647,7 @@ class CSVDatasetCrossValidation(Database):
samples_for_enrollment=1,
csv_to_sample_loader=None,
allow_scoring_with_all_biometric_references=True,
group_probes_by_reference_id=False,
**kwargs,
):
super().__init__(
......@@ -673,6 +676,7 @@ class CSVDatasetCrossValidation(Database):
self.csv_file_name = open(csv_file_name)
self.samples_for_enrollment = samples_for_enrollment
self.test_size = test_size
self.group_probes_by_reference_id = group_probes_by_reference_id
if self.test_size < 0 and self.test_size > 1:
raise ValueError(
......@@ -721,7 +725,7 @@ class CSVDatasetCrossValidation(Database):
self.cache["dev_probe_csv"] += convert_samples_to_samplesets(
samples[self.samples_for_enrollment :],
group_by_reference_id=False,
group_by_reference_id=self.group_probes_by_reference_id,
references=reference_ids[n_samples_for_training:],
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment