Skip to content
Snippets Groups Projects
Commit 1f2735c1 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[datashuffler] added the doc in DrGanDisk

parent 319ef692
No related branches found
No related tags found
1 merge request!8Gan
Pipeline #
......@@ -25,9 +25,11 @@ class DrGanDisk(Base):
data:
Input data
labels:
List of list containing labels
(we consider several labels per example)
id_labels:
id labels of the retrieved faces.
pose_labels:
pose labels of the retrieved faces.
input_shape:
The shape of the inputs
......@@ -110,6 +112,19 @@ class DrGanDisk(Base):
return self.data_placeholder, self.id_label_placeholder, self.pose_label_placeholder
def load_from_file(self, file_name):
"""load_from_file(file_name) -> data
Load an image from a file, and rescale it if it does not fit the input data format
Optionnally, data augmentation is performed.
**Parameters**
file_name: path
The name of the (image) file to load.
**Returns**
data: numpy array
The image data
"""
d = bob.io.base.load(file_name)
# Applying the data augmentation
......@@ -132,6 +147,22 @@ class DrGanDisk(Base):
def get_batch(self):
"""get_batch() -> selected_data, selected_pose_labels, selected_id_labels
This function selects and returns data to be used in a minibatch iteration.
Note that returned data is randomly selected in the training set
**Returns**
selected_data:
The face images.
selected_pose_labels:
The pose labels
selected_id_labels:
The id labels
"""
# Shuffling samples
indexes = numpy.array(range(self.data.shape[0]))
......@@ -155,13 +186,30 @@ class DrGanDisk(Base):
def get_batch_epoch(self):
"""get_batch_epoch() -> selected_data, selected_pose_labels, selected_id_labels
This function selects and returns data to be used in a minibatch iterations.
Note that it works in epochs, i.e. all the training data should be seen
during one epoch, which consists in several minibatch iterations.
**Returns**
selected_data:
The face images.
selected_pose_labels:
The pose labels
selected_id_labels:
The id labels
"""
# this is done to rebuild the whole list (i.e. at the end of one epoch)
epoch_done = False
# returned mini-batch
selected_data = numpy.zeros(shape=self.shape)
selected_labels = []
selected_id_labels = []
selected_pose_labels = []
# if there is not enough available data to fill the current mini-batch
# add randomly some examples THAT ARE NOT STILL PRESENT in the dataset !
......@@ -202,18 +250,22 @@ class DrGanDisk(Base):
selected_data[i, ...] = self.normalizer(selected_data[i, ...])
# label
selected_labels.append(self.labels[self.indexes[current_index]])
selected_id_labels.append(self.id_labels[self.indexes[current_index]])
selected_pose_labels.append(self.pose_labels[self.indexes[current_index]])
# remove this example from the training set - used once in the epoch
new_indexes = numpy.delete(self.indexes, current_index)
self.indexes = new_indexes
if isinstance(selected_labels, list):
selected_labels = numpy.array(selected_labels)
if isinstance(selected_id_labels, list):
selected_id_labels = numpy.array(selected_id_labels)
if isinstance(selected_pose_labels, list):
selected_pose_labels = numpy.array(selected_pose_labels)
# rebuild whole randomly shuffled training dataset
if epoch_done:
self.indexes = numpy.array(range(self.data.shape[0]))
numpy.random.shuffle(self.indexes)
return [selected_data.astype("float32"), selected_labels.astype("int64"), epoch_done]
return [selected_data.astype("float32"), selected_id_labels.astype("int64"), selected_pose_labels.astype("int64"), epoch_done]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment