Skip to content
Snippets Groups Projects
Commit 889ced71 authored by Anjith GEORGE's avatar Anjith GEORGE Committed by Olegs NIKISINS
Browse files

Modified the data_folder to be more generic

parent 85834fef
No related branches found
No related tags found
1 merge request!12Modified the data_folder to be more generic
......@@ -12,9 +12,8 @@ import torch.utils.data as data
import os
import random
random.seed( a = 7 )
import PIL
random.seed( a = 7 )
import numpy as np
......@@ -102,7 +101,7 @@ class DataFolder(data.Dataset):
A directory containing the training data.
transform : object
A function/transform that takes in a PIL image, and returns a
A function/transform that takes in a numpy.ndarray image, and returns a
transformed version. E.g, ``transforms.RandomCrop``. Default: None.
extension : str
......@@ -155,7 +154,7 @@ class DataFolder(data.Dataset):
A directory containing the training data.
transform : object
A function/transform that takes in a PIL image, and returns a
A custom function/transform (torchvision.transforms.Compose) that takes in a numpy.ndarray image , and returns a
transformed version. E.g, ``transforms.RandomCrop``. Default: None.
extension : str
......@@ -238,9 +237,9 @@ class DataFolder(data.Dataset):
Returns
-------
pil_img : Tensor or PIL Image
If ``self.transform`` is defined the output is the torch.Tensor,
otherwise the output is an instance of the PIL.Image.Image class.
np_img : Tensor image
If ``self.transform`` is defined as a custom function, the output is the torch.Tensor,
otherwise the last transform in the transforms.Compose should be transforms.ToTensor().
target : int
Index of the class.
......@@ -254,20 +253,21 @@ class DataFolder(data.Dataset):
if isinstance(self.transform, transforms.Compose): # if an instance of torchvision composed transformation
if len(img_array.shape) == 3: # for color images
if len(img_array.shape) == 3: # for color or multi-channel images
img_array_tr = np.swapaxes(img_array, 1, 2)
img_array_tr = np.swapaxes(img_array_tr, 0, 2)
pil_img = PIL.Image.fromarray( img_array_tr ) # convert to PIL from array of size (H x W x 3)
np_img =img_array_tr.copy() # np_img is numpy.ndarray of shape HxWxC
else: # for gray-scale images
pil_img = PIL.Image.fromarray( img_array, 'L' ) # convert to PIL from array of size (H x W)
np_img=np.expand_dims(img_array_tr,2) # np_img is numpy.ndarray of size HxWx1
if self.transform is not None:
pil_img = self.transform(pil_img)
np_img = self.transform(np_img) # after this transformation np_img should be a tensor
else: # if custom transformation function is given
......@@ -275,7 +275,7 @@ class DataFolder(data.Dataset):
return torch.Tensor(img_array_transformed).unsqueeze(0), target # convert array to Tensor, also return target
return pil_img, target
return np_img, target
#==========================================================================
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment