Skip to content
Snippets Groups Projects

Gan

Closed Guillaume HEUSCH requested to merge gan into master
1 file
+ 60
8
Compare changes
  • Side-by-side
  • Inline
@@ -25,9 +25,11 @@ class DrGanDisk(Base):
@@ -25,9 +25,11 @@ class DrGanDisk(Base):
data:
data:
Input data
Input data
labels:
id_labels:
List of list containing labels
id labels of the retrieved faces.
(we consider several labels per example)
 
pose_labels:
 
pose labels of the retrieved faces.
input_shape:
input_shape:
The shape of the inputs
The shape of the inputs
@@ -110,6 +112,19 @@ class DrGanDisk(Base):
@@ -110,6 +112,19 @@ class DrGanDisk(Base):
return self.data_placeholder, self.id_label_placeholder, self.pose_label_placeholder
return self.data_placeholder, self.id_label_placeholder, self.pose_label_placeholder
def load_from_file(self, file_name):
def load_from_file(self, file_name):
 
"""load_from_file(file_name) -> data
 
 
Load an image from a file, and rescale it if it does not fit the input data format
 
Optionnally, data augmentation is performed.
 
 
**Parameters**
 
file_name: path
 
The name of the (image) file to load.
 
 
**Returns**
 
data: numpy array
 
The image data
 
"""
d = bob.io.base.load(file_name)
d = bob.io.base.load(file_name)
# Applying the data augmentation
# Applying the data augmentation
@@ -132,6 +147,22 @@ class DrGanDisk(Base):
@@ -132,6 +147,22 @@ class DrGanDisk(Base):
def get_batch(self):
def get_batch(self):
 
"""get_batch() -> selected_data, selected_pose_labels, selected_id_labels
 
 
This function selects and returns data to be used in a minibatch iteration.
 
Note that returned data is randomly selected in the training set
 
 
**Returns**
 
 
selected_data:
 
The face images.
 
 
selected_pose_labels:
 
The pose labels
 
 
selected_id_labels:
 
The id labels
 
"""
# Shuffling samples
# Shuffling samples
indexes = numpy.array(range(self.data.shape[0]))
indexes = numpy.array(range(self.data.shape[0]))
@@ -155,13 +186,30 @@ class DrGanDisk(Base):
@@ -155,13 +186,30 @@ class DrGanDisk(Base):
def get_batch_epoch(self):
def get_batch_epoch(self):
 
"""get_batch_epoch() -> selected_data, selected_pose_labels, selected_id_labels
 
 
This function selects and returns data to be used in a minibatch iterations.
 
Note that it works in epochs, i.e. all the training data should be seen
 
during one epoch, which consists in several minibatch iterations.
 
 
**Returns**
 
selected_data:
 
The face images.
 
 
selected_pose_labels:
 
The pose labels
 
 
selected_id_labels:
 
The id labels
 
"""
# this is done to rebuild the whole list (i.e. at the end of one epoch)
# this is done to rebuild the whole list (i.e. at the end of one epoch)
epoch_done = False
epoch_done = False
# returned mini-batch
# returned mini-batch
selected_data = numpy.zeros(shape=self.shape)
selected_data = numpy.zeros(shape=self.shape)
selected_labels = []
selected_id_labels = []
 
selected_pose_labels = []
# if there is not enough available data to fill the current mini-batch
# if there is not enough available data to fill the current mini-batch
# add randomly some examples THAT ARE NOT STILL PRESENT in the dataset !
# add randomly some examples THAT ARE NOT STILL PRESENT in the dataset !
@@ -202,18 +250,22 @@ class DrGanDisk(Base):
@@ -202,18 +250,22 @@ class DrGanDisk(Base):
selected_data[i, ...] = self.normalizer(selected_data[i, ...])
selected_data[i, ...] = self.normalizer(selected_data[i, ...])
# label
# label
selected_labels.append(self.labels[self.indexes[current_index]])
selected_id_labels.append(self.id_labels[self.indexes[current_index]])
 
selected_pose_labels.append(self.pose_labels[self.indexes[current_index]])
# remove this example from the training set - used once in the epoch
# remove this example from the training set - used once in the epoch
new_indexes = numpy.delete(self.indexes, current_index)
new_indexes = numpy.delete(self.indexes, current_index)
self.indexes = new_indexes
self.indexes = new_indexes
if isinstance(selected_labels, list):
if isinstance(selected_id_labels, list):
selected_labels = numpy.array(selected_labels)
selected_id_labels = numpy.array(selected_id_labels)
 
 
if isinstance(selected_pose_labels, list):
 
selected_pose_labels = numpy.array(selected_pose_labels)
# rebuild whole randomly shuffled training dataset
# rebuild whole randomly shuffled training dataset
if epoch_done:
if epoch_done:
self.indexes = numpy.array(range(self.data.shape[0]))
self.indexes = numpy.array(range(self.data.shape[0]))
numpy.random.shuffle(self.indexes)
numpy.random.shuffle(self.indexes)
return [selected_data.astype("float32"), selected_labels.astype("int64"), epoch_done]
return [selected_data.astype("float32"), selected_id_labels.astype("int64"), selected_pose_labels.astype("int64"), epoch_done]
Loading