Skip to content
Snippets Groups Projects

WIP: Generic trainer

Closed Anjith GEORGE requested to merge generic_trainer into master
2 files
+ 28
38
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -99,13 +99,23 @@ def main(user_input=None):
@@ -99,13 +99,23 @@ def main(user_input=None):
logger.debug("Use GPU = {}".format(use_gpu))
logger.debug("Use GPU = {}".format(use_gpu))
logger.debug("Perform cross validation = {}".format(do_crossvalidation))
logger.debug("Perform cross validation = {}".format(do_crossvalidation))
 
# use new interface
 
 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
# process on the arguments / options
# process on the arguments / options
torch.manual_seed(seed)
torch.manual_seed(seed)
if use_gpu:
if use_gpu:
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed_all(seed)
if torch.cuda.is_available() and not use_gpu:
if torch.cuda.is_available() and not use_gpu:
 
device="cpu"
logger.warn("You have a CUDA device, so you should probably run with --use-gpu")
logger.warn("You have a CUDA device, so you should probably run with --use-gpu")
 
logger.debug("Device used for training = {}".format(device))
 
 
 
# Which device to use is figured out at this point, no need to use `use-gpu` flag anymore
# get data
# get data
if hasattr(configuration, 'dataset'):
if hasattr(configuration, 'dataset'):
@@ -131,7 +141,7 @@ def main(user_input=None):
@@ -131,7 +141,7 @@ def main(user_input=None):
# train the network
# train the network
if hasattr(configuration, 'network'):
if hasattr(configuration, 'network'):
trainer = GenericTrainer(configuration.network, configuration.optimizer,configuration.compute_loss,learning_rate=learning_rate, batch_size=batch_size, use_gpu=use_gpu, verbosity_level=verbosity_level,tf_logdir=output_dir+'/tf_logs',do_crossvalidation=do_crossvalidation)
trainer = GenericTrainer(configuration.network, configuration.optimizer,configuration.compute_loss,learning_rate=learning_rate, batch_size=batch_size, device=device, verbosity_level=verbosity_level,tf_logdir=output_dir+'/tf_logs',do_crossvalidation=do_crossvalidation)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir, model=model)
trainer.train(dataloader, n_epochs=epochs, output_dir=output_dir, model=model)
else:
else:
logger.error("Please provide a network in your configuration file !")
logger.error("Please provide a network in your configuration file !")
Loading