From 79efeb3962c900e5e819ba3ffb474e2696325e06 Mon Sep 17 00:00:00 2001 From: Guillaume HEUSCH <guillaume.heusch@idiap.ch> Date: Fri, 5 Apr 2019 15:23:54 +0200 Subject: [PATCH] [dataset] first commit with base dataset class, to handle BIO and PAD databases --- bob/learn/pytorch/datasets/__init__.py | 1 + bob/learn/pytorch/datasets/base.py | 94 ++++++++++++++++++++++++++ bob/learn/pytorch/scripts/test_data.py | 69 +++++++++++++++++++ setup.py | 1 + 4 files changed, 165 insertions(+) create mode 100644 bob/learn/pytorch/datasets/base.py create mode 100644 bob/learn/pytorch/scripts/test_data.py diff --git a/bob/learn/pytorch/datasets/__init__.py b/bob/learn/pytorch/datasets/__init__.py index a14a0b8..5695850 100644 --- a/bob/learn/pytorch/datasets/__init__.py +++ b/bob/learn/pytorch/datasets/__init__.py @@ -1,6 +1,7 @@ from .casia_webface import CasiaDataset from .casia_webface import CasiaWebFaceDataset from .data_folder import DataFolder +from .base import BaseDataSet # transforms from .utils import FaceCropper diff --git a/bob/learn/pytorch/datasets/base.py b/bob/learn/pytorch/datasets/base.py new file mode 100644 index 0000000..e57c29c --- /dev/null +++ b/bob/learn/pytorch/datasets/base.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# encoding: utf-8 + +import torch +import bob.pad.base +import bob.bio.base + +from torch.utils.data import Dataset + +import bob.core +logger = bob.core.log.setup("bob.learn.pytorch") + + + +class BaseDataSet(Dataset): + """ Class implementing a base dataset + to be used with PyTorch + + Attributes + ---------- + data_files: + List of filenames + labels: + List of labels + + transform_fn : + function that takes a sample and returns + a transformed version. + + """ + + def __init__(self, db, transform_fn, groups=['train'], purposes=None, verbosity_level=3): + """ Init function + + Parameters + ---------- + db: :py:obj:`bob.bio.base.database.BioDatabase` or :py:obj:`bob.pad.base.database.PadDatabase` + The database instance + groups: list of str + The groups to consider (i.e. train, dev, eval) + """ + bob.core.log.set_verbosity_level(logger, verbosity_level) + if isinstance(db, bob.pad.base.database.PadDatabase): + logger.info("PAD database") + else: + logger.info("BIO database") + + self.data_files = [] + self.labels = [] + + objs = db.objects(protocol=db.protocol, groups=groups, purposes=purposes) + logger.info("{} samples will (theoretically) be added to the dataset".format(len(objs))) + for o in objs: + self.data_files.append(o.make_path(db.original_directory, db.original_extension)) + if isinstance(db, bob.pad.base.database.PadDatabase): + if o.attack_type is None: + self.labels.append(1) + else: + self.labels.append(0) + else: + self.labels.append(o.client_id) + logger.debug("Added file {} with label {}".format(self.data_files[-1], self.labels[-1])) + + + + + def __len__(self): + """Returns the length of the dataset (i.e. nb of examples) + + Returns + ------- + int + the number of examples in the dataset + + """ + return len(self.data_files) + + def __getitem__(self, idx): + """Returns a sample from the dataset + + Parameters + ---------- + idx : int + The index of the item to load + + Returns + ------- + dict + an example of the dataset, containing the + transformed face image, its identity and pose information + + """ + data = bob.io.base.load(self.data_files[idx]) + label = self.labels[idx] diff --git a/bob/learn/pytorch/scripts/test_data.py b/bob/learn/pytorch/scripts/test_data.py new file mode 100644 index 0000000..7109221 --- /dev/null +++ b/bob/learn/pytorch/scripts/test_data.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# encoding: utf-8 + +""" Test data handling + +Usage: + %(prog)s <configuration> [--verbose ...] + +Arguments: + <configuration> A configuration file, defining the dataset and the network + +Options: + -h, --help Shows this help message and exits + -v, --verbose Increase the verbosity (may appear multiple times). + +Note that arguments provided directly by command-line will override the ones in the configuration file. + +Example: + + To run the test + + $ %(prog)s config.py + +See '%(prog)s --help' for more information. + +""" + +import os, sys +import pkg_resources + +from docopt import docopt + +import bob.core +logger = bob.core.log.setup("bob.learn.pytorch") + +import torch + +from bob.extension.config import load +from bob.learn.pytorch.utils import get_parameter + +version = pkg_resources.require('bob.learn.pytorch')[0].version + +def main(user_input=None): + + # Parse the command-line arguments + if user_input is not None: + arguments = user_input + else: + arguments = sys.argv[1:] + + prog = os.path.basename(sys.argv[0]) + completions = dict(prog=prog, version=version,) + args = docopt(__doc__ % completions,argv=arguments,version='Train a CNN (%s)' % version,) + + # load configuration file + configuration = load([os.path.join(args['<configuration>'])]) + + verbosity_level = get_parameter(args, configuration, 'verbose', 0) + + bob.core.log.set_verbosity_level(logger, verbosity_level) + + batch_size = 32 + + # get data + if hasattr(configuration, 'dataset'): + dataloader = torch.utils.data.DataLoader(configuration.dataset, batch_size=batch_size, shuffle=True) + print(len(configuration.dataset)) + else: + logger.error("Please provide a dataset in your configuration file !") diff --git a/setup.py b/setup.py index 7098af0..fb7c1b2 100644 --- a/setup.py +++ b/setup.py @@ -75,6 +75,7 @@ setup( 'train_dcgan.py = bob.learn.pytorch.scripts.train_dcgan:main', 'train_conditionalgan.py = bob.learn.pytorch.scripts.train_conditionalgan:main', 'train_network.py = bob.learn.pytorch.scripts.train_network:main', + 'test_data.py = bob.learn.pytorch.scripts.test_data:main', ], }, -- GitLab