diff --git a/bob/learn/pytorch/datasets/__init__.py b/bob/learn/pytorch/datasets/__init__.py index a14a0b8fc118cb13594cd924c20eb8523a2f6e7e..5695850766f86cac48ba539850cbe626ee09110c 100644 --- a/bob/learn/pytorch/datasets/__init__.py +++ b/bob/learn/pytorch/datasets/__init__.py @@ -1,6 +1,7 @@ from .casia_webface import CasiaDataset from .casia_webface import CasiaWebFaceDataset from .data_folder import DataFolder +from .base import BaseDataSet # transforms from .utils import FaceCropper diff --git a/bob/learn/pytorch/datasets/base.py b/bob/learn/pytorch/datasets/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e57c29ca2c8416c7091dc1a2b0b895f0eb2f2b3e --- /dev/null +++ b/bob/learn/pytorch/datasets/base.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# encoding: utf-8 + +import torch +import bob.pad.base +import bob.bio.base + +from torch.utils.data import Dataset + +import bob.core +logger = bob.core.log.setup("bob.learn.pytorch") + + + +class BaseDataSet(Dataset): + """ Class implementing a base dataset + to be used with PyTorch + + Attributes + ---------- + data_files: + List of filenames + labels: + List of labels + + transform_fn : + function that takes a sample and returns + a transformed version. + + """ + + def __init__(self, db, transform_fn, groups=['train'], purposes=None, verbosity_level=3): + """ Init function + + Parameters + ---------- + db: :py:obj:`bob.bio.base.database.BioDatabase` or :py:obj:`bob.pad.base.database.PadDatabase` + The database instance + groups: list of str + The groups to consider (i.e. train, dev, eval) + """ + bob.core.log.set_verbosity_level(logger, verbosity_level) + if isinstance(db, bob.pad.base.database.PadDatabase): + logger.info("PAD database") + else: + logger.info("BIO database") + + self.data_files = [] + self.labels = [] + + objs = db.objects(protocol=db.protocol, groups=groups, purposes=purposes) + logger.info("{} samples will (theoretically) be added to the dataset".format(len(objs))) + for o in objs: + self.data_files.append(o.make_path(db.original_directory, db.original_extension)) + if isinstance(db, bob.pad.base.database.PadDatabase): + if o.attack_type is None: + self.labels.append(1) + else: + self.labels.append(0) + else: + self.labels.append(o.client_id) + logger.debug("Added file {} with label {}".format(self.data_files[-1], self.labels[-1])) + + + + + def __len__(self): + """Returns the length of the dataset (i.e. nb of examples) + + Returns + ------- + int + the number of examples in the dataset + + """ + return len(self.data_files) + + def __getitem__(self, idx): + """Returns a sample from the dataset + + Parameters + ---------- + idx : int + The index of the item to load + + Returns + ------- + dict + an example of the dataset, containing the + transformed face image, its identity and pose information + + """ + data = bob.io.base.load(self.data_files[idx]) + label = self.labels[idx] diff --git a/bob/learn/pytorch/scripts/test_data.py b/bob/learn/pytorch/scripts/test_data.py new file mode 100644 index 0000000000000000000000000000000000000000..7109221d4f703c7c3b0931e55177819d7e8f7e1f --- /dev/null +++ b/bob/learn/pytorch/scripts/test_data.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# encoding: utf-8 + +""" Test data handling + +Usage: + %(prog)s <configuration> [--verbose ...] + +Arguments: + <configuration> A configuration file, defining the dataset and the network + +Options: + -h, --help Shows this help message and exits + -v, --verbose Increase the verbosity (may appear multiple times). + +Note that arguments provided directly by command-line will override the ones in the configuration file. + +Example: + + To run the test + + $ %(prog)s config.py + +See '%(prog)s --help' for more information. + +""" + +import os, sys +import pkg_resources + +from docopt import docopt + +import bob.core +logger = bob.core.log.setup("bob.learn.pytorch") + +import torch + +from bob.extension.config import load +from bob.learn.pytorch.utils import get_parameter + +version = pkg_resources.require('bob.learn.pytorch')[0].version + +def main(user_input=None): + + # Parse the command-line arguments + if user_input is not None: + arguments = user_input + else: + arguments = sys.argv[1:] + + prog = os.path.basename(sys.argv[0]) + completions = dict(prog=prog, version=version,) + args = docopt(__doc__ % completions,argv=arguments,version='Train a CNN (%s)' % version,) + + # load configuration file + configuration = load([os.path.join(args['<configuration>'])]) + + verbosity_level = get_parameter(args, configuration, 'verbose', 0) + + bob.core.log.set_verbosity_level(logger, verbosity_level) + + batch_size = 32 + + # get data + if hasattr(configuration, 'dataset'): + dataloader = torch.utils.data.DataLoader(configuration.dataset, batch_size=batch_size, shuffle=True) + print(len(configuration.dataset)) + else: + logger.error("Please provide a dataset in your configuration file !") diff --git a/setup.py b/setup.py index 7098af0e535276634f69a689062199793a8df81e..fb7c1b2ecf1675c2b0f0256caebf807cb404f170 100644 --- a/setup.py +++ b/setup.py @@ -75,6 +75,7 @@ setup( 'train_dcgan.py = bob.learn.pytorch.scripts.train_dcgan:main', 'train_conditionalgan.py = bob.learn.pytorch.scripts.train_conditionalgan:main', 'train_network.py = bob.learn.pytorch.scripts.train_network:main', + 'test_data.py = bob.learn.pytorch.scripts.test_data:main', ], },