Skip to content
Snippets Groups Projects
Commit c940728d authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[datasets] made sure that id labels are native python int, since otherwise it...

[datasets] made sure that id labels are native python int, since otherwise it could cause problem in the concatenation of datasets
parent fd47618d
Branches
Tags
No related merge requests found
......@@ -41,15 +41,24 @@ def map_labels(raw_labels, start_index=0):
for i in range(len(possible_labels)):
l = possible_labels[i]
labels[numpy.where(labels==l)[0]] = i + start_index
return labels
labels[numpy.where(labels==l)[0]] = i + start_index
# -----
# map back to native int, resolve the problem with dataset concatenation
# it does: line 78 is now ok
# for some reason, it was not working when the type of id labels were numpy.int64 ...
labels_int = []
for i in range(len(labels)):
labels_int.append(labels[i].item())
return labels_int
from torch.utils.data import Dataset
import bob.io.base
import bob.io.image
class ConcatDataset(Dataset):
"""
Class to concatenate two or more datasets for DR-GAN training
......@@ -64,14 +73,8 @@ class ConcatDataset(Dataset):
self.transform = datasets[0].transform
self.data_files = sum((d.data_files for d in datasets), [])
self.pose_labels = sum((d.pose_labels for d in datasets), [])
self.id_labels = sum((d.id_labels for d in datasets), [])
# TODO: for an unknown reason, the above does not work for id_labels - Guillaume HEUSCH, 01-12-2017
self.id_labels = []
for d in datasets:
self.id_labels.append(d.id_labels)
print "type of id labels = {}".format(type(d.id_labels[0]))
print len(self.id_labels)
def __len__(self):
"""
return the length of the dataset (i.e. nb of examples)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment