Skip to content
Snippets Groups Projects
Commit 45119335 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[doc] worked on the user guide, add list to nitpick exceptions

parent 10d0f1b2
No related branches found
No related tags found
No related merge requests found
#!/usr/bin/env python
# encoding: utf-8
import os
import torch
import numpy
from torch.utils.data import Dataset
import bob.db.fargo
import bob.io.base
import bob.io.image
from .utils import map_labels
class FargoDataset(Dataset):
"""Fargo dataset.
Class representing the FARGO dataset.
Note that it only consider access to training data.
**Parameters**
root-dir: path
The path to the data
annotation_dir: path
The path to the annotation (eyes center)
transform: torchvision.transforms
The transform(s) to apply to the face images
"""
def __init__(self, root_dir, annotation_dir, transform=None):
self.root_dir = root_dir
self.annotation_dir = annotation_dir
self.transform = transform
self.data_files = []
self.annotations_files = []
self.pose_labels = []
id_labels = []
db = bob.db.fargo.Database()
protocol = 'public_MC_RGB'
objs = db.objects(protocol=protocol, groups=['world'])
objs += db.objects(protocol=protocol, groups=['dev'], purpose='enroll')
for obj in objs:
self.data_files.append(os.path.join(root_dir, obj.path) + ".png")
self.annotations_files.append(os.path.join(annotation_dir, obj.path) + ".pos")
identity = int(obj.path.split('/')[0])
id_labels.append(identity)
self.pose_labels.append(6) # all frontal
self.id_labels = map_labels(id_labels)
def __len__(self):
"""
return the length of the dataset (i.e. nb of examples)
"""
return len(self.data_files)
def __getitem__(self, idx):
"""
return a sample from the dataset
"""
eyesfiles = open(self.annotations_files[idx], 'r')
eyes = eyesfiles.readline()
reyex, reyey, leyex, leyey = eyes.split()
eyes={'reye' : (int(reyey), int(reyex)), 'leye' : (int(leyey), int(leyex))}
identity = self.id_labels[idx]
pose = self.pose_labels[idx]
image = bob.io.base.load(self.data_files[idx])
sample = {'image': image, 'eyes': eyes, 'id': identity, 'pose': pose}
if self.transform:
sample = self.transform(sample)
return sample
py:class torch.nn.modules.module.Module
py:class torch.utils.data.dataset.Dataset
py:obj list
......@@ -4,12 +4,20 @@
As the name suggest, this package makes heavy use of pytorch_, so make sure you have it installed on your environment.
It also relies on bob_ (and in particular for I/O and databases interfaces), so you may want to refer
to their respective documentation as well.
to their respective documentation as well. In particular, the following assumes that you have a conda environment
with bob installed (see `installation instructions <https://www.idiap.ch/software/bob/docs/bob/docs/stable/bob/doc/install.html>`_)
and that you cloned and built this package:
.. code-block:: bash
$ git clone git@gitlab.idiap.ch:bob/bob.learn.pytorch.git
$ cd bob.learn.pytorch
$ buildout
Anatomy of the package
----------------------
This package is basically organized as follows (some files are omitted for clarity purposes):
.. code-block:: text
......@@ -23,7 +31,6 @@ This package is basically organized as follows (some files are omitted for clari
+-- ...
+-- datasets/
+-- casia_webface.py
+-- fargo.py
+-- multipie.py
+-- ...
+-- scripts/
......@@ -33,16 +40,100 @@ This package is basically organized as follows (some files are omitted for clari
+-- CNNTrainer.py
+ ``architectures`` contains files defining the different network architectures. The network is defined as a derived class of ``torch.nn.Module`` and must contain a ``forward`` method.
See for instance the simple examples provided in `PyTorch documentation <https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html>_`.
+ ``architectures`` contains files defining the different network architectures. The network is defined as a derived class of :py:class:`torch.nn.Module` and must contain a ``forward`` method. See for instance the simple examples provided in `PyTorch documentation <https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html>`_.
+ ``datasets`` contains files implementing datasets as :py:class:`torch.utils.data.Dataset`. The dataset classes are meant to provide a bridge between bob's databases (i.e. ``bob.db.*``) and the dataset used by PyTorch. To see an example on how this could be achieved, have a look at ``bob/learn/pytorch/datasets/multipie.py``. This directory also contains some utility functions, such as wrappers around some `torchvision.transforms <https://pytorch.org/docs/stable/torchvision/transforms.html>`_ (note that this is needed since, some built-in ``torchvision.transforms`` ignore the labels).
+ ``scripts`` contains the various scripts to perform training. More on that below.
+ ``trainers`` contains classes implementing training of networks. At the moment, there is only one trainer, which will train CNN for visual recognition. Note that the trainer may depend on the specific model (GANs, ...).
Training a network
------------------
Now let's move to a concrete example. Let's say that the goal is to train a simple network to perform face recognition on the AT&T database.
First, you'll have to get the dataset, you can download it `here <http://www.cl.cam.ac.uk/Research/DTG/attarchive/pub/data/att_faces.zip>`_. Assuming that you downloaded the zip archive
in your current directory, do the following:
.. code-block:: bash
$ mkdir atnt
$ unzip att_faces.zip -d atnt/
$ rm atnt/README
The training script ``./bin/train_cnn.py`` has a config file as argument.
The config file should specify **at least** a dataset and a network (other parameters, such as the batch size could also be provided this way).
Let's define the dataset: here we'll define it directly in the config file, but you can of course implement it in a separate file and import it in the config file.
The following is heavily inspired by what is described in the `PyTorch tutorial on data laoding and processing <https://pytorch.org/tutorials/beginner/data_loading_tutorial.html>`_.
Have a look there for more details.
.. code-block:: python
import os
import numpy
import bob.io.base
import bob.io.image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from bob.learn.pytorch.datasets.utils import map_labels
from bob.learn.pytorch.datasets import ToTensor
from bob.learn.pytorch.datasets import Normalize
from bob.learn.pytorch.datasets import Resize
class AtntDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.data_files = []
id_labels = []
for root, dirs, files in os.walk(self.root_dir):
for name in files:
filename = os.path.split(os.path.join(root, name))[-1]
path = root.split(os.sep)
subject = int(path[-1].replace('s', ''))
self.data_files.append(os.path.join(root, name))
id_labels.append(subject)
self.id_labels = map_labels(id_labels)
def __len__(self):
"""
return the length of the dataset (i.e. nb of examples)
"""
return len(self.data_files)
def __getitem__(self, idx):
"""
return a sample from the dataset
"""
image = bob.io.base.load(self.data_files[idx])
# add an empty dimension so that the array is HxWxC (as expected by PyTorch)
image = image[..., numpy.newaxis]
identity = self.id_labels[idx]
sample = {'image': image, 'label': identity}
if self.transform:
sample = self.transform(sample)
return sample
+ ``datasets`` contains files implementing datasets as ``torch.utils.data.DataSet``,
and some utility functions (wrapper around torch.transforms for instance)
+ ``scripts`` contains the various scripts to perform training.
dataset = AtntDataset(root_dir='./atnt',
transform=transforms.Compose([
Resize((32, 32)),
ToTensor(),
Normalize((0.5,), (0.5,))
])
)
+ ``trainers`` contains files implementing the different training procedure
Now that we have a dataset,
.. _bob: http://idiap.github.io/bob/
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment