Skip to content
Snippets Groups Projects
Commit 800081d1 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[train] Split common training operations into classes

parent 252156bb
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -2,8 +2,13 @@ import click ...@@ -2,8 +2,13 @@ import click
from clapper.click import ResourceOption, verbosity_option from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand from mednet.libs.common.scripts.click import ConfigCommand
from mednet.libs.common.scripts.train import reusable_options from mednet.libs.common.scripts.train import (
from mednet.libs.common.scripts.train import train as train_script get_checkpoint_file,
load_checkpoint,
reusable_options,
save_json_data,
setup_datamodule,
)
logger = setup("mednet", format="%(levelname)s: %(message)s") logger = setup("mednet", format="%(levelname)s: %(message)s")
...@@ -48,20 +53,68 @@ def train( ...@@ -48,20 +53,68 @@ def train(
extension that are used in subsequent tasks or from which training extension that are used in subsequent tasks or from which training
can be resumed. can be resumed.
""" """
train_script( from lightning.pytorch import seed_everything
from mednet.libs.common.engine.device import DeviceManager
from mednet.libs.common.engine.trainer import run
seed_everything(seed)
device_manager = DeviceManager(device)
# reset datamodule with user configurable options
setup_datamodule(
datamodule,
model,
batch_size,
batch_chunk_count,
drop_incomplete_batch,
cache_samples,
parallel,
)
# If asked, rebalances the loss criterion based on the relative proportion
# of class examples available in the training set. Also affects the
# validation loss if a validation set is available on the DataModule.
if balance_classes:
logger.info("Applying DataModule train sampler balancing...")
datamodule.balance_sampler_by_class = True
# logger.info("Applying train/valid loss balancing...")
# model.balance_losses_by_class(datamodule)
else:
logger.info(
"Skipping sample class/dataset ownership balancing on user request",
)
checkpoint_file = get_checkpoint_file(output_folder)
load_checkpoint(checkpoint_file, datamodule, model)
logger.info(f"Training for at most {epochs} epochs.")
# stores all information we can think of, to reproduce this later
save_json_data(
datamodule,
model, model,
output_folder, output_folder,
device_manager,
epochs, epochs,
batch_size, batch_size,
accumulate_grad_batches, accumulate_grad_batches,
drop_incomplete_batch, drop_incomplete_batch,
datamodule,
validation_period, validation_period,
device,
cache_samples, cache_samples,
seed, seed,
parallel, parallel,
monitoring_interval, monitoring_interval,
balance_classes, balance_classes,
**_, )
run(
model=model,
datamodule=datamodule,
validation_period=validation_period,
device_manager=device_manager,
max_epochs=epochs,
output_folder=output_folder,
monitoring_interval=monitoring_interval,
batch_chunk_count=batch_chunk_count,
checkpoint=checkpoint_file,
) )
...@@ -215,84 +215,50 @@ def reusable_options(f): ...@@ -215,84 +215,50 @@ def reusable_options(f):
return wrapper_reusable_options return wrapper_reusable_options
def train( def get_checkpoint_file(results_dir) -> pathlib.Path:
model, """Return the path of the latest checkpoint if it exists.
output_folder,
epochs,
batch_size,
accumulate_grad_batches,
drop_incomplete_batch,
datamodule,
validation_period,
device,
cache_samples,
seed,
parallel,
monitoring_interval,
balance_classes,
**_,
) -> None: # numpydoc ignore=PR01
"""Train an CNN to perform image classification.
Training is performed for a configurable number of epochs, and Parameters
generates checkpoints. Checkpoints are model files with a .ckpt ----------
extension that are used in subsequent tasks or from which training results_dir
can be resumed. Directory in which results are saved.
"""
import torch Returns
from lightning.pytorch import seed_everything -------
from mednet.libs.common.engine.device import DeviceManager Path to the latest checkpoint
from mednet.libs.common.engine.trainer import run """
from mednet.libs.common.utils.checkpointer import ( from mednet.libs.common.utils.checkpointer import (
get_checkpoint_to_resume_training, get_checkpoint_to_resume_training,
) )
from .utils import (
device_properties,
execution_metadata,
model_summary,
save_json_with_backup,
)
checkpoint_file = None checkpoint_file = None
if output_folder.is_dir(): if results_dir.is_dir():
try: try:
checkpoint_file = get_checkpoint_to_resume_training(output_folder) checkpoint_file = get_checkpoint_to_resume_training(results_dir)
except FileNotFoundError: except FileNotFoundError:
logger.info( logger.info(
f"Folder {output_folder} already exists, but I did not" f"Folder {results_dir} already exists, but I did not"
f" find any usable checkpoint file to resume training" f" find any usable checkpoint file to resume training"
f" from. Starting from scratch...", f" from. Starting from scratch...",
) )
seed_everything(seed) return checkpoint_file
# reset datamodule with user configurable options
datamodule.drop_incomplete_batch = drop_incomplete_batch
datamodule.cache_samples = cache_samples
datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms
datamodule.prepare_data()
datamodule.setup(stage="fit")
# If asked, rebalances the loss criterion based on the relative proportion def load_checkpoint(checkpoint_file, datamodule, model):
# of class examples available in the training set. Also affects the """Load the checkpoint.
# validation loss if a validation set is available on the DataModule.
if balance_classes:
logger.info("Applying train/valid loss balancing...")
model.balance_losses(datamodule)
else:
logger.info(
"Skipping sample class/dataset ownership balancing on user request",
)
logger.info(f"Training for at most {epochs} epochs.") Parameters
----------
checkpoint_file
Path to the checkpoint.
datamodule
Instance of a Datamodule, used to set the model's normalizer.
model
The model corresponding to the checkpoint.
"""
arguments = {} import torch
arguments["max_epoch"] = epochs
arguments["epoch"] = 0
if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"): if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"):
# Sets the model normalizer with the unaugmented-train-subset if we are # Sets the model normalizer with the unaugmented-train-subset if we are
...@@ -316,9 +282,51 @@ def train( ...@@ -316,9 +282,51 @@ def train(
f"(checkpoint file: `{str(checkpoint_file)}`)...", f"(checkpoint file: `{str(checkpoint_file)}`)...",
) )
device_manager = DeviceManager(device)
# stores all information we can think of, to reproduce this later def setup_datamodule(
datamodule,
model,
batch_size,
batch_chunk_count,
drop_incomplete_batch,
cache_samples,
parallel,
) -> None: # numpydoc ignore=PR01
"""Configure and set up the datamodule."""
datamodule.set_chunk_size(batch_size, batch_chunk_count)
datamodule.drop_incomplete_batch = drop_incomplete_batch
datamodule.cache_samples = cache_samples
datamodule.parallel = parallel
datamodule.model_transforms = model.model_transforms
datamodule.prepare_data()
datamodule.setup(stage="fit")
def save_json_data(
datamodule,
model,
output_folder,
device_manager,
epochs,
batch_size,
accumulate_grad_batches,
drop_incomplete_batch,
validation_period,
cache_samples,
seed,
parallel,
monitoring_interval,
balance_classes,
) -> None: # numpydoc ignore=PR01
"""Save training hyperparameters into a .json file."""
from .utils import (
device_properties,
execution_metadata,
model_summary,
save_json_with_backup,
)
json_data: dict[str, typing.Any] = execution_metadata() json_data: dict[str, typing.Any] = execution_metadata()
json_data.update(device_properties(device_manager.device_type)) json_data.update(device_properties(device_manager.device_type))
json_data.update( json_data.update(
...@@ -341,15 +349,3 @@ def train( ...@@ -341,15 +349,3 @@ def train(
json_data.update(model_summary(model)) json_data.update(model_summary(model))
json_data = {k.replace("_", "-"): v for k, v in json_data.items()} json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
save_json_with_backup(output_folder / "meta.json", json_data) save_json_with_backup(output_folder / "meta.json", json_data)
run(
model=model,
datamodule=datamodule,
validation_period=validation_period,
device_manager=device_manager,
max_epochs=epochs,
output_folder=output_folder,
monitoring_interval=monitoring_interval,
accumulate_grad_batches=accumulate_grad_batches,
checkpoint=checkpoint_file,
)
...@@ -2,14 +2,19 @@ import click ...@@ -2,14 +2,19 @@ import click
from clapper.click import ResourceOption, verbosity_option from clapper.click import ResourceOption, verbosity_option
from clapper.logging import setup from clapper.logging import setup
from mednet.libs.common.scripts.click import ConfigCommand from mednet.libs.common.scripts.click import ConfigCommand
from mednet.libs.common.scripts.train import reusable_options from mednet.libs.common.scripts.train import (
from mednet.libs.common.scripts.train import train as train_script get_checkpoint_file,
load_checkpoint,
reusable_options,
save_json_data,
setup_datamodule,
)
logger = setup("mednet", format="%(levelname)s: %(message)s") logger = setup("mednet", format="%(levelname)s: %(message)s")
@click.command( @click.command(
entry_point_group="mednet.libs.segmentation.config", entry_point_group="mednet.libs.classification.config",
cls=ConfigCommand, cls=ConfigCommand,
epilog="""Examples: epilog="""Examples:
...@@ -39,27 +44,62 @@ def train( ...@@ -39,27 +44,62 @@ def train(
balance_classes, balance_classes,
**_, **_,
) -> None: # numpydoc ignore=PR01 ) -> None: # numpydoc ignore=PR01
"""Train a model to perform image segmentation. """Train an CNN to perform image classification.
Training is performed for a configurable number of epochs, and Training is performed for a configurable number of epochs, and
generates checkpoints. Checkpoints are model files with a .ckpt generates checkpoints. Checkpoints are model files with a .ckpt
extension that are used in subsequent tasks or from which training extension that are used in subsequent tasks or from which training
can be resumed. can be resumed.
""" """
train_script( from lightning.pytorch import seed_everything
from mednet.libs.common.engine.device import DeviceManager
from mednet.libs.common.engine.trainer import run
seed_everything(seed)
device_manager = DeviceManager(device)
# reset datamodule with user configurable options
setup_datamodule(
datamodule,
model,
batch_size,
batch_chunk_count,
drop_incomplete_batch,
cache_samples,
parallel,
)
checkpoint_file = get_checkpoint_file(output_folder)
load_checkpoint(checkpoint_file, datamodule, model)
logger.info(f"Training for at most {epochs} epochs.")
# stores all information we can think of, to reproduce this later
save_json_data(
datamodule,
model, model,
output_folder, output_folder,
device_manager,
epochs, epochs,
batch_size, batch_size,
batch_chunk_count, batch_chunk_count,
drop_incomplete_batch, drop_incomplete_batch,
datamodule,
validation_period, validation_period,
device,
cache_samples, cache_samples,
seed, seed,
parallel, parallel,
monitoring_interval, monitoring_interval,
balance_classes, balance_classes,
**_, )
run(
model=model,
datamodule=datamodule,
validation_period=validation_period,
device_manager=device_manager,
max_epochs=epochs,
output_folder=output_folder,
monitoring_interval=monitoring_interval,
batch_chunk_count=batch_chunk_count,
checkpoint=checkpoint_file,
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment