Skip to content
Snippets Groups Projects
Commit f979eb7d authored by Olegs NIKISINS's avatar Olegs NIKISINS
Browse files

Fixed the map_labels in the utils, fixing the unit tests

parent 720dd3ce
Branches
Tags
1 merge request!6autoencoders pretraining using RGB faces
Pipeline #26193 passed
......@@ -26,7 +26,7 @@ class FaceCropper():
cropped = self.face_cropper(sample['image'], sample['eyes'])
sample['image'] = cropped
return sample
class RollChannels(object):
"""
......@@ -41,7 +41,7 @@ class RollChannels(object):
class ToTensor(object):
def __init__(self):
self.op = transforms.ToTensor()
def __call__(self, sample):
sample['image'] = self.op(sample['image'])
return sample
......@@ -70,14 +70,14 @@ class Resize(object):
def map_labels(raw_labels, start_index=0):
"""
Map the ID label to [0 - # of IDs]
Map the ID label to [0 - # of IDs]
"""
possible_labels = list(set(raw_labels))
labels = numpy.array(raw_labels)
for i in range(len(possible_labels)):
l = possible_labels[i]
labels[numpy.where(labels==l)[0]] = i + start_index
labels[numpy.where(labels==l)[0][0]] = i + start_index
# -----
# map back to native int, resolve the problem with dataset concatenation
......@@ -86,7 +86,7 @@ def map_labels(raw_labels, start_index=0):
labels_int = []
for i in range(len(labels)):
labels_int.append(labels[i].item())
return labels_int
......@@ -105,12 +105,12 @@ class ConcatDataset(Dataset):
The list of datasets (as torch.utils.data.Dataset)
"""
def __init__(self, datasets):
self.transform = datasets[0].transform
self.data_files = sum((d.data_files for d in datasets), [])
self.pose_labels = sum((d.pose_labels for d in datasets), [])
self.id_labels = sum((d.id_labels for d in datasets), [])
def __len__(self):
"""
return the length of the dataset (i.e. nb of examples)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment