This package is basically organized as follows (some files are omitted for clarity purposes):
.. code-block:: text
...
...
@@ -23,7 +31,6 @@ This package is basically organized as follows (some files are omitted for clari
+-- ...
+-- datasets/
+-- casia_webface.py
+-- fargo.py
+-- multipie.py
+-- ...
+-- scripts/
...
...
@@ -33,16 +40,100 @@ This package is basically organized as follows (some files are omitted for clari
+-- CNNTrainer.py
+ ``architectures`` contains files defining the different network architectures. The network is defined as a derived class of ``torch.nn.Module`` and must contain a ``forward`` method.
See for instance the simple examples provided in `PyTorch documentation <https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html>_`.
+ ``architectures`` contains files defining the different network architectures. The network is defined as a derived class of :py:class:`torch.nn.Module` and must contain a ``forward`` method. See for instance the simple examples provided in `PyTorch documentation <https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html>`_.
+ ``datasets`` contains files implementing datasets as :py:class:`torch.utils.data.Dataset`. The dataset classes are meant to provide a bridge between bob's databases (i.e. ``bob.db.*``) and the dataset used by PyTorch. To see an example on how this could be achieved, have a look at ``bob/learn/pytorch/datasets/multipie.py``. This directory also contains some utility functions, such as wrappers around some `torchvision.transforms <https://pytorch.org/docs/stable/torchvision/transforms.html>`_ (note that this is needed since, some built-in ``torchvision.transforms`` ignore the labels).
+ ``scripts`` contains the various scripts to perform training. More on that below.
+ ``trainers`` contains classes implementing training of networks. At the moment, there is only one trainer, which will train CNN for visual recognition. Note that the trainer may depend on the specific model (GANs, ...).
Training a network
------------------
Now let's move to a concrete example. Let's say that the goal is to train a simple network to perform face recognition on the AT&T database.
First, you'll have to get the dataset, you can download it `here <http://www.cl.cam.ac.uk/Research/DTG/attarchive/pub/data/att_faces.zip>`_. Assuming that you downloaded the zip archive
in your current directory, do the following:
.. code-block:: bash
$ mkdir atnt
$ unzip att_faces.zip -d atnt/
$ rm atnt/README
The training script ``./bin/train_cnn.py`` has a config file as argument.
The config file should specify **at least** a dataset and a network (other parameters, such as the batch size could also be provided this way).
Let's define the dataset: here we'll define it directly in the config file, but you can of course implement it in a separate file and import it in the config file.
The following is heavily inspired by what is described in the `PyTorch tutorial on data laoding and processing <https://pytorch.org/tutorials/beginner/data_loading_tutorial.html>`_.
Have a look there for more details.
.. code-block:: python
import os
import numpy
import bob.io.base
import bob.io.image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from bob.learn.pytorch.datasets.utils import map_labels