diff --git a/bob/pad/face/database/aggregated_db.py b/bob/pad/face/database/aggregated_db.py index f1acda3abd9e5c95c7acd09238d1965e865402af..cc584490fcdd20d16f2caaa6c1c89122b8c8a67c 100644 --- a/bob/pad/face/database/aggregated_db.py +++ b/bob/pad/face/database/aggregated_db.py @@ -1,8 +1,8 @@ #!/usr/bin/env python2 # -*- coding: utf-8 -*- -#============================================================================== -from bob.pad.face.database import VideoPadFile # Used in ReplayPadFile class +# ============================================================================= +from bob.pad.face.database import VideoPadFile from bob.pad.base.database import PadDatabase @@ -19,8 +19,7 @@ from bob.bio.video import FrameSelector, FrameContainer import numpy as np - -#============================================================================== +# ============================================================================= class AggregatedDbPadFile(VideoPadFile): """ A high level implementation of the File class for the Aggregated Database @@ -76,7 +75,7 @@ class AggregatedDbPadFile(VideoPadFile): attack_type=attack_type, file_id=file_id) - #========================================================================== + # ========================================================================= def encode_file_id(self, f, n=2000): """ Return a modified version of the ``f.id`` ensuring uniqueness of the ids @@ -134,7 +133,7 @@ class AggregatedDbPadFile(VideoPadFile): return file_id - #========================================================================== + # ========================================================================= def encode_file_path(self, f): """ Append the name of the database to the end of the file path separated @@ -187,7 +186,7 @@ class AggregatedDbPadFile(VideoPadFile): return file_path - #========================================================================== + # ========================================================================= def load(self, directory=None, extension='.mov'): """ Overridden version of the load method defined in the ``VideoPadFile``. @@ -269,7 +268,7 @@ class AggregatedDbPadFile(VideoPadFile): return video_data # video data -#============================================================================== +# ============================================================================= class AggregatedDbPadDatabase(PadDatabase): """ A high level implementation of the Database class for the Aggregated Database @@ -374,7 +373,8 @@ class AggregatedDbPadDatabase(PadDatabase): # A list of available protocols: self.available_protocols = [ 'grandtest', 'photo-photo-video', 'video-video-photo', - 'grandtest-mobio', 'grandtest-train-eval', "grandtest-train-eval-<num_train_samples>"] + 'grandtest-mobio', 'grandtest-train-eval', + 'grandtest-train-eval-<num_train_samples>'] # Always use super to call parent class methods. super(AggregatedDbPadDatabase, self).__init__( @@ -384,7 +384,7 @@ class AggregatedDbPadDatabase(PadDatabase): original_extension=original_extension, **kwargs) - #========================================================================== + # ========================================================================= def get_mobio_files_given_single_group(self, groups=None, purposes=None): """ Get a list of files for the MOBIO database. All files are bona-fide @@ -442,7 +442,7 @@ class AggregatedDbPadDatabase(PadDatabase): return mobio_files - #========================================================================== + # ========================================================================= def uniform_select_list_elements(self, data, n_samples): """ Uniformly select N elements from the input data list. @@ -467,15 +467,16 @@ class AggregatedDbPadDatabase(PadDatabase): else: - uniform_step = len(data)/np.float(n_samples+1) + uniform_step = len(data) / np.float(n_samples + 1) - idxs = [int(np.round(uniform_step*(x+1))) for x in range(n_samples)] + idxs = [int(np.round(uniform_step * (x + 1))) + for x in range(n_samples)] selected_data = [data[idx] for idx in idxs] return selected_data - #========================================================================== + # ========================================================================= def get_files_given_single_group(self, groups=None, protocol=None, @@ -580,7 +581,8 @@ class AggregatedDbPadDatabase(PadDatabase): if protocol == 'photo-photo-video': - if groups == 'train' or groups == 'devel': # the group names are low-level here: ('train', 'devel', 'test') + # the group names are low-level here: ('train', 'devel', 'test') + if groups == 'train' or groups == 'devel': replay_files = self.replay_db.objects( protocol='photo', groups=groups, cls=purposes, **kwargs) @@ -618,7 +620,8 @@ class AggregatedDbPadDatabase(PadDatabase): if protocol == 'video-video-photo': - if groups == 'train' or groups == 'devel': # the group names are low-level here: ('train', 'devel', 'test') + # the group names are low-level here: ('train', 'devel', 'test') + if groups == 'train' or groups == 'devel': replay_files = self.replay_db.objects( protocol='video', groups=groups, cls=purposes, **kwargs) @@ -670,53 +673,59 @@ class AggregatedDbPadDatabase(PadDatabase): mobio_files = self.get_mobio_files_given_single_group( groups=groups, purposes=purposes) - if 'grandtest-train-eval' in protocol: + if protocol is not None: - if groups == 'train': + if 'grandtest-train-eval' in protocol: - replay_files = self.replay_db.objects( - protocol='grandtest', - groups=['train', 'devel'], - cls=purposes, - **kwargs) + if groups == 'train': - replaymobile_files = self.replaymobile_db.objects( - protocol='grandtest', - groups=['train', 'devel'], - cls=purposes, - **kwargs) + replay_files = self.replay_db.objects( + protocol='grandtest', + groups=['train', 'devel'], + cls=purposes, + **kwargs) - msu_mfsd_files = self.msu_mfsd_db.objects( - group=['train', 'devel'], cls=purposes, **kwargs) + replaymobile_files = self.replaymobile_db.objects( + protocol='grandtest', + groups=['train', 'devel'], + cls=purposes, + **kwargs) - if len(protocol) > len('grandtest-train-eval'): + msu_mfsd_files = self.msu_mfsd_db.objects( + group=['train', 'devel'], cls=purposes, **kwargs) - num_train_samples = [int(s) for s in protocol.split("-") if s.isdigit()][-1] + if len(protocol) > len('grandtest-train-eval'): - replay_files = self.uniform_select_list_elements(data = replay_files, n_samples = num_train_samples) - replaymobile_files = self.uniform_select_list_elements(data = replaymobile_files, n_samples = num_train_samples) - msu_mfsd_files = self.uniform_select_list_elements(data = msu_mfsd_files, n_samples = num_train_samples) + num_train_samples = [ + int(s) for s in protocol.split("-") if s.isdigit()][-1] - if groups in ['devel', 'test']: + replay_files = self.uniform_select_list_elements( + data=replay_files, n_samples=num_train_samples) + replaymobile_files = self.uniform_select_list_elements( + data=replaymobile_files, n_samples=num_train_samples) + msu_mfsd_files = self.uniform_select_list_elements( + data=msu_mfsd_files, n_samples=num_train_samples) - replay_files = self.replay_db.objects( - protocol='grandtest', - groups='test', - cls=purposes, - **kwargs) + if groups in ['devel', 'test']: - replaymobile_files = self.replaymobile_db.objects( - protocol='grandtest', - groups='test', - cls=purposes, - **kwargs) + replay_files = self.replay_db.objects( + protocol='grandtest', + groups='test', + cls=purposes, + **kwargs) - msu_mfsd_files = self.msu_mfsd_db.objects( - group='test', cls=purposes, **kwargs) + replaymobile_files = self.replaymobile_db.objects( + protocol='grandtest', + groups='test', + cls=purposes, + **kwargs) + + msu_mfsd_files = self.msu_mfsd_db.objects( + group='test', cls=purposes, **kwargs) return replay_files, replaymobile_files, msu_mfsd_files, mobio_files - #========================================================================== + # ========================================================================= def get_files_given_groups(self, groups=None, protocol=None, @@ -841,7 +850,7 @@ class AggregatedDbPadDatabase(PadDatabase): return replay_files, replaymobile_files, msu_mfsd_files, mobio_files - #========================================================================== + # ========================================================================= def objects(self, groups=None, protocol=None, @@ -934,26 +943,22 @@ class AggregatedDbPadDatabase(PadDatabase): model_ids=model_ids, **kwargs) - # replay_files = self.replay_db.objects(protocol=protocol, groups=groups, cls=purposes, **kwargs) - # - # replaymobile_files = self.replaymobile_db.objects(protocol=protocol, groups=groups, cls=purposes, **kwargs) - # - # msu_mfsd_files = self.msu_mfsd_db.objects(group=groups, cls=purposes, **kwargs) - - files = replay_files + replaymobile_files + msu_mfsd_files + mobio_files # append all files to a single list + files = replay_files + replaymobile_files + msu_mfsd_files + \ + mobio_files # append all files to a single list files = [AggregatedDbPadFile(f) for f in files] return files - #========================================================================== + # ========================================================================= def annotations(self, f): """ Return annotations for a given file object ``f``, which is an instance of ``AggregatedDbPadFile`` defined in the HLDI of the Aggregated DB. The ``load()`` method of ``AggregatedDbPadFile`` class (see above) returns a video, therefore this method returns bounding-box annotations - for each video frame. The annotations are returned as dictionary of dictionaries. + for each video frame. The annotations are returned as dictionary of + dictionaries. **Parameters:**