Skip to content
Snippets Groups Projects
Commit 5303e516 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH Committed by Anjith GEORGE
Browse files

[script] corrected the call to GenericTrainer in train_generic script: removed...

[script] corrected the call to GenericTrainer in train_generic script: removed the batch_size argument
parent f9c86339
Branches
Tags
1 merge request!32Generic trainer cleaned
...@@ -147,7 +147,8 @@ def main(user_input=None): ...@@ -147,7 +147,8 @@ def main(user_input=None):
# train the network # train the network
if hasattr(configuration, 'network'): if hasattr(configuration, 'network'):
trainer = GenericTrainer(configuration.network, configuration.optimizer,configuration.compute_loss,learning_rate=learning_rate, batch_size=batch_size, device=device, verbosity_level=verbosity_level,tf_logdir=output_dir+'/tf_logs',do_crossvalidation=do_crossvalidation, save_interval=save_interval) trainer = GenericTrainer(configuration.network, configuration.optimizer,configuration.compute_loss,learning_rate=learning_rate,
device=device, verbosity_level=verbosity_level,tf_logdir=output_dir+'/tf_logs',do_crossvalidation=do_crossvalidation, save_interval=save_interval)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir, model=model) trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir, model=model)
else: else:
logger.error("Please provide a network in your configuration file !") logger.error("Please provide a network in your configuration file !")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment