Skip to content
Snippets Groups Projects

Modified the data_folder to be more generic

Merged Guillaume HEUSCH requested to merge generic_loader into master
1 file
+ 12
12
Compare changes
  • Side-by-side
  • Inline
@@ -12,9 +12,8 @@ import torch.utils.data as data
@@ -12,9 +12,8 @@ import torch.utils.data as data
import os
import os
import random
import random
random.seed( a = 7 )
import PIL
random.seed( a = 7 )
import numpy as np
import numpy as np
@@ -102,7 +101,7 @@ class DataFolder(data.Dataset):
@@ -102,7 +101,7 @@ class DataFolder(data.Dataset):
A directory containing the training data.
A directory containing the training data.
transform : object
transform : object
A function/transform that takes in a PIL image, and returns a
A function/transform that takes in a numpy.ndarray image, and returns a
transformed version. E.g, ``transforms.RandomCrop``. Default: None.
transformed version. E.g, ``transforms.RandomCrop``. Default: None.
extension : str
extension : str
@@ -155,7 +154,7 @@ class DataFolder(data.Dataset):
@@ -155,7 +154,7 @@ class DataFolder(data.Dataset):
A directory containing the training data.
A directory containing the training data.
transform : object
transform : object
A function/transform that takes in a PIL image, and returns a
A custom function/transform (torchvision.transforms.Compose) that takes in a numpy.ndarray image , and returns a
transformed version. E.g, ``transforms.RandomCrop``. Default: None.
transformed version. E.g, ``transforms.RandomCrop``. Default: None.
extension : str
extension : str
@@ -238,9 +237,9 @@ class DataFolder(data.Dataset):
@@ -238,9 +237,9 @@ class DataFolder(data.Dataset):
Returns
Returns
-------
-------
pil_img : Tensor or PIL Image
np_img : Tensor image
If ``self.transform`` is defined the output is the torch.Tensor,
If ``self.transform`` is defined as a custom function, the output is the torch.Tensor,
otherwise the output is an instance of the PIL.Image.Image class.
otherwise the last transform in the transforms.Compose should be transforms.ToTensor().
target : int
target : int
Index of the class.
Index of the class.
@@ -254,20 +253,21 @@ class DataFolder(data.Dataset):
@@ -254,20 +253,21 @@ class DataFolder(data.Dataset):
if isinstance(self.transform, transforms.Compose): # if an instance of torchvision composed transformation
if isinstance(self.transform, transforms.Compose): # if an instance of torchvision composed transformation
if len(img_array.shape) == 3: # for color images
if len(img_array.shape) == 3: # for color or multi-channel images
img_array_tr = np.swapaxes(img_array, 1, 2)
img_array_tr = np.swapaxes(img_array, 1, 2)
img_array_tr = np.swapaxes(img_array_tr, 0, 2)
img_array_tr = np.swapaxes(img_array_tr, 0, 2)
pil_img = PIL.Image.fromarray( img_array_tr ) # convert to PIL from array of size (H x W x 3)
np_img =img_array_tr.copy() # np_img is numpy.ndarray of shape HxWxC
else: # for gray-scale images
else: # for gray-scale images
pil_img = PIL.Image.fromarray( img_array, 'L' ) # convert to PIL from array of size (H x W)
np_img=np.expand_dims(img_array_tr,2) # np_img is numpy.ndarray of size HxWx1
 
if self.transform is not None:
if self.transform is not None:
pil_img = self.transform(pil_img)
np_img = self.transform(np_img) # after this transformation np_img should be a tensor
else: # if custom transformation function is given
else: # if custom transformation function is given
@@ -275,7 +275,7 @@ class DataFolder(data.Dataset):
@@ -275,7 +275,7 @@ class DataFolder(data.Dataset):
return torch.Tensor(img_array_transformed).unsqueeze(0), target # convert array to Tensor, also return target
return torch.Tensor(img_array_transformed).unsqueeze(0), target # convert array to Tensor, also return target
return pil_img, target
return np_img, target
#==========================================================================
#==========================================================================
Loading