Commit 0980593a authored by Guillaume HEUSCH's avatar Guillaume HEUSCH

[test] the the CNNTrainer using LightCNN9

parent e698cd3b
Pipeline #26312 passed with stage
in 6 minutes and 27 seconds
......@@ -124,7 +124,7 @@ class DummyDataSet(Dataset):
def __len__(self):
return 100
def __getitem__(self, idx):
data = numpy.random.rand(3, 128, 128).astype("float32")
data = numpy.random.rand(1, 128, 128).astype("float32")
label = numpy.random.randint(20)
sample = {'image': torch.from_numpy(data), 'label': label}
return sample
......@@ -132,8 +132,8 @@ class DummyDataSet(Dataset):
def test_CNNtrainer():
from ..architectures import CNN8
net = CNN8(20)
from ..architectures import LightCNN9
net = LightCNN9(20)
dataloader = torch.utils.data.DataLoader(DummyDataSet(), batch_size=32, shuffle=True)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment