diff --git a/bob/learn/tensorflow/datashuffler/DrGanDisk.py b/bob/learn/tensorflow/datashuffler/DrGanDisk.py index 1ab46a96f10a389a01cbdaf29d2eb6730907d7d2..b7f079be88f38a3edefa2d491c1847ef6e985a3a 100644 --- a/bob/learn/tensorflow/datashuffler/DrGanDisk.py +++ b/bob/learn/tensorflow/datashuffler/DrGanDisk.py @@ -25,9 +25,11 @@ class DrGanDisk(Base): data: Input data - labels: - List of list containing labels - (we consider several labels per example) + id_labels: + id labels of the retrieved faces. + + pose_labels: + pose labels of the retrieved faces. input_shape: The shape of the inputs @@ -110,6 +112,19 @@ class DrGanDisk(Base): return self.data_placeholder, self.id_label_placeholder, self.pose_label_placeholder def load_from_file(self, file_name): + """load_from_file(file_name) -> data + + Load an image from a file, and rescale it if it does not fit the input data format + Optionnally, data augmentation is performed. + + **Parameters** + file_name: path + The name of the (image) file to load. + + **Returns** + data: numpy array + The image data + """ d = bob.io.base.load(file_name) # Applying the data augmentation @@ -132,6 +147,22 @@ class DrGanDisk(Base): def get_batch(self): + """get_batch() -> selected_data, selected_pose_labels, selected_id_labels + + This function selects and returns data to be used in a minibatch iteration. + Note that returned data is randomly selected in the training set + + **Returns** + + selected_data: + The face images. + + selected_pose_labels: + The pose labels + + selected_id_labels: + The id labels + """ # Shuffling samples indexes = numpy.array(range(self.data.shape[0])) @@ -155,13 +186,30 @@ class DrGanDisk(Base): def get_batch_epoch(self): + """get_batch_epoch() -> selected_data, selected_pose_labels, selected_id_labels + + This function selects and returns data to be used in a minibatch iterations. + Note that it works in epochs, i.e. all the training data should be seen + during one epoch, which consists in several minibatch iterations. + + **Returns** + selected_data: + The face images. + + selected_pose_labels: + The pose labels + + selected_id_labels: + The id labels + """ # this is done to rebuild the whole list (i.e. at the end of one epoch) epoch_done = False # returned mini-batch selected_data = numpy.zeros(shape=self.shape) - selected_labels = [] + selected_id_labels = [] + selected_pose_labels = [] # if there is not enough available data to fill the current mini-batch # add randomly some examples THAT ARE NOT STILL PRESENT in the dataset ! @@ -202,18 +250,22 @@ class DrGanDisk(Base): selected_data[i, ...] = self.normalizer(selected_data[i, ...]) # label - selected_labels.append(self.labels[self.indexes[current_index]]) + selected_id_labels.append(self.id_labels[self.indexes[current_index]]) + selected_pose_labels.append(self.pose_labels[self.indexes[current_index]]) # remove this example from the training set - used once in the epoch new_indexes = numpy.delete(self.indexes, current_index) self.indexes = new_indexes - if isinstance(selected_labels, list): - selected_labels = numpy.array(selected_labels) + if isinstance(selected_id_labels, list): + selected_id_labels = numpy.array(selected_id_labels) + + if isinstance(selected_pose_labels, list): + selected_pose_labels = numpy.array(selected_pose_labels) # rebuild whole randomly shuffled training dataset if epoch_done: self.indexes = numpy.array(range(self.data.shape[0])) numpy.random.shuffle(self.indexes) - return [selected_data.astype("float32"), selected_labels.astype("int64"), epoch_done] + return [selected_data.astype("float32"), selected_id_labels.astype("int64"), selected_pose_labels.astype("int64"), epoch_done]