From c2c4286d6168fb3bd6baeb00c624ea8b2a83a8ce Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 12 Apr 2023 11:28:26 +0200 Subject: [PATCH] Removed unused arguments --- src/ptbench/engine/trainer.py | 6 ++---- src/ptbench/scripts/train.py | 10 ---------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index cf8fc101..bb7fedab 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -153,7 +153,6 @@ def run( data_loader, valid_loader, extra_valid_loaders, - optimizer, checkpoint_period, device, arguments, @@ -187,8 +186,6 @@ def run( an extra column with the loss of every dataset in this list is kept on the final training log. - optimizer : :py:mod:`torch.optim` - checkpoint_period : int save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do not save intermediary checkpoints @@ -227,7 +224,7 @@ def run( tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard") resource_monitor = ResourceMonitor( - interval=5.0, + interval=monitoring_interval, has_gpu=(device.type == "cuda"), main_pid=os.getpid(), logging_level=logging.ERROR, @@ -254,6 +251,7 @@ def run( accelerator="auto", devices="auto", max_epochs=max_epoch, + accumulate_grad_batches=batch_chunk_count, logger=[csv_logger, tensorboard_logger], check_val_every_n_epoch=1, callbacks=[LoggingCallback(resource_monitor), checkpoint_callback], diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index ea82f5a7..64f66a6a 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -125,12 +125,6 @@ def set_reproducible_cuda(): required=True, cls=ResourceOption, ) -@click.option( - "--optimizer", - help="A torch.optim.Optimizer that will be used to train the network", - required=True, - cls=ResourceOption, -) @click.option( "--criterion", help="A loss function to compute the CNN error for every sample " @@ -291,7 +285,6 @@ def set_reproducible_cuda(): @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def train( model, - optimizer, output_folder, epochs, batch_size, @@ -481,7 +474,6 @@ def train( logger.info(f"Z-normalization with mean {mean} and std {std}") arguments = {} - arguments["epoch"] = 0 arguments["max_epoch"] = epochs last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt") @@ -518,14 +510,12 @@ def train( raise FileNotFoundError(f"Could not find checkpoint {resume_from}") logger.info("Training for {} epochs".format(arguments["max_epoch"])) - logger.info("Continuing from epoch {}".format(arguments["epoch"])) run( model=model, data_loader=data_loader, valid_loader=valid_loader, extra_valid_loaders=extra_valid_loaders, - optimizer=optimizer, checkpoint_period=checkpoint_period, device=device, arguments=arguments, -- GitLab