Modified the data_folder to be more generic
Compare changes
@@ -12,16 +12,13 @@ import torch.utils.data as data
@@ -12,16 +12,13 @@ import torch.utils.data as data
@@ -36,7 +33,7 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
@@ -36,7 +33,7 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
@@ -65,11 +62,11 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
@@ -65,11 +62,11 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
@@ -93,17 +90,34 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
@@ -93,17 +90,34 @@ def get_file_names_and_labels(files, data_folder, extension = ".hdf5", hldi_type
@@ -155,8 +169,8 @@ class DataFolder(data.Dataset):
@@ -155,8 +169,8 @@ class DataFolder(data.Dataset):
@@ -227,20 +241,21 @@ class DataFolder(data.Dataset):
@@ -227,20 +241,21 @@ class DataFolder(data.Dataset):
@@ -254,28 +269,30 @@ class DataFolder(data.Dataset):
@@ -254,28 +269,30 @@ class DataFolder(data.Dataset):
if isinstance(self.transform, transforms.Compose): # if an instance of torchvision composed transformation
return torch.Tensor(img_array_transformed).unsqueeze(0), target # convert array to Tensor, also return target
@@ -289,6 +306,3 @@ class DataFolder(data.Dataset):
@@ -289,6 +306,3 @@ class DataFolder(data.Dataset):