Commit 613ecd6a authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[script] added the number of classes in the CNNtrainer init

parent 79cecbd2
......@@ -103,14 +103,15 @@ def main(user_input=None):
# get data
if hasattr(configuration, 'dataset'):
dataloader = torch.utils.data.DataLoader(configuration.dataset, batch_size=batch_size, shuffle=True)
logger.info("There are {} training images from {} categories".format(len(configuration.dataset), numpy.max(configuration.dataset.id_labels)))
num_classes = numpy.max(configuration.dataset.id_labels)
logger.info("There are {} training images from {} categories".format(len(configuration.dataset), num_classes))
else:
logger.error("Please provide a dataset in your configuration file !")
sys.exit()
# train the network
if hasattr(configuration, 'network'):
trainer = CNNTrainer(configuration.network, batch_size=batch_size, use_gpu=use_gpu, verbosity_level=verbosity_level)
trainer = CNNTrainer(configuration.network, batch_size=batch_size, use_gpu=use_gpu, verbosity_level=verbosity_level, num_classes=num_classes)
trainer.train(dataloader, n_epochs=epochs, learning_rate=learning_rate, output_dir=output_dir, model=model)
else:
logger.error("Please provide a network in your configuration file !")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment