diff --git a/bob/learn/pytorch/datasets/__init__.py b/bob/learn/pytorch/datasets/__init__.py
index a14a0b8fc118cb13594cd924c20eb8523a2f6e7e..0eaa91d69367fe0d7fbf3ca49dfdf5bd07c78d87 100644
--- a/bob/learn/pytorch/datasets/__init__.py
+++ b/bob/learn/pytorch/datasets/__init__.py
@@ -1,6 +1,8 @@
 from .casia_webface import CasiaDataset
 from .casia_webface import CasiaWebFaceDataset
 from .data_folder import DataFolder
+from .data_folder_generic import DataFolderGeneric
+
 
 # transforms
 from .utils import FaceCropper
diff --git a/bob/learn/pytorch/datasets/data_folder_generic.py b/bob/learn/pytorch/datasets/data_folder_generic.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7010c46d415e5b8cd8157034fd8caf531d90105
--- /dev/null
+++ b/bob/learn/pytorch/datasets/data_folder_generic.py
@@ -0,0 +1,316 @@
+#!/usr/bin/env python2
+# -*- coding: utf-8 -*-
+"""
+@author: Olegs Nikisins
+"""
+
+#==============================================================================
+# Import what is needed here:
+
+import torch.utils.data as data
+
+import os
+
+import random
+
+random.seed( a = 7 )
+
+import numpy as np
+
+from torchvision import transforms
+
+import h5py
+
+
+#==============================================================================
+def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type = "pad"):
+    """
+    Get absolute names of the corresponding file objects and their class labels,
+    as well as keys defining name of the frame to load the data from.
+
+    Attributes
+    ----------
+
+    files : [File]
+        A list of files objects defined in the High Level Database Interface
+        of the particular database.
+
+    data_folder : str
+        A directory containing the training data.
+
+    extension : str
+        Extension of the data files. Default: ".hdf5" .
+
+    hldi_type : str
+        Type of the high level database interface. Default: "pad".
+        Note: this is the only type supported at the moment.
+
+    Returns
+    -------
+
+    file_names_labels_keys : [(str, int, str)]
+        A list of tuples, where each tuple contain an absolute filename,
+        a corresponding label of the class, and a key defining the name of the
+        frame to extract the data from.
+    """
+
+    file_names_labels_keys = []
+
+    if hldi_type == "pad":
+
+        for f in files:
+
+            if f.attack_type is None:
+
+                label = 1
+
+            else:
+
+                label = 0
+
+            file_name = os.path.join(data_folder, f.path + extension)
+
+            if os.path.isfile(file_name): # if file is available:
+
+                with h5py.File(file_name, "r") as f_h5py:
+
+                    file_keys = list(f_h5py.keys())
+
+                #removes the 'FrameIndexes' key
+                file_keys=[f for f in file_keys if f!='FrameIndexes' ]
+
+                # elements of tuples in the below list are as follows:
+                # a filename a key is extracted from,
+                # a label corresponding to the file,
+                # a key defining a frame from the file.
+                file_names_labels_keys = file_names_labels_keys + [(file_name, label, key) for file_name, label, key in
+                                                                   zip([file_name]*len(file_keys), [label]*len(file_keys), file_keys)]
+
+    return file_names_labels_keys
+
+
+#==============================================================================
+class DataFolderGeneric(data.Dataset):
+    """
+    A generic data loader compatible with Bob High Level Database Interfaces
+    (HLDI). Only HLDI's of ``bob.pad.face`` are currently supported.
+
+    The basic functionality is composed of two steps: load the data from hdf5
+    file, and transform it using user defined transformation function.
+
+    Two types of user defined transformations are supported:
+
+    1. An instance of ``Compose`` transformation class from ``torchvision``
+    package.
+
+    2. A custom transformation function, which takes numpy.ndarray as input,
+    and returns a transformed Tensor. The dimensionality of the output tensor
+    must match the format expected by the network to be trained.
+
+    Note: if no special transformation is needed, the ``transform``
+    must at least convert an input numpy array to Tensor.
+
+    Attributes
+    ----------
+
+    data_folder : str
+        A directory containing the training data. Note, that the training data
+        must be stored as a FrameContainers written to the hdf5 files. Other
+        formats are currently not supported.
+
+    transform : object
+        A function ``transform`` takes an input numpy.ndarray sample/image,
+        and returns a transformed version as a Tensor. Default: None.
+
+    extension : str
+        Extension of the data files. Default: ".hdf5".
+        Note: this is the only extension supported at the moment.
+
+    bob_hldi_instance : object
+        An instance of the HLDI interface. Only HLDI's of bob.pad.face
+        are currently supported.
+
+    hldi_type : str
+        String defining the type of the HLDI. Default: "pad".
+        Note: this is the only option currently supported.
+
+    groups : str or [str]
+        The groups for which the clients should be returned.
+        Usually, groups are one or more elements of ['train', 'dev', 'eval'].
+        Default: ['train', 'dev', 'eval'].
+
+    protocol : str
+        The protocol for which the clients should be retrieved.
+        Default: 'grandtest'.
+
+    purposes : str or [str]
+        The purposes for which File objects should be retrieved.
+        Usually it is either 'real' or 'attack'.
+        Default: ['real', 'attack'].
+
+    allow_missing_files : bool
+        The missing files in the ``data_folder`` will not break the
+        execution if set to True.
+        Default: True.
+    """
+
+    def __init__(self, data_folder,
+                 transform = None,
+                 extension = '.hdf5',
+                 bob_hldi_instance = None,
+                 hldi_type = "pad",
+                 groups = ['train', 'dev', 'eval'],
+                 protocol = 'grandtest',
+                 purposes=['real', 'attack'],
+                 allow_missing_files = True,custom_func=None,
+                 **kwargs):
+        """
+        Attributes
+        ----------
+
+        data_folder : str
+            A directory containing the training data.
+
+        transform : object
+            A function ``transform`` takes an input numpy.ndarray sample/image,
+            and returns a transformed version as a Tensor. Default: None.
+
+        extension : str
+            Extension of the data files. Default: ".hdf5".
+            Note: this is the only extension supported at the moment.
+
+        bob_hldi_instance : object
+            An instance of the HLDI interface. Only HLDI's of bob.pad.face
+            are currently supported.
+
+        hldi_type : str
+            String defining the type of the HLDI. Default: "pad".
+            Note: this is the only option currently supported.
+
+        groups : str or [str]
+            The groups for which the clients should be returned.
+            Usually, groups are one or more elements of ['train', 'dev', 'eval'].
+            Default: ['train', 'dev', 'eval'].
+
+        protocol : str
+            The protocol for which the clients should be retrieved.
+            Default: 'grandtest'.
+
+        purposes : str or [str]
+            The purposes for which File objects should be retrieved.
+            Usually it is either 'real' or 'attack'.
+            Default: ['real', 'attack'].
+
+        allow_missing_files : bool
+            The missing files in the ``data_folder`` will not break the
+            execution if set to True.
+            Default: True.
+        """
+
+        self.data_folder = data_folder
+        self.transform = transform
+        self.extension = extension
+        self.bob_hldi_instance = bob_hldi_instance
+        self.hldi_type = hldi_type
+        self.groups = groups
+        self.protocol = protocol
+        self.purposes = purposes
+        self.allow_missing_files = allow_missing_files
+        self.custom_func = custom_func
+
+        if bob_hldi_instance is not None:
+
+            files = bob_hldi_instance.objects(groups = self.groups,
+                                              protocol = self.protocol,
+                                              purposes = self.purposes,
+                                              **kwargs)
+
+            file_names_labels_keys = get_file_names_and_labels(files = files,
+                                        data_folder = self.data_folder,
+                                        extension = self.extension,
+                                        hldi_type = self.hldi_type)
+
+            if self.allow_missing_files: # return only existing files
+
+                file_names_labels_keys = [f for f in file_names_labels_keys if os.path.isfile(f[0])]
+
+        else:
+
+            # TODO - add behaviour similar to image folder
+            file_names_labels_keys = []
+
+        self.file_names_labels_keys = file_names_labels_keys
+
+
+    #==========================================================================
+    def __getitem__(self, index):
+        """
+        Returns a **transformed** sample/image and a target class, given index.
+        Two types of transformations are handled, see the doc-string of the
+        class.
+
+        Attributes
+        ----------
+
+        index : int
+            An index of the sample to return.
+
+        Returns
+        -------
+
+        np_img : Tensor
+            Transformed sample.
+
+        target : int
+            Index of the class.
+        """
+
+        path, target, key = self.file_names_labels_keys[index]
+
+        with h5py.File(path, "r") as f_h5py:
+
+            img_array = np.array(f_h5py.get(key+'/array')) # The size now is (3 x W x H)
+
+        if isinstance(self.transform, transforms.Compose): # if an instance of torchvision composed transformation
+
+            if len(img_array.shape) == 3: # for color or multi-channel images
+
+                img_array_tr = np.swapaxes(img_array, 1, 2)
+                img_array_tr = np.swapaxes(img_array_tr, 0, 2)
+
+                np_img =img_array_tr.copy()  # np_img is numpy.ndarray of shape HxWxC
+
+            else: # for gray-scale images
+
+                np_img=np.expand_dims(img_array_tr,2) # np_img is numpy.ndarray of size HxWx1
+
+
+            if self.transform is not None:
+
+                np_img = self.transform(np_img) # after this transformation np_img should be a tensor
+
+        else: # if custom transformation function is given
+
+            img_array_transformed = self.transform(img_array)
+
+            return img_array_transformed, target
+            # NOTE: make sure ``img_array_transformed`` converted to Tensor in your custom ``transform`` function.
+
+        if self.custom_func is not None: # custom function to change the return to something else
+
+            return self.custom_func(np_img,target)
+
+        return np_img, target
+
+
+    #==========================================================================
+    def __len__(self):
+        """
+        Returns
+        -------
+
+        len : int
+            The length of the file list.
+        """
+        return len(self.file_names_labels_keys)
+
diff --git a/bob/learn/pytorch/scripts/train_generic.py b/bob/learn/pytorch/scripts/train_generic.py
new file mode 100644
index 0000000000000000000000000000000000000000..8607bcd10c50061828621c9a2b1ebfa0538a1357
--- /dev/null
+++ b/bob/learn/pytorch/scripts/train_generic.py
@@ -0,0 +1,154 @@
+#!/usr/bin/env python
+# encoding: utf-8
+
+
+""" Train a Generic Net
+
+Usage:
+  %(prog)s <configuration> 
+            [--model=<string>] [--batch-size=<int>] [--num-workers=<int>][--epochs=<int>] [--save-interval=<int>] 
+            [--learning-rate=<float>][--do-crossvalidation][--seed=<int>] 
+            [--output-dir=<path>] [--use-gpu] [--verbose ...]
+
+Arguments:
+  <configuration>  A configuration file, defining the dataset and the network
+
+Options:
+  -h, --help                            Shows this help message and exits
+      --model=<string>                  Filename of the model to load (if any). 
+      --batch-size=<int>                Batch size [default: 64]
+      --num-workers=<int>               Number subprocesses to use for data loading [default: 0]
+      --epochs=<int>                    Number of training epochs [default: 20]
+      --save-interval=<int>             Interval between saving epochs [default: 5]
+      --learning-rate=<float>           Learning rate [default: 0.01]
+      --do-crossvalidation              Whether to perform cross validation [default: False]
+  -S, --seed=<int>                      The random seed [default: 3] 
+  -o, --output-dir=<path>               Dir to save stuff [default: training]
+  -g, --use-gpu                         Use the GPU
+  -v, --verbose                         Increase the verbosity (may appear multiple times).
+
+Note that arguments provided directly by command-line will override the ones in the configuration file.
+
+Example:
+
+  To run the training process 
+
+    $ %(prog)s config.py 
+
+See '%(prog)s --help' for more information.
+
+"""
+
+import os, sys
+import pkg_resources
+
+import torch
+import numpy
+from docopt import docopt
+
+import bob.core
+logger = bob.core.log.setup("bob.learn.pytorch")
+
+from bob.extension.config import load
+from bob.learn.pytorch.trainers import GenericTrainer
+from bob.learn.pytorch.utils import get_parameter
+
+version = pkg_resources.require('bob.learn.pytorch')[0].version
+
+def main(user_input=None):
+  
+  # Parse the command-line arguments
+  if user_input is not None:
+      arguments = user_input
+  else:
+      arguments = sys.argv[1:]
+
+  prog = os.path.basename(sys.argv[0])
+  completions = dict(prog=prog, version=version,)
+  args = docopt(__doc__ % completions,argv=arguments,version='Train a Generic Network  (%s)' % version,)
+
+  # load configuration file
+  configuration = load([os.path.join(args['<configuration>'])])
+  
+  # get the pre-trained model file, if any
+  model = args['--model']
+  if hasattr(configuration, 'model'):
+    model = configuration.model
+  
+  # get various parameters, either from config file or command-line 
+  batch_size = get_parameter(args, configuration, 'batch_size', 64)
+  num_workers = get_parameter(args, configuration, 'num_workers', 0)
+  epochs = get_parameter(args, configuration, 'epochs', 20)
+  save_interval = get_parameter(args, configuration, 'save_interval', 5)
+  learning_rate = get_parameter(args, configuration, 'learning_rate', 0.01)
+  seed = get_parameter(args, configuration, 'seed', 3)
+  output_dir = get_parameter(args, configuration, 'output_dir', 'training')
+  use_gpu = get_parameter(args, configuration, 'use_gpu', False)
+  verbosity_level = get_parameter(args, configuration, 'verbose', 0)
+  do_crossvalidation  = get_parameter(args, configuration, 'do_crossvalidation', False)
+
+  
+  bob.core.log.set_verbosity_level(logger, verbosity_level)
+  bob.io.base.create_directories_safe(output_dir)
+
+  # print parameters
+  logger.debug("Model file = {}".format(model))
+  logger.debug("Batch size = {}".format(batch_size))
+  logger.debug("Num workers = {}".format(num_workers))
+  logger.debug("Epochs = {}".format(epochs))
+  logger.debug("Save interval = {}".format(save_interval))
+  logger.debug("Learning rate = {}".format(learning_rate))
+  logger.debug("Seed = {}".format(seed))
+  logger.debug("Output directory = {}".format(output_dir))
+  logger.debug("Use GPU = {}".format(use_gpu))
+  logger.debug("Perform cross validation = {}".format(do_crossvalidation))
+
+  # use new interface
+
+  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+
+  # process on the arguments / options
+  torch.manual_seed(seed)
+  if use_gpu:
+    torch.cuda.manual_seed_all(seed)
+  if torch.cuda.is_available() and not use_gpu:
+    device="cpu"
+    logger.warn("You have a CUDA device, so you should probably run with --use-gpu")
+
+  logger.debug("Device used for training = {}".format(device))
+
+
+  # Which device to use is figured out at this point, no need to use `use-gpu` flag anymore
+  # get data
+  if hasattr(configuration, 'dataset'):
+    
+    dataloader={}
+
+    if not do_crossvalidation:
+
+      logger.info("There are {} training samples".format(len(configuration.dataset['train'])))
+
+      dataloader['train'] = torch.utils.data.DataLoader(configuration.dataset['train'], batch_size=batch_size, num_workers=num_workers, shuffle=True)
+
+    else:
+
+      dataloader['train'] = torch.utils.data.DataLoader(configuration.dataset['train'], batch_size=batch_size, num_workers=num_workers, shuffle=True)
+      dataloader['val'] = torch.utils.data.DataLoader(configuration.dataset['val'], batch_size=batch_size, num_workers=num_workers, shuffle=True)
+
+      logger.info("There are {} training samples".format(len(configuration.dataset['train'])))
+      logger.info("There are {} validation samples".format(len(configuration.dataset['val'])))
+    
+  else:
+    logger.error("Please provide a dataset in your configuration file !")
+    sys.exit()
+
+  assert(hasattr(configuration, 'optimizer'))
+  
+  # train the network
+  if hasattr(configuration, 'network'):
+    trainer = GenericTrainer(configuration.network, configuration.optimizer,configuration.compute_loss,learning_rate=learning_rate, batch_size=batch_size, device=device, verbosity_level=verbosity_level,tf_logdir=output_dir+'/tf_logs',do_crossvalidation=do_crossvalidation, save_interval=save_interval)
+    trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir, model=model)
+  else:
+    logger.error("Please provide a network in your configuration file !")
+    sys.exit()
diff --git a/bob/learn/pytorch/test/test.py b/bob/learn/pytorch/test/test.py
index 7a2de3bfd4e22532ac25ae222aa556c629a22046..16ae124935299ab518e1b943321892f44871df19 100644
--- a/bob/learn/pytorch/test/test.py
+++ b/bob/learn/pytorch/test/test.py
@@ -233,6 +233,29 @@ def test_CNNtrainer():
 
   os.remove('model_1_0.pth')
 
+def test_Generictrainer():
+
+  from ..architectures import LightCNN9
+  net = LightCNN9(20)
+
+  dataloader={}
+  dataloader['train'] = torch.utils.data.DataLoader(DummyDataSet(), batch_size=32, shuffle=True)
+  dataloader['val']= torch.utils.data.DataLoader(DummyDataSet(), batch_size=32, shuffle=True)
+
+  optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()),lr = 0.1)
+
+  criterion = torch.nn.BCELoss()
+
+
+
+  from ..trainers import GenericTrainer
+  trainer = GenericTrainer(net, verbosity_level=3)
+  trainer.train(dataloader, n_epochs=1, output_dir='.')
+
+  import os
+  assert os.path.isfile('model_1_0.pth')
+
+  os.remove('model_1_0.pth')
 class DummyDataSetMCCNN(Dataset):
   def __init__(self):
     pass
diff --git a/bob/learn/pytorch/trainers/GenericTrainer.py b/bob/learn/pytorch/trainers/GenericTrainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..51b16bc0a353c3205abf14f0af024607f7b87868
--- /dev/null
+++ b/bob/learn/pytorch/trainers/GenericTrainer.py
@@ -0,0 +1,297 @@
+#!/usr/bin/env python
+# encoding: utf-8
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.autograd import Variable
+from .tflog import Logger 
+
+import bob.core
+logger = bob.core.log.setup("bob.learn.pytorch")
+
+import time
+import os
+
+import copy
+
+class GenericTrainer(object):
+	"""
+	Class to train a generic NN; all the parameters are provided in configs
+
+	Attributes
+	----------
+	network: :py:class:`torch.nn.Module`
+		The network to train
+	batch_size: int
+		The size of your minibatch
+	optimizer: :py:class:`torch.optim.Optimizer`
+		Optimizer object to be used. Initialized in the config file. 
+
+	device: str
+		Device which will be used for training the model
+	verbosity_level: int
+		The level of verbosity output to stdout
+	
+	"""
+ 
+	def __init__(self, network, optimizer, compute_loss, learning_rate=0.0001, batch_size=64, device='cpu', verbosity_level=2, tf_logdir='tf_logs',do_crossvalidation=False, save_interval=5):
+		""" Init function . The layers to be adapted in the network is selected and the gradients are set to `True` 
+		for the  layers which needs to be adapted. 
+
+		Parameters
+		----------
+		network: :py:class:`torch.nn.Module`
+			The network to train
+		batch_size: int
+			The size of your minibatch
+		device: str
+			Device which will be used for training the model
+		verbosity_level: int
+			The level of verbosity output to stdout
+		do_crossvalidation: bool
+			If set to `True`, performs validation in each epoch and stores the best model based on validation loss.
+		"""
+		self.network = network
+		self.batch_size = batch_size
+		self.optimizer=optimizer
+		self.compute_loss=compute_loss
+		self.device = device
+		self.learning_rate=learning_rate
+		self.save_interval=save_interval
+
+		self.do_crossvalidation=do_crossvalidation
+
+		if self.do_crossvalidation:
+			phases=['train','val']
+		else:
+			phases=['train']
+		self.phases=phases
+
+		# Move the network to device
+		self.network.to(self.device)
+
+		bob.core.log.set_verbosity_level(logger, verbosity_level)
+		
+		self.tf_logger = Logger(tf_logdir)
+
+
+		# Setting the gradients to true for the layers which needs to be adapted
+
+
+
+	def load_model(self, model_filename):
+		"""Loads an existing model
+
+		Parameters
+		----------
+		model_file: str
+			The filename of the model to load
+
+		Returns
+		-------
+		start_epoch: int
+			The epoch to start with
+		start_iteration: int
+			The iteration to start with
+		losses: list(float)
+			The list of losses from previous training 
+		
+		"""
+		
+		cp = torch.load(model_filename)
+		self.network.load_state_dict(cp['state_dict'])
+		start_epoch = cp['epoch']
+		start_iter = cp['iteration']
+		losses = cp['loss']
+		return start_epoch, start_iter, losses
+
+
+	def save_model(self, output_dir, epoch=0, iteration=0, losses=None):
+		"""Save the trained network
+
+		Parameters
+		----------
+		output_dir: str
+			The directory to write the models to
+		epoch: int
+			the current epoch
+		iteration: int
+			the current (last) iteration
+		losses: list(float)
+				The list of losses since the beginning of training 
+		
+		""" 
+		
+		saved_filename = 'model_{}_{}.pth'.format(epoch, iteration)    
+		saved_path = os.path.join(output_dir, saved_filename)    
+		logger.info('Saving model to {}'.format(saved_path))
+		cp = {'epoch': epoch, 
+					'iteration': iteration,
+					'loss': losses, 
+					'state_dict': self.network.cpu().state_dict()
+					}
+		torch.save(cp, saved_path)
+		
+		self.network.to(self.device)
+
+
+	def train(self, dataloader, n_epochs=25, output_dir='out', model=None):
+		"""Performs the training.
+
+		Parameters
+		----------
+		dataloader: :py:class:`torch.utils.data.DataLoader`
+			The dataloader for your data
+		n_epochs: int
+			The number of epochs you would like to train for
+		learning_rate: float
+			The learning rate for Adam optimizer.
+		output_dir: str
+			The directory where you would like to save models 
+		model: str
+			The path to a pretrained model file to start training from; this is the PAD model; not the LightCNN model
+
+		"""
+
+		# if model exists, load it
+		if model is not None:
+			start_epoch, start_iter, losses = self.load_model(model)
+			logger.info('Starting training at epoch {}, iteration {} - last loss value is {}'.format(start_epoch, start_iter, losses[-1]))
+		else:
+			start_epoch = 0
+			start_iter = 0
+			losses = []
+			logger.info('Starting training from scratch')
+
+
+		for name, param in  self.network.named_parameters():
+
+			if param.requires_grad == True:
+				logger.info('Layer to be adapted from grad check : {}'.format(name))
+
+		# setup optimizer
+
+		
+
+		self.network.train(True)
+
+		best_model_wts = copy.deepcopy(self.network.state_dict())
+			
+		best_loss = float("inf")
+
+		# let's go
+		for epoch in range(start_epoch, n_epochs):
+
+			# in the epoch
+
+			train_loss_history=[]
+
+			val_loss_history = []
+
+			for phase in self.phases:
+
+				if phase == 'train':
+					self.network.train()  # Set model to training mode
+				else:
+					self.network.eval()   # Set model to evaluate mode
+
+
+				for i, data in enumerate(dataloader[phase], 0):
+
+		 
+					if i >= start_iter:
+					
+						start = time.time()
+
+						# get data from dataset
+
+						img, labels = data
+						
+						self.optimizer.zero_grad()
+
+						with torch.set_grad_enabled(phase == 'train'):
+							
+							loss = self.compute_loss(self.network, img, labels, self.device)
+
+							if phase == 'train':
+
+								loss.backward()
+
+								self.optimizer.step()
+
+								train_loss_history.append(loss.item())
+							else:
+
+								val_loss_history.append(loss.item())
+
+
+						end = time.time()
+
+						logger.info("[{}/{}][{}/{}] => Loss = {} (time spent: {}), Phase {}".format(epoch, n_epochs, i, len(dataloader[phase]), loss.item(), (end-start),phase))
+
+						losses.append(loss.item())
+
+						
+			epoch_train_loss=np.mean(train_loss_history)
+
+			logger.info("Train Loss : {}  epoch : {}".format(epoch_train_loss,epoch))
+
+			if self.do_crossvalidation:
+
+				epoch_val_loss=np.mean(val_loss_history)
+
+				logger.info("Val Loss : {}  epoch : {}".format(epoch_val_loss,epoch))
+				
+				if phase == 'val' and epoch_val_loss < best_loss:
+
+					logger.debug("New val loss : {} is better than old: {}, copying over the new weights".format(epoch_val_loss,best_loss))
+					
+					best_loss = epoch_val_loss
+
+					best_model_wts = copy.deepcopy(self.network.state_dict())
+		
+
+			
+
+			########################################  <Logging> ###################################
+			if self.do_crossvalidation:
+
+				info = {'train_loss':epoch_train_loss,'val_loss':epoch_val_loss}
+			else:
+
+				info = {'train_loss':epoch_train_loss}
+			
+			# scalar logs
+			
+			for tag, value in info.items():
+				self.tf_logger.scalar_summary(tag, value, epoch+1)
+
+			# Log values and gradients of the parameters (histogram summary)
+
+			for tag, value in self.network.named_parameters():
+				tag = tag.replace('.', '/')        
+				try:          
+					self.tf_logger.histo_summary(tag, value.data.cpu().numpy(), epoch+1)
+					self.tf_logger.histo_summary(tag+'/grad', value.grad.data.cpu().numpy(), epoch+1)
+				except:
+					pass
+
+			########################################  </Logging>  ###################################  
+			
+			
+			# do stuff - like saving models
+			logger.info("EPOCH {} DONE".format(epoch+1))
+
+			# comment it out after debugging
+			
+			if (epoch+1)==n_epochs or epoch%self.save_interval==0: # save the last model, and the ones in the specified interval
+				self.save_model(output_dir, epoch=(epoch+1), iteration=0, losses=losses)
+			  
+		## load the best weights
+
+		self.network.load_state_dict(best_model_wts)
+
+		# best epoch is 0
+
+		self.save_model(output_dir, epoch=0, iteration=0, losses=losses)
diff --git a/bob/learn/pytorch/trainers/__init__.py b/bob/learn/pytorch/trainers/__init__.py
index dc8bfef189eb361a21df66296001d8107cf6386b..2493093896cdadb64d7ecfedaf96862c025eeb33 100644
--- a/bob/learn/pytorch/trainers/__init__.py
+++ b/bob/learn/pytorch/trainers/__init__.py
@@ -3,6 +3,8 @@ from .MCCNNTrainer import MCCNNTrainer
 from .DCGANTrainer import DCGANTrainer
 from .ConditionalGANTrainer import ConditionalGANTrainer
 from .FASNetTrainer import FASNetTrainer
+from .GenericTrainer import GenericTrainer
+from .GenericTrainer import GenericTrainer
 
 # gets sphinx autodoc done right - don't remove it
 __all__ = [_ for _ in dir() if not _.startswith('_')]
diff --git a/setup.py b/setup.py
index 7098af0e535276634f69a689062199793a8df81e..7c23666b01393f064b0c505f70697cf5effe2438 100644
--- a/setup.py
+++ b/setup.py
@@ -71,6 +71,7 @@ setup(
       'console_scripts' : [
         'train_cnn.py = bob.learn.pytorch.scripts.train_cnn:main',
         'train_mccnn.py = bob.learn.pytorch.scripts.train_mccnn:main',
+        'train_generic.py = bob.learn.pytorch.scripts.train_generic:main',
         'train_fasnet.py = bob.learn.pytorch.scripts.train_fasnet:main',
         'train_dcgan.py = bob.learn.pytorch.scripts.train_dcgan:main',
         'train_conditionalgan.py = bob.learn.pytorch.scripts.train_conditionalgan:main',