Commit dd580d02 authored by Anjith GEORGE's avatar Anjith GEORGE

cleanup

parent e60850f3
Pipeline #29857 passed with stage
in 65 minutes and 43 seconds
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
import h5py
from torchvision import transforms
import numpy as np
"""
@author: Olegs Nikisins
"""
#==============================================================================
# ==============================================================================
# Import what is needed here:
import torch.utils.data as data
......@@ -13,17 +16,11 @@ import os
import random
random.seed( a = 7 )
random.seed(a=7)
import numpy as np
from torchvision import transforms
import h5py
#==============================================================================
def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type = "pad"):
# ==============================================================================
def get_file_names_and_labels(files, data_folder, extension=".hdf5", hldi_type="pad"):
"""
Get absolute names of the corresponding file objects and their class labels,
as well as keys defining name of the frame to load the data from.
......@@ -70,14 +67,14 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
file_name = os.path.join(data_folder, f.path + extension)
if os.path.isfile(file_name): # if file is available:
if os.path.isfile(file_name): # if file is available:
with h5py.File(file_name, "r") as f_h5py:
file_keys = list(f_h5py.keys())
#removes the 'FrameIndexes' key
file_keys=[f for f in file_keys if f!='FrameIndexes' ]
# removes the 'FrameIndexes' key
file_keys = [f for f in file_keys if f != 'FrameIndexes']
# elements of tuples in the below list are as follows:
# a filename a key is extracted from,
......@@ -89,7 +86,7 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
return file_names_labels_keys
#==============================================================================
# ==============================================================================
class DataFolderGeneric(data.Dataset):
"""
A generic data loader compatible with Bob High Level Database Interfaces
......@@ -155,14 +152,14 @@ class DataFolderGeneric(data.Dataset):
"""
def __init__(self, data_folder,
transform = None,
extension = '.hdf5',
bob_hldi_instance = None,
hldi_type = "pad",
groups = ['train', 'dev', 'eval'],
protocol = 'grandtest',
transform=None,
extension='.hdf5',
bob_hldi_instance=None,
hldi_type="pad",
groups=['train', 'dev', 'eval'],
protocol='grandtest',
purposes=['real', 'attack'],
allow_missing_files = True,custom_func=None,
allow_missing_files=True, custom_func=None,
**kwargs):
"""
Attributes
......@@ -220,19 +217,20 @@ class DataFolderGeneric(data.Dataset):
if bob_hldi_instance is not None:
files = bob_hldi_instance.objects(groups = self.groups,
protocol = self.protocol,
purposes = self.purposes,
files = bob_hldi_instance.objects(groups=self.groups,
protocol=self.protocol,
purposes=self.purposes,
**kwargs)
file_names_labels_keys = get_file_names_and_labels(files = files,
data_folder = self.data_folder,
extension = self.extension,
hldi_type = self.hldi_type)
file_names_labels_keys = get_file_names_and_labels(files=files,
data_folder=self.data_folder,
extension=self.extension,
hldi_type=self.hldi_type)
if self.allow_missing_files: # return only existing files
if self.allow_missing_files: # return only existing files
file_names_labels_keys = [f for f in file_names_labels_keys if os.path.isfile(f[0])]
file_names_labels_keys = [
f for f in file_names_labels_keys if os.path.isfile(f[0])]
else:
......@@ -241,8 +239,8 @@ class DataFolderGeneric(data.Dataset):
self.file_names_labels_keys = file_names_labels_keys
# ==========================================================================
#==========================================================================
def __getitem__(self, index):
"""
Returns a **transformed** sample/image and a target class, given index.
......@@ -269,41 +267,44 @@ class DataFolderGeneric(data.Dataset):
with h5py.File(path, "r") as f_h5py:
img_array = np.array(f_h5py.get(key+'/array')) # The size now is (3 x W x H)
# The size now is (3 x W x H)
img_array = np.array(f_h5py.get(key+'/array'))
if isinstance(self.transform, transforms.Compose): # if an instance of torchvision composed transformation
# if an instance of torchvision composed transformation
if isinstance(self.transform, transforms.Compose):
if len(img_array.shape) == 3: # for color or multi-channel images
if len(img_array.shape) == 3: # for color or multi-channel images
img_array_tr = np.swapaxes(img_array, 1, 2)
img_array_tr = np.swapaxes(img_array_tr, 0, 2)
np_img =img_array_tr.copy() # np_img is numpy.ndarray of shape HxWxC
np_img = img_array_tr.copy() # np_img is numpy.ndarray of shape HxWxC
else: # for gray-scale images
np_img=np.expand_dims(img_array_tr,2) # np_img is numpy.ndarray of size HxWx1
else: # for gray-scale images
# np_img is numpy.ndarray of size HxWx1
np_img = np.expand_dims(img_array_tr, 2)
if self.transform is not None:
np_img = self.transform(np_img) # after this transformation np_img should be a tensor
# after this transformation np_img should be a tensor
np_img = self.transform(np_img)
else: # if custom transformation function is given
else: # if custom transformation function is given
img_array_transformed = self.transform(img_array)
return img_array_transformed, target
# NOTE: make sure ``img_array_transformed`` converted to Tensor in your custom ``transform`` function.
if self.custom_func is not None: # custom function to change the return to something else
if self.custom_func is not None: # custom function to change the return to something else
return self.custom_func(np_img,target)
return self.custom_func(np_img, target)
return np_img, target
# ==========================================================================
#==========================================================================
def __len__(self):
"""
Returns
......@@ -313,4 +314,3 @@ class DataFolderGeneric(data.Dataset):
The length of the file list.
"""
return len(self.file_names_labels_keys)
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment