diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index cf8fc1018189b05dba7a474b5c9a493244c5dc83..bb7fedab436e4ec68db632c05194db2bea61ae16 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -153,7 +153,6 @@ def run( data_loader, valid_loader, extra_valid_loaders, - optimizer, checkpoint_period, device, arguments, @@ -187,8 +186,6 @@ def run( an extra column with the loss of every dataset in this list is kept on the final training log. - optimizer : :py:mod:`torch.optim` - checkpoint_period : int save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do not save intermediary checkpoints @@ -227,7 +224,7 @@ def run( tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard") resource_monitor = ResourceMonitor( - interval=5.0, + interval=monitoring_interval, has_gpu=(device.type == "cuda"), main_pid=os.getpid(), logging_level=logging.ERROR, @@ -254,6 +251,7 @@ def run( accelerator="auto", devices="auto", max_epochs=max_epoch, + accumulate_grad_batches=batch_chunk_count, logger=[csv_logger, tensorboard_logger], check_val_every_n_epoch=1, callbacks=[LoggingCallback(resource_monitor), checkpoint_callback], diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py index ea82f5a7b91f272893812474d6b2a19551e71600..64f66a6a03c713521463907c340027948c0a07ad 100644 --- a/src/ptbench/scripts/train.py +++ b/src/ptbench/scripts/train.py @@ -125,12 +125,6 @@ def set_reproducible_cuda(): required=True, cls=ResourceOption, ) -@click.option( - "--optimizer", - help="A torch.optim.Optimizer that will be used to train the network", - required=True, - cls=ResourceOption, -) @click.option( "--criterion", help="A loss function to compute the CNN error for every sample " @@ -291,7 +285,6 @@ def set_reproducible_cuda(): @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def train( model, - optimizer, output_folder, epochs, batch_size, @@ -481,7 +474,6 @@ def train( logger.info(f"Z-normalization with mean {mean} and std {std}") arguments = {} - arguments["epoch"] = 0 arguments["max_epoch"] = epochs last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt") @@ -518,14 +510,12 @@ def train( raise FileNotFoundError(f"Could not find checkpoint {resume_from}") logger.info("Training for {} epochs".format(arguments["max_epoch"])) - logger.info("Continuing from epoch {}".format(arguments["epoch"])) run( model=model, data_loader=data_loader, valid_loader=valid_loader, extra_valid_loaders=extra_valid_loaders, - optimizer=optimizer, checkpoint_period=checkpoint_period, device=device, arguments=arguments,