Commit 51345d11 authored by Olegs NIKISINS's avatar Olegs NIKISINS
Browse files

Updated the docs of DataFolder class, simplified the __getitem__ method

parent 63609e20
Pipeline #26506 passed with stage
in 8 minutes and 5 seconds
......@@ -19,8 +19,6 @@ import numpy as np
from torchvision import transforms
import torch
import h5py
......@@ -35,7 +33,7 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
files : [File]
A list of files objects defined in the High Level Database Interface
of the particular datbase.
of the particular database.
data_folder : str
A directory containing the training data.
......@@ -92,17 +90,34 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
class DataFolder(data.Dataset):
"""
A generic data loader compatible with Bob High Level Database Interfaces
(HLDI). Only HLDI's of bob.pad.face are currently supported.
(HLDI). Only HLDI's of ``bob.pad.face`` are currently supported.
The basic functionality is composed of two steps: load the data from hdf5
file, and transform it using user defined transformation function.
Two types of user defined transformations are supported:
1. An instance of ``Compose`` transformation class from ``torchvision``
package.
2. A custom transformation function, which takes numpy.ndarray as input,
and returns a transformed Tensor. The dimensionality of the output tensor
must match the format expected by the network to be trained.
Note: if no special transformation is needed, the ``transform``
must at least convert an input numpy array to Tensor.
Attributes
----------
data_folder : str
A directory containing the training data.
A directory containing the training data. Note, that the training data
must be stored as a FrameContainers written to the hdf5 files. Other
formats are currently not supported.
transform : object
A function/transform that takes in a numpy.ndarray image, and returns a
transformed version. E.g, ``transforms.RandomCrop``. Default: None.
A function ``transform`` takes an input numpy.ndarray sample/image,
and returns a transformed version as a Tensor. Default: None.
extension : str
Extension of the data files. Default: ".hdf5".
......@@ -154,8 +169,8 @@ class DataFolder(data.Dataset):
A directory containing the training data.
transform : object
A custom function/transform (torchvision.transforms.Compose) that takes in a numpy.ndarray image , and returns a
transformed version. E.g, ``transforms.RandomCrop``. Default: None.
A function ``transform`` takes an input numpy.ndarray sample/image,
and returns a transformed version as a Tensor. Default: None.
extension : str
Extension of the data files. Default: ".hdf5".
......@@ -226,20 +241,21 @@ class DataFolder(data.Dataset):
#==========================================================================
def __getitem__(self, index):
"""
Returns an image, possibly transformed, and a target class given index.
Returns a **transformed** sample/image and a target class, given index.
Two types of transformations are handled, see the doc-string of the
class.
Attributes
----------
index : int.
index : int
An index of the sample to return.
Returns
-------
np_img : Tensor image
If ``self.transform`` is defined as a custom function, the output is the torch.Tensor,
otherwise the last transform in the transforms.Compose should be transforms.ToTensor().
np_img : Tensor
Transformed sample.
target : int
Index of the class.
......@@ -273,7 +289,8 @@ class DataFolder(data.Dataset):
img_array_transformed = self.transform(img_array)
return torch.Tensor(img_array_transformed).unsqueeze(0), target # convert array to Tensor, also return target
return img_array_transformed, target
# NOTE: make sure ``img_array_transformed`` converted to Tensor in your custom ``transform`` function.
return np_img, target
......@@ -289,6 +306,3 @@ class DataFolder(data.Dataset):
"""
return len(self.file_names_labels_keys)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment