From c2c4286d6168fb3bd6baeb00c624ea8b2a83a8ce Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Wed, 12 Apr 2023 11:28:26 +0200
Subject: [PATCH] Removed unused arguments

---
 src/ptbench/engine/trainer.py |  6 ++----
 src/ptbench/scripts/train.py  | 10 ----------
 2 files changed, 2 insertions(+), 14 deletions(-)

diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index cf8fc101..bb7fedab 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -153,7 +153,6 @@ def run(
     data_loader,
     valid_loader,
     extra_valid_loaders,
-    optimizer,
     checkpoint_period,
     device,
     arguments,
@@ -187,8 +186,6 @@ def run(
         an extra column with the loss of every dataset in this list is kept on
         the final training log.
 
-    optimizer : :py:mod:`torch.optim`
-
     checkpoint_period : int
         save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
         not save intermediary checkpoints
@@ -227,7 +224,7 @@ def run(
     tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard")
 
     resource_monitor = ResourceMonitor(
-        interval=5.0,
+        interval=monitoring_interval,
         has_gpu=(device.type == "cuda"),
         main_pid=os.getpid(),
         logging_level=logging.ERROR,
@@ -254,6 +251,7 @@ def run(
             accelerator="auto",
             devices="auto",
             max_epochs=max_epoch,
+            accumulate_grad_batches=batch_chunk_count,
             logger=[csv_logger, tensorboard_logger],
             check_val_every_n_epoch=1,
             callbacks=[LoggingCallback(resource_monitor), checkpoint_callback],
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index ea82f5a7..64f66a6a 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -125,12 +125,6 @@ def set_reproducible_cuda():
     required=True,
     cls=ResourceOption,
 )
-@click.option(
-    "--optimizer",
-    help="A torch.optim.Optimizer that will be used to train the network",
-    required=True,
-    cls=ResourceOption,
-)
 @click.option(
     "--criterion",
     help="A loss function to compute the CNN error for every sample "
@@ -291,7 +285,6 @@ def set_reproducible_cuda():
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def train(
     model,
-    optimizer,
     output_folder,
     epochs,
     batch_size,
@@ -481,7 +474,6 @@ def train(
         logger.info(f"Z-normalization with mean {mean} and std {std}")
 
     arguments = {}
-    arguments["epoch"] = 0
     arguments["max_epoch"] = epochs
 
     last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
@@ -518,14 +510,12 @@ def train(
             raise FileNotFoundError(f"Could not find checkpoint {resume_from}")
 
     logger.info("Training for {} epochs".format(arguments["max_epoch"]))
-    logger.info("Continuing from epoch {}".format(arguments["epoch"]))
 
     run(
         model=model,
         data_loader=data_loader,
         valid_loader=valid_loader,
         extra_valid_loaders=extra_valid_loaders,
-        optimizer=optimizer,
         checkpoint_period=checkpoint_period,
         device=device,
         arguments=arguments,
-- 
GitLab