Skip to content
Snippets Groups Projects
Commit 46fc1f92 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[utils] simplify the map_labels function

parent 77b652ca
Branches
Tags
1 merge request!9Light cnn
Pipeline #26377 passed
...@@ -55,9 +55,10 @@ class CasiaWebFaceDataset(Dataset): ...@@ -55,9 +55,10 @@ class CasiaWebFaceDataset(Dataset):
subject = int(path[-1]) subject = int(path[-1])
self.data_files.append(os.path.join(root, name)) self.data_files.append(os.path.join(root, name))
id_labels.append(subject) id_labels.append(subject)
self.id_labels = map_labels(id_labels, start_index) self.id_labels = map_labels(id_labels, start_index)
def __len__(self): def __len__(self):
"""Returns the length of the dataset (i.e. nb of examples) """Returns the length of the dataset (i.e. nb of examples)
......
...@@ -97,12 +97,12 @@ def map_labels(raw_labels, start_index=0): ...@@ -97,12 +97,12 @@ def map_labels(raw_labels, start_index=0):
# map back to native int, resolve the problem with dataset concatenation # map back to native int, resolve the problem with dataset concatenation
# it does: line 78 is now ok # it does: line 78 is now ok
# for some reason, it was not working when the type of id labels were numpy.int64 ... # for some reason, it was not working when the type of id labels were numpy.int64 ...
labels_int = [] #labels_int = []
for i in range(len(labels)): #for i in range(len(labels)):
labels_int.append(labels[i].item()) # labels_int.append(labels[i].item())
return labels_int
#return labels_int
return labels
from torch.utils.data import Dataset from torch.utils.data import Dataset
import bob.io.base import bob.io.base
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment