Skip to content
Snippets Groups Projects

Modified the data_folder to be more generic

Merged Guillaume HEUSCH requested to merge generic_loader into master
1 file
+ 41
27
Compare changes
  • Side-by-side
  • Inline
@@ -12,16 +12,13 @@ import torch.utils.data as data
@@ -12,16 +12,13 @@ import torch.utils.data as data
import os
import os
import random
import random
random.seed( a = 7 )
import PIL
random.seed( a = 7 )
import numpy as np
import numpy as np
from torchvision import transforms
from torchvision import transforms
import torch
import h5py
import h5py
@@ -36,7 +33,7 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
@@ -36,7 +33,7 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
files : [File]
files : [File]
A list of files objects defined in the High Level Database Interface
A list of files objects defined in the High Level Database Interface
of the particular datbase.
of the particular database.
data_folder : str
data_folder : str
A directory containing the training data.
A directory containing the training data.
@@ -65,11 +62,11 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
@@ -65,11 +62,11 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
if f.attack_type is None:
if f.attack_type is None:
label = 0
label = 1
else:
else:
label = 1
label = 0
file_name = os.path.join(data_folder, f.path + extension)
file_name = os.path.join(data_folder, f.path + extension)
@@ -93,17 +90,34 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
@@ -93,17 +90,34 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
class DataFolder(data.Dataset):
class DataFolder(data.Dataset):
"""
"""
A generic data loader compatible with Bob High Level Database Interfaces
A generic data loader compatible with Bob High Level Database Interfaces
(HLDI). Only HLDI's of bob.pad.face are currently supported.
(HLDI). Only HLDI's of ``bob.pad.face`` are currently supported.
 
 
The basic functionality is composed of two steps: load the data from hdf5
 
file, and transform it using user defined transformation function.
 
 
Two types of user defined transformations are supported:
 
 
1. An instance of ``Compose`` transformation class from ``torchvision``
 
package.
 
 
2. A custom transformation function, which takes numpy.ndarray as input,
 
and returns a transformed Tensor. The dimensionality of the output tensor
 
must match the format expected by the network to be trained.
 
 
Note: if no special transformation is needed, the ``transform``
 
must at least convert an input numpy array to Tensor.
Attributes
Attributes
----------
----------
data_folder : str
data_folder : str
A directory containing the training data.
A directory containing the training data. Note, that the training data
 
must be stored as a FrameContainers written to the hdf5 files. Other
 
formats are currently not supported.
transform : object
transform : object
A function/transform that takes in a PIL image, and returns a
A function ``transform`` takes an input numpy.ndarray sample/image,
transformed version. E.g, ``transforms.RandomCrop``. Default: None.
and returns a transformed version as a Tensor. Default: None.
extension : str
extension : str
Extension of the data files. Default: ".hdf5".
Extension of the data files. Default: ".hdf5".
@@ -155,8 +169,8 @@ class DataFolder(data.Dataset):
@@ -155,8 +169,8 @@ class DataFolder(data.Dataset):
A directory containing the training data.
A directory containing the training data.
transform : object
transform : object
A function/transform that takes in a PIL image, and returns a
A function ``transform`` takes an input numpy.ndarray sample/image,
transformed version. E.g, ``transforms.RandomCrop``. Default: None.
and returns a transformed version as a Tensor. Default: None.
extension : str
extension : str
Extension of the data files. Default: ".hdf5".
Extension of the data files. Default: ".hdf5".
@@ -227,20 +241,21 @@ class DataFolder(data.Dataset):
@@ -227,20 +241,21 @@ class DataFolder(data.Dataset):
#==========================================================================
#==========================================================================
def __getitem__(self, index):
def __getitem__(self, index):
"""
"""
Returns an image, possibly transformed, and a target class given index.
Returns a **transformed** sample/image and a target class, given index.
 
Two types of transformations are handled, see the doc-string of the
 
class.
Attributes
Attributes
----------
----------
index : int.
index : int
An index of the sample to return.
An index of the sample to return.
Returns
Returns
-------
-------
pil_img : Tensor or PIL Image
np_img : Tensor
If ``self.transform`` is defined the output is the torch.Tensor,
Transformed sample.
otherwise the output is an instance of the PIL.Image.Image class.
target : int
target : int
Index of the class.
Index of the class.
@@ -254,28 +269,30 @@ class DataFolder(data.Dataset):
@@ -254,28 +269,30 @@ class DataFolder(data.Dataset):
if isinstance(self.transform, transforms.Compose): # if an instance of torchvision composed transformation
if isinstance(self.transform, transforms.Compose): # if an instance of torchvision composed transformation
if len(img_array.shape) == 3: # for color images
if len(img_array.shape) == 3: # for color or multi-channel images
img_array_tr = np.swapaxes(img_array, 1, 2)
img_array_tr = np.swapaxes(img_array, 1, 2)
img_array_tr = np.swapaxes(img_array_tr, 0, 2)
img_array_tr = np.swapaxes(img_array_tr, 0, 2)
pil_img = PIL.Image.fromarray( img_array_tr ) # convert to PIL from array of size (H x W x 3)
np_img =img_array_tr.copy() # np_img is numpy.ndarray of shape HxWxC
else: # for gray-scale images
else: # for gray-scale images
pil_img = PIL.Image.fromarray( img_array, 'L' ) # convert to PIL from array of size (H x W)
np_img=np.expand_dims(img_array_tr,2) # np_img is numpy.ndarray of size HxWx1
 
if self.transform is not None:
if self.transform is not None:
pil_img = self.transform(pil_img)
np_img = self.transform(np_img) # after this transformation np_img should be a tensor
else: # if custom transformation function is given
else: # if custom transformation function is given
img_array_transformed = self.transform(img_array)
img_array_transformed = self.transform(img_array)
return torch.Tensor(img_array_transformed).unsqueeze(0), target # convert array to Tensor, also return target
return img_array_transformed, target
 
# NOTE: make sure ``img_array_transformed`` converted to Tensor in your custom ``transform`` function.
return pil_img, target
return np_img, target
#==========================================================================
#==========================================================================
@@ -289,6 +306,3 @@ class DataFolder(data.Dataset):
@@ -289,6 +306,3 @@ class DataFolder(data.Dataset):
"""
"""
return len(self.file_names_labels_keys)
return len(self.file_names_labels_keys)
Loading