Skip to content
Snippets Groups Projects
Commit c2c4286d authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Removed unused arguments

parent e1215e99
No related branches found
No related tags found
1 merge request!4Moved code to lightning
......@@ -153,7 +153,6 @@ def run(
data_loader,
valid_loader,
extra_valid_loaders,
optimizer,
checkpoint_period,
device,
arguments,
......@@ -187,8 +186,6 @@ def run(
an extra column with the loss of every dataset in this list is kept on
the final training log.
optimizer : :py:mod:`torch.optim`
checkpoint_period : int
save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
not save intermediary checkpoints
......@@ -227,7 +224,7 @@ def run(
tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard")
resource_monitor = ResourceMonitor(
interval=5.0,
interval=monitoring_interval,
has_gpu=(device.type == "cuda"),
main_pid=os.getpid(),
logging_level=logging.ERROR,
......@@ -254,6 +251,7 @@ def run(
accelerator="auto",
devices="auto",
max_epochs=max_epoch,
accumulate_grad_batches=batch_chunk_count,
logger=[csv_logger, tensorboard_logger],
check_val_every_n_epoch=1,
callbacks=[LoggingCallback(resource_monitor), checkpoint_callback],
......
......@@ -125,12 +125,6 @@ def set_reproducible_cuda():
required=True,
cls=ResourceOption,
)
@click.option(
"--optimizer",
help="A torch.optim.Optimizer that will be used to train the network",
required=True,
cls=ResourceOption,
)
@click.option(
"--criterion",
help="A loss function to compute the CNN error for every sample "
......@@ -291,7 +285,6 @@ def set_reproducible_cuda():
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def train(
model,
optimizer,
output_folder,
epochs,
batch_size,
......@@ -481,7 +474,6 @@ def train(
logger.info(f"Z-normalization with mean {mean} and std {std}")
arguments = {}
arguments["epoch"] = 0
arguments["max_epoch"] = epochs
last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
......@@ -518,14 +510,12 @@ def train(
raise FileNotFoundError(f"Could not find checkpoint {resume_from}")
logger.info("Training for {} epochs".format(arguments["max_epoch"]))
logger.info("Continuing from epoch {}".format(arguments["epoch"]))
run(
model=model,
data_loader=data_loader,
valid_loader=valid_loader,
extra_valid_loaders=extra_valid_loaders,
optimizer=optimizer,
checkpoint_period=checkpoint_period,
device=device,
arguments=arguments,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment