Skip to content
Snippets Groups Projects
Commit 79efeb39 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[dataset] first commit with base dataset class, to handle BIO and PAD databases

parent 03ca71be
Branches
Tags v2.5.1
1 merge request!31WIP: Resolve "Dataset handling"
Pipeline #29011 failed
from .casia_webface import CasiaDataset
from .casia_webface import CasiaWebFaceDataset
from .data_folder import DataFolder
from .base import BaseDataSet
# transforms
from .utils import FaceCropper
......
#!/usr/bin/env python
# encoding: utf-8
import torch
import bob.pad.base
import bob.bio.base
from torch.utils.data import Dataset
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
class BaseDataSet(Dataset):
""" Class implementing a base dataset
to be used with PyTorch
Attributes
----------
data_files:
List of filenames
labels:
List of labels
transform_fn :
function that takes a sample and returns
a transformed version.
"""
def __init__(self, db, transform_fn, groups=['train'], purposes=None, verbosity_level=3):
""" Init function
Parameters
----------
db: :py:obj:`bob.bio.base.database.BioDatabase` or :py:obj:`bob.pad.base.database.PadDatabase`
The database instance
groups: list of str
The groups to consider (i.e. train, dev, eval)
"""
bob.core.log.set_verbosity_level(logger, verbosity_level)
if isinstance(db, bob.pad.base.database.PadDatabase):
logger.info("PAD database")
else:
logger.info("BIO database")
self.data_files = []
self.labels = []
objs = db.objects(protocol=db.protocol, groups=groups, purposes=purposes)
logger.info("{} samples will (theoretically) be added to the dataset".format(len(objs)))
for o in objs:
self.data_files.append(o.make_path(db.original_directory, db.original_extension))
if isinstance(db, bob.pad.base.database.PadDatabase):
if o.attack_type is None:
self.labels.append(1)
else:
self.labels.append(0)
else:
self.labels.append(o.client_id)
logger.debug("Added file {} with label {}".format(self.data_files[-1], self.labels[-1]))
def __len__(self):
"""Returns the length of the dataset (i.e. nb of examples)
Returns
-------
int
the number of examples in the dataset
"""
return len(self.data_files)
def __getitem__(self, idx):
"""Returns a sample from the dataset
Parameters
----------
idx : int
The index of the item to load
Returns
-------
dict
an example of the dataset, containing the
transformed face image, its identity and pose information
"""
data = bob.io.base.load(self.data_files[idx])
label = self.labels[idx]
#!/usr/bin/env python
# encoding: utf-8
""" Test data handling
Usage:
%(prog)s <configuration> [--verbose ...]
Arguments:
<configuration> A configuration file, defining the dataset and the network
Options:
-h, --help Shows this help message and exits
-v, --verbose Increase the verbosity (may appear multiple times).
Note that arguments provided directly by command-line will override the ones in the configuration file.
Example:
To run the test
$ %(prog)s config.py
See '%(prog)s --help' for more information.
"""
import os, sys
import pkg_resources
from docopt import docopt
import bob.core
logger = bob.core.log.setup("bob.learn.pytorch")
import torch
from bob.extension.config import load
from bob.learn.pytorch.utils import get_parameter
version = pkg_resources.require('bob.learn.pytorch')[0].version
def main(user_input=None):
# Parse the command-line arguments
if user_input is not None:
arguments = user_input
else:
arguments = sys.argv[1:]
prog = os.path.basename(sys.argv[0])
completions = dict(prog=prog, version=version,)
args = docopt(__doc__ % completions,argv=arguments,version='Train a CNN (%s)' % version,)
# load configuration file
configuration = load([os.path.join(args['<configuration>'])])
verbosity_level = get_parameter(args, configuration, 'verbose', 0)
bob.core.log.set_verbosity_level(logger, verbosity_level)
batch_size = 32
# get data
if hasattr(configuration, 'dataset'):
dataloader = torch.utils.data.DataLoader(configuration.dataset, batch_size=batch_size, shuffle=True)
print(len(configuration.dataset))
else:
logger.error("Please provide a dataset in your configuration file !")
......@@ -75,6 +75,7 @@ setup(
'train_dcgan.py = bob.learn.pytorch.scripts.train_dcgan:main',
'train_conditionalgan.py = bob.learn.pytorch.scripts.train_conditionalgan:main',
'train_network.py = bob.learn.pytorch.scripts.train_network:main',
'test_data.py = bob.learn.pytorch.scripts.test_data:main',
],
},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment