Skip to content
Snippets Groups Projects
Commit 0980593a authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[test] the the CNNTrainer using LightCNN9

parent e698cd3b
Branches
Tags
1 merge request!9Light cnn
Pipeline #26312 passed
......@@ -124,7 +124,7 @@ class DummyDataSet(Dataset):
def __len__(self):
return 100
def __getitem__(self, idx):
data = numpy.random.rand(3, 128, 128).astype("float32")
data = numpy.random.rand(1, 128, 128).astype("float32")
label = numpy.random.randint(20)
sample = {'image': torch.from_numpy(data), 'label': label}
return sample
......@@ -132,8 +132,8 @@ class DummyDataSet(Dataset):
def test_CNNtrainer():
from ..architectures import CNN8
net = CNN8(20)
from ..architectures import LightCNN9
net = LightCNN9(20)
dataloader = torch.utils.data.DataLoader(DummyDataSet(), batch_size=32, shuffle=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment