@@ -49,8 +49,8 @@ This package is basically organized as follows (some files are omitted for clari
+ ``trainers`` contains classes implementing training of networks. At the moment, there is only one trainer, which will train CNN for visual recognition. Note that the trainer may depend on the specific model (GANs, ...).
Training a network
------------------
Defining and training a network on an arbitrary dataset
Now let's move to a concrete example. Let's say that the goal is to train a simple network to perform face recognition on the AT&T database.
...
...
@@ -64,8 +64,8 @@ in your current directory, do the following:
$ rm atnt/README
The training script ``./bin/train_cnn.py`` has a config file as argument.
The config file should specify **at least** a dataset and a network (other parameters, such as the batch size could also be provided this way).
Let's define the dataset: here we'll define it directly in the config file, but you can of course implement it in a separate file and import it in the config file.
The config file should specify **at least** the ``dataset`` and ``network`` variables (other parameters, such as the batch size could also be provided this way).
Let's define the dataset: here we'll define it directly in a config file, but you can of course implement it in a separate file and import it.
The following is heavily inspired by what is described in the `PyTorch tutorial on data laoding and processing <https://pytorch.org/tutorials/beginner/data_loading_tutorial.html>`_.
Have a look there for more details.
...
...
@@ -74,18 +74,54 @@ Have a look there for more details.
import os
import numpy
# load images
import bob.io.base
import bob.io.image
# to build your dataset
from torch.utils.data import Dataset
# mainly use to compose transforms (i.e. apply more than one transform to an input image)
import torchvision.transforms as transforms
from bob.learn.pytorch.datasets.utils import map_labels
# wrapper around torchvision.transforms
# turns out that the original ones are 'destroying' labels ...
from bob.learn.pytorch.datasets import ToTensor
from bob.learn.pytorch.datasets import Normalize
from bob.learn.pytorch.datasets import Resize
# to get the right number of classes (between 0 and n_classes)
from bob.learn.pytorch.datasets.utils import map_labels
class AtntDataset(Dataset):
""" Class defining the AT&T face dataset as a PyTorch Dataset
Attributes
----------
root_dir: str
The path to the raw images.
transform: :py:mod:`torchvision.transforms`
The transfrom to apply to the input image
data_files: list(str)
The list of image files.
id_labels: list(int)
The subjects' identity, for each file
"""
def __init__(self, root_dir, transform=None):
""" Init method
Parameters
----------
root_dir: str
The path to the raw images.
transform: :py:mod:`torchvision.transforms`
The transfrom to apply to the input image
"""
self.root_dir = root_dir
self.transform = transform
self.data_files = []
...
...
@@ -101,15 +137,17 @@ Have a look there for more details.
self.id_labels = map_labels(id_labels)
def __len__(self):
"""
return the length of the dataset (i.e. nb of examples)
""" Return the length of the dataset (i.e. nb of examples)
"""
return len(self.data_files)
def __getitem__(self, idx):
"""
return a sample from the dataset
""" Return a sample from the dataset
The sample consists in an image and a label (i.e. a face and an ID)
"""
image = bob.io.base.load(self.data_files[idx])
...
...
@@ -118,12 +156,13 @@ Have a look there for more details.
identity = self.id_labels[idx]
sample = {'image': image, 'label': identity}
# apply transform
if self.transform:
sample = self.transform(sample)
return sample
# instantiate the dataset
dataset = AtntDataset(root_dir='./atnt',
transform=transforms.Compose([
Resize((32, 32)),
...
...
@@ -134,7 +173,146 @@ Have a look there for more details.
Now that we have a dataset, we should define a network
Now that we have a dataset, we should define a network. Again, we'll do it directly in the configuration file, but
you can also define it in ``architectures`` and import it in your configuration. For the sake of simplicity, the
architecture is directly taken from `PyTorch tutorials <https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html>`_.
Note the slight modification at the end of the ``forward`` method: it returns both the ouput (``out``) and the
*embedding* ``x`` (which may be used as a features to describe an identity).
.. code-block:: python
import torch
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 40)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
out = self.fc3(x)
return out, x
def num_flat_features(self, x):
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features
# instantiate the network
network = Net()
Since we have both a dataset and a network define in a configuration file, we can now train the
network using the dataset. This is done by launching the following script on your terminal:
.. code-block:: bash
$ ./bin/train_cnn config.py -vvv
And the output should look like this:
.. code-block:: bash
bob.learn.pytorch@2018-05-16 10:10:47,582 -- DEBUG: Model file = None