Commit f979eb7d authored by Olegs NIKISINS's avatar Olegs NIKISINS
Browse files

Fixed the map_labels in the utils, fixing the unit tests

parent 720dd3ce
Pipeline #26193 passed with stage
in 7 minutes and 59 seconds
...@@ -26,7 +26,7 @@ class FaceCropper(): ...@@ -26,7 +26,7 @@ class FaceCropper():
cropped = self.face_cropper(sample['image'], sample['eyes']) cropped = self.face_cropper(sample['image'], sample['eyes'])
sample['image'] = cropped sample['image'] = cropped
return sample return sample
class RollChannels(object): class RollChannels(object):
""" """
...@@ -41,7 +41,7 @@ class RollChannels(object): ...@@ -41,7 +41,7 @@ class RollChannels(object):
class ToTensor(object): class ToTensor(object):
def __init__(self): def __init__(self):
self.op = transforms.ToTensor() self.op = transforms.ToTensor()
def __call__(self, sample): def __call__(self, sample):
sample['image'] = self.op(sample['image']) sample['image'] = self.op(sample['image'])
return sample return sample
...@@ -70,14 +70,14 @@ class Resize(object): ...@@ -70,14 +70,14 @@ class Resize(object):
def map_labels(raw_labels, start_index=0): def map_labels(raw_labels, start_index=0):
""" """
Map the ID label to [0 - # of IDs] Map the ID label to [0 - # of IDs]
""" """
possible_labels = list(set(raw_labels)) possible_labels = list(set(raw_labels))
labels = numpy.array(raw_labels) labels = numpy.array(raw_labels)
for i in range(len(possible_labels)): for i in range(len(possible_labels)):
l = possible_labels[i] l = possible_labels[i]
labels[numpy.where(labels==l)[0]] = i + start_index labels[numpy.where(labels==l)[0][0]] = i + start_index
# ----- # -----
# map back to native int, resolve the problem with dataset concatenation # map back to native int, resolve the problem with dataset concatenation
...@@ -86,7 +86,7 @@ def map_labels(raw_labels, start_index=0): ...@@ -86,7 +86,7 @@ def map_labels(raw_labels, start_index=0):
labels_int = [] labels_int = []
for i in range(len(labels)): for i in range(len(labels)):
labels_int.append(labels[i].item()) labels_int.append(labels[i].item())
return labels_int return labels_int
...@@ -105,12 +105,12 @@ class ConcatDataset(Dataset): ...@@ -105,12 +105,12 @@ class ConcatDataset(Dataset):
The list of datasets (as torch.utils.data.Dataset) The list of datasets (as torch.utils.data.Dataset)
""" """
def __init__(self, datasets): def __init__(self, datasets):
self.transform = datasets[0].transform self.transform = datasets[0].transform
self.data_files = sum((d.data_files for d in datasets), []) self.data_files = sum((d.data_files for d in datasets), [])
self.pose_labels = sum((d.pose_labels for d in datasets), []) self.pose_labels = sum((d.pose_labels for d in datasets), [])
self.id_labels = sum((d.id_labels for d in datasets), []) self.id_labels = sum((d.id_labels for d in datasets), [])
def __len__(self): def __len__(self):
""" """
return the length of the dataset (i.e. nb of examples) return the length of the dataset (i.e. nb of examples)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment