diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py
index 94fbcb00bacbb0cadce116e5a944c4abf57cdbb3..dbc14ebef400b23dbacd90c458d6108b4f76c69e 100644
--- a/src/mednet/libs/classification/scripts/train.py
+++ b/src/mednet/libs/classification/scripts/train.py
@@ -2,8 +2,13 @@ import click
 from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 from mednet.libs.common.scripts.click import ConfigCommand
-from mednet.libs.common.scripts.train import reusable_options
-from mednet.libs.common.scripts.train import train as train_script
+from mednet.libs.common.scripts.train import (
+    get_checkpoint_file,
+    load_checkpoint,
+    reusable_options,
+    save_json_data,
+    setup_datamodule,
+)
 
 logger = setup("mednet", format="%(levelname)s: %(message)s")
 
@@ -48,20 +53,68 @@ def train(
     extension that are used in subsequent tasks or from which training
     can be resumed.
     """
-    train_script(
+    from lightning.pytorch import seed_everything
+    from mednet.libs.common.engine.device import DeviceManager
+    from mednet.libs.common.engine.trainer import run
+
+    seed_everything(seed)
+    device_manager = DeviceManager(device)
+
+    # reset datamodule with user configurable options
+    setup_datamodule(
+        datamodule,
+        model,
+        batch_size,
+        batch_chunk_count,
+        drop_incomplete_batch,
+        cache_samples,
+        parallel,
+    )
+
+    # If asked, rebalances the loss criterion based on the relative proportion
+    # of class examples available in the training set.  Also affects the
+    # validation loss if a validation set is available on the DataModule.
+    if balance_classes:
+        logger.info("Applying DataModule train sampler balancing...")
+        datamodule.balance_sampler_by_class = True
+        # logger.info("Applying train/valid loss balancing...")
+        # model.balance_losses_by_class(datamodule)
+    else:
+        logger.info(
+            "Skipping sample class/dataset ownership balancing on user request",
+        )
+
+    checkpoint_file = get_checkpoint_file(output_folder)
+    load_checkpoint(checkpoint_file, datamodule, model)
+
+    logger.info(f"Training for at most {epochs} epochs.")
+
+    # stores all information we can think of, to reproduce this later
+    save_json_data(
+        datamodule,
         model,
         output_folder,
+        device_manager,
         epochs,
         batch_size,
         accumulate_grad_batches,
         drop_incomplete_batch,
-        datamodule,
         validation_period,
-        device,
         cache_samples,
         seed,
         parallel,
         monitoring_interval,
         balance_classes,
-        **_,
+    )
+
+    run(
+        model=model,
+        datamodule=datamodule,
+        validation_period=validation_period,
+        device_manager=device_manager,
+        max_epochs=epochs,
+        output_folder=output_folder,
+        monitoring_interval=monitoring_interval,
+        batch_chunk_count=batch_chunk_count,
+        checkpoint=checkpoint_file,
     )
diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py
index 2228725b673068f8417fcb1a8b6fc829357dab39..32d7bd9481b0d973d0f0eb6279836d584578d84b 100644
--- a/src/mednet/libs/common/scripts/train.py
+++ b/src/mednet/libs/common/scripts/train.py
@@ -215,84 +215,50 @@ def reusable_options(f):
     return wrapper_reusable_options
 
 
-def train(
-    model,
-    output_folder,
-    epochs,
-    batch_size,
-    accumulate_grad_batches,
-    drop_incomplete_batch,
-    datamodule,
-    validation_period,
-    device,
-    cache_samples,
-    seed,
-    parallel,
-    monitoring_interval,
-    balance_classes,
-    **_,
-) -> None:  # numpydoc ignore=PR01
-    """Train an CNN to perform image classification.
+def get_checkpoint_file(results_dir) -> pathlib.Path:
+    """Return the path of the latest checkpoint if it exists.
 
-    Training is performed for a configurable number of epochs, and
-    generates checkpoints.  Checkpoints are model files with a .ckpt
-    extension that are used in subsequent tasks or from which training
-    can be resumed.
-    """
+    Parameters
+    ----------
+    results_dir
+        Directory in which results are saved.
 
-    import torch
-    from lightning.pytorch import seed_everything
-    from mednet.libs.common.engine.device import DeviceManager
-    from mednet.libs.common.engine.trainer import run
+    Returns
+    -------
+        Path to the latest checkpoint
+    """
     from mednet.libs.common.utils.checkpointer import (
         get_checkpoint_to_resume_training,
     )
 
-    from .utils import (
-        device_properties,
-        execution_metadata,
-        model_summary,
-        save_json_with_backup,
-    )
-
     checkpoint_file = None
-    if output_folder.is_dir():
+    if results_dir.is_dir():
         try:
-            checkpoint_file = get_checkpoint_to_resume_training(output_folder)
+            checkpoint_file = get_checkpoint_to_resume_training(results_dir)
         except FileNotFoundError:
             logger.info(
-                f"Folder {output_folder} already exists, but I did not"
+                f"Folder {results_dir} already exists, but I did not"
                 f" find any usable checkpoint file to resume training"
                 f" from. Starting from scratch...",
             )
 
-    seed_everything(seed)
-
-    # reset datamodule with user configurable options
-    datamodule.drop_incomplete_batch = drop_incomplete_batch
-    datamodule.cache_samples = cache_samples
-    datamodule.parallel = parallel
-    datamodule.model_transforms = model.model_transforms
+    return checkpoint_file
 
-    datamodule.prepare_data()
-    datamodule.setup(stage="fit")
 
-    # If asked, rebalances the loss criterion based on the relative proportion
-    # of class examples available in the training set.  Also affects the
-    # validation loss if a validation set is available on the DataModule.
-    if balance_classes:
-        logger.info("Applying train/valid loss balancing...")
-        model.balance_losses(datamodule)
-    else:
-        logger.info(
-            "Skipping sample class/dataset ownership balancing on user request",
-        )
+def load_checkpoint(checkpoint_file, datamodule, model):
+    """Load the checkpoint.
 
-    logger.info(f"Training for at most {epochs} epochs.")
+    Parameters
+    ----------
+    checkpoint_file
+        Path to the checkpoint.
+    datamodule
+        Instance of a Datamodule, used to set the model's normalizer.
+    model
+        The model corresponding to the checkpoint.
+    """
 
-    arguments = {}
-    arguments["max_epoch"] = epochs
-    arguments["epoch"] = 0
+    import torch
 
     if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"):
         # Sets the model normalizer with the unaugmented-train-subset if we are
@@ -316,9 +282,51 @@ def train(
             f"(checkpoint file: `{str(checkpoint_file)}`)...",
         )
 
-    device_manager = DeviceManager(device)
 
-    # stores all information we can think of, to reproduce this later
+def setup_datamodule(
+    datamodule,
+    model,
+    batch_size,
+    batch_chunk_count,
+    drop_incomplete_batch,
+    cache_samples,
+    parallel,
+) -> None:  # numpydoc ignore=PR01
+    """Configure and set up the datamodule."""
+    datamodule.set_chunk_size(batch_size, batch_chunk_count)
+    datamodule.drop_incomplete_batch = drop_incomplete_batch
+    datamodule.cache_samples = cache_samples
+    datamodule.parallel = parallel
+    datamodule.model_transforms = model.model_transforms
+
+    datamodule.prepare_data()
+    datamodule.setup(stage="fit")
+
+
+def save_json_data(
+    datamodule,
+    model,
+    output_folder,
+    device_manager,
+    epochs,
+    batch_size,
+    accumulate_grad_batches,
+    drop_incomplete_batch,
+    validation_period,
+    cache_samples,
+    seed,
+    parallel,
+    monitoring_interval,
+    balance_classes,
+) -> None:  # numpydoc ignore=PR01
+    """Save training hyperparameters into a .json file."""
+    from .utils import (
+        device_properties,
+        execution_metadata,
+        model_summary,
+        save_json_with_backup,
+    )
+
     json_data: dict[str, typing.Any] = execution_metadata()
     json_data.update(device_properties(device_manager.device_type))
     json_data.update(
@@ -341,15 +349,3 @@ def train(
     json_data.update(model_summary(model))
     json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
     save_json_with_backup(output_folder / "meta.json", json_data)
-
-    run(
-        model=model,
-        datamodule=datamodule,
-        validation_period=validation_period,
-        device_manager=device_manager,
-        max_epochs=epochs,
-        output_folder=output_folder,
-        monitoring_interval=monitoring_interval,
-        accumulate_grad_batches=accumulate_grad_batches,
-        checkpoint=checkpoint_file,
-    )
diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py
index 933231c6d6aa3a6a96088d61f42407b041ac1da8..2632369c874fda1aad7b792c6f6a0a655b0cfe3b 100644
--- a/src/mednet/libs/segmentation/scripts/train.py
+++ b/src/mednet/libs/segmentation/scripts/train.py
@@ -2,14 +2,19 @@ import click
 from clapper.click import ResourceOption, verbosity_option
 from clapper.logging import setup
 from mednet.libs.common.scripts.click import ConfigCommand
-from mednet.libs.common.scripts.train import reusable_options
-from mednet.libs.common.scripts.train import train as train_script
+from mednet.libs.common.scripts.train import (
+    get_checkpoint_file,
+    load_checkpoint,
+    reusable_options,
+    save_json_data,
+    setup_datamodule,
+)
 
 logger = setup("mednet", format="%(levelname)s: %(message)s")
 
 
 @click.command(
-    entry_point_group="mednet.libs.segmentation.config",
+    entry_point_group="mednet.libs.classification.config",
     cls=ConfigCommand,
     epilog="""Examples:
 
@@ -39,27 +44,62 @@ def train(
     balance_classes,
     **_,
 ) -> None:  # numpydoc ignore=PR01
-    """Train a model to perform image segmentation.
+    """Train an CNN to perform image classification.
 
     Training is performed for a configurable number of epochs, and
     generates checkpoints.  Checkpoints are model files with a .ckpt
     extension that are used in subsequent tasks or from which training
     can be resumed.
     """
-    train_script(
+    from lightning.pytorch import seed_everything
+    from mednet.libs.common.engine.device import DeviceManager
+    from mednet.libs.common.engine.trainer import run
+
+    seed_everything(seed)
+    device_manager = DeviceManager(device)
+
+    # reset datamodule with user configurable options
+    setup_datamodule(
+        datamodule,
+        model,
+        batch_size,
+        batch_chunk_count,
+        drop_incomplete_batch,
+        cache_samples,
+        parallel,
+    )
+
+    checkpoint_file = get_checkpoint_file(output_folder)
+    load_checkpoint(checkpoint_file, datamodule, model)
+
+    logger.info(f"Training for at most {epochs} epochs.")
+
+    # stores all information we can think of, to reproduce this later
+    save_json_data(
+        datamodule,
         model,
         output_folder,
+        device_manager,
         epochs,
         batch_size,
         batch_chunk_count,
         drop_incomplete_batch,
-        datamodule,
         validation_period,
-        device,
         cache_samples,
         seed,
         parallel,
         monitoring_interval,
         balance_classes,
-        **_,
+    )
+
+    run(
+        model=model,
+        datamodule=datamodule,
+        validation_period=validation_period,
+        device_manager=device_manager,
+        max_epochs=epochs,
+        output_folder=output_folder,
+        monitoring_interval=monitoring_interval,
+        batch_chunk_count=batch_chunk_count,
+        checkpoint=checkpoint_file,
     )