Skip to content
Snippets Groups Projects
Commit c2c4286d authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Removed unused arguments

parent e1215e99
No related branches found
No related tags found
1 merge request!4Moved code to lightning
...@@ -153,7 +153,6 @@ def run( ...@@ -153,7 +153,6 @@ def run(
data_loader, data_loader,
valid_loader, valid_loader,
extra_valid_loaders, extra_valid_loaders,
optimizer,
checkpoint_period, checkpoint_period,
device, device,
arguments, arguments,
...@@ -187,8 +186,6 @@ def run( ...@@ -187,8 +186,6 @@ def run(
an extra column with the loss of every dataset in this list is kept on an extra column with the loss of every dataset in this list is kept on
the final training log. the final training log.
optimizer : :py:mod:`torch.optim`
checkpoint_period : int checkpoint_period : int
save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
not save intermediary checkpoints not save intermediary checkpoints
...@@ -227,7 +224,7 @@ def run( ...@@ -227,7 +224,7 @@ def run(
tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard") tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard")
resource_monitor = ResourceMonitor( resource_monitor = ResourceMonitor(
interval=5.0, interval=monitoring_interval,
has_gpu=(device.type == "cuda"), has_gpu=(device.type == "cuda"),
main_pid=os.getpid(), main_pid=os.getpid(),
logging_level=logging.ERROR, logging_level=logging.ERROR,
...@@ -254,6 +251,7 @@ def run( ...@@ -254,6 +251,7 @@ def run(
accelerator="auto", accelerator="auto",
devices="auto", devices="auto",
max_epochs=max_epoch, max_epochs=max_epoch,
accumulate_grad_batches=batch_chunk_count,
logger=[csv_logger, tensorboard_logger], logger=[csv_logger, tensorboard_logger],
check_val_every_n_epoch=1, check_val_every_n_epoch=1,
callbacks=[LoggingCallback(resource_monitor), checkpoint_callback], callbacks=[LoggingCallback(resource_monitor), checkpoint_callback],
......
...@@ -125,12 +125,6 @@ def set_reproducible_cuda(): ...@@ -125,12 +125,6 @@ def set_reproducible_cuda():
required=True, required=True,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option(
"--optimizer",
help="A torch.optim.Optimizer that will be used to train the network",
required=True,
cls=ResourceOption,
)
@click.option( @click.option(
"--criterion", "--criterion",
help="A loss function to compute the CNN error for every sample " help="A loss function to compute the CNN error for every sample "
...@@ -291,7 +285,6 @@ def set_reproducible_cuda(): ...@@ -291,7 +285,6 @@ def set_reproducible_cuda():
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def train( def train(
model, model,
optimizer,
output_folder, output_folder,
epochs, epochs,
batch_size, batch_size,
...@@ -481,7 +474,6 @@ def train( ...@@ -481,7 +474,6 @@ def train(
logger.info(f"Z-normalization with mean {mean} and std {std}") logger.info(f"Z-normalization with mean {mean} and std {std}")
arguments = {} arguments = {}
arguments["epoch"] = 0
arguments["max_epoch"] = epochs arguments["max_epoch"] = epochs
last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt") last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
...@@ -518,14 +510,12 @@ def train( ...@@ -518,14 +510,12 @@ def train(
raise FileNotFoundError(f"Could not find checkpoint {resume_from}") raise FileNotFoundError(f"Could not find checkpoint {resume_from}")
logger.info("Training for {} epochs".format(arguments["max_epoch"])) logger.info("Training for {} epochs".format(arguments["max_epoch"]))
logger.info("Continuing from epoch {}".format(arguments["epoch"]))
run( run(
model=model, model=model,
data_loader=data_loader, data_loader=data_loader,
valid_loader=valid_loader, valid_loader=valid_loader,
extra_valid_loaders=extra_valid_loaders, extra_valid_loaders=extra_valid_loaders,
optimizer=optimizer,
checkpoint_period=checkpoint_period, checkpoint_period=checkpoint_period,
device=device, device=device,
arguments=arguments, arguments=arguments,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment