Skip to content
Snippets Groups Projects
Commit 9d36fa1f authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Moved train.py to lightning

parent 7ed6145b
No related branches found
No related tags found
1 merge request!4Moved code to lightning
......@@ -512,15 +512,12 @@ def run(
valid_loader,
extra_valid_loaders,
optimizer,
criterion,
checkpointer,
checkpoint_period,
device,
arguments,
output_folder,
monitoring_interval,
batch_chunk_count,
criterion_valid,
):
"""Fits a CNN model using supervised learning and save it to disk.
......@@ -549,12 +546,6 @@ def run(
optimizer : :py:mod:`torch.optim`
criterion : :py:class:`torch.nn.modules.loss._Loss`
loss function
checkpointer : :py:class:`ptbench.utils.checkpointer.Checkpointer`
checkpointer implementation
checkpoint_period : int
save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do
not save intermediary checkpoints
......@@ -578,9 +569,6 @@ def run(
mini-batch. This is particularly interesting when one has limited RAM
on the GPU, but would like to keep training with larger batches. One
exchanges for longer processing times in this case.
criterion_valid : :py:class:`torch.nn.modules.loss._Loss`
specific loss function for the validation set
"""
max_epoch = arguments["max_epoch"]
......@@ -621,7 +609,7 @@ def run(
],
)
_ = trainer.fit(model, data_loader)
_ = trainer.fit(model, data_loader, valid_loader)
"""# write static information to a CSV file
static_logfile_name = os.path.join(output_folder, "constants.csv")
......
......@@ -8,6 +8,7 @@ import click
from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup
from pytorch_lightning import seed_everything
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......@@ -53,42 +54,6 @@ def setup_pytorch_device(name):
return torch.device(name)
def set_seeds(value, all_gpus):
"""Sets up all relevant random seeds (numpy, python, cuda)
If running with multiple GPUs **at the same time**, set ``all_gpus`` to
``True`` to force all GPU seeds to be initialized.
Reference: `PyTorch page for reproducibility
<https://pytorch.org/docs/stable/notes/randomness.html>`_.
Parameters
----------
value : int
The random seed value to use
all_gpus : :py:class:`bool`, Optional
If set, then reset the seed on all GPUs available at once. This is
normally **not** what you want if running on a single GPU
"""
import random
import numpy.random
import torch
import torch.cuda
random.seed(value)
numpy.random.seed(value)
torch.manual_seed(value)
torch.cuda.manual_seed(value) # noop if cuda not available
# set seeds for all gpus
if all_gpus:
torch.cuda.manual_seed_all(value) # noop if cuda not available
def set_reproducible_cuda():
"""Turns-off all CUDA optimizations that would affect reproducibility.
......@@ -252,13 +217,14 @@ def set_reproducible_cuda():
"last saved checkpoint if training is restarted with the same "
"configuration.",
show_default=True,
required=True,
default=0,
required=False,
default=None,
type=click.IntRange(min=0),
cls=ResourceOption,
)
@click.option(
"--device",
"-d",
help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
show_default=True,
required=True,
......@@ -288,6 +254,13 @@ def set_reproducible_cuda():
default=-1,
cls=ResourceOption,
)
@click.option(
"--weight",
"-w",
help="Path or URL to pretrained model file (.pth extension)",
required=False,
cls=ResourceOption,
)
@click.option(
"--normalization",
"-n",
......@@ -330,6 +303,7 @@ def train(
device,
seed,
parallel,
weight,
normalization,
monitoring_interval,
**_,
......@@ -354,11 +328,10 @@ def train(
from ..configs.datasets import get_positive_weights, get_samples_weights
from ..engine.trainer import run
from ..utils.checkpointer import Checkpointer
device = setup_pytorch_device(device)
set_seeds(seed, all_gpus=False)
seed_everything(seed)
use_dataset = dataset
validation_dataset = None
......@@ -418,9 +391,6 @@ def train(
# Create weighted random sampler
train_samples_weights = get_samples_weights(use_dataset)
train_samples_weights = train_samples_weights.to(
device=device, non_blocking=torch.cuda.is_available()
)
train_sampler = WeightedRandomSampler(
train_samples_weights, len(train_samples_weights), replacement=True
)
......@@ -428,10 +398,7 @@ def train(
# Redefine a weighted criterion if possible
if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
positive_weights = get_positive_weights(use_dataset)
positive_weights = positive_weights.to(
device=device, non_blocking=torch.cuda.is_available()
)
criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
model.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
else:
logger.warning("Weighted criterion not supported")
......@@ -454,10 +421,9 @@ def train(
or criterion_valid is None
):
positive_weights = get_positive_weights(validation_dataset)
positive_weights = positive_weights.to(
device=device, non_blocking=torch.cuda.is_available()
model.criterion_valid = BCEWithLogitsLoss(
pos_weight=positive_weights
)
criterion_valid = BCEWithLogitsLoss(pos_weight=positive_weights)
else:
logger.warning("Weighted valid criterion not supported")
......@@ -513,14 +479,8 @@ def train(
)
logger.info(f"Z-normalization with mean {mean} and std {std}")
# Checkpointer
checkpointer = Checkpointer(model, optimizer, path=output_folder)
# Initialize epoch information
arguments = {}
arguments["epoch"] = 0
extra_checkpoint_data = checkpointer.load()
arguments.update(extra_checkpoint_data)
arguments["max_epoch"] = epochs
logger.info("Training for {} epochs".format(arguments["max_epoch"]))
......@@ -532,13 +492,10 @@ def train(
valid_loader=valid_loader,
extra_valid_loaders=extra_valid_loaders,
optimizer=optimizer,
criterion=criterion,
checkpointer=checkpointer,
checkpoint_period=checkpoint_period,
device=device,
arguments=arguments,
output_folder=output_folder,
monitoring_interval=monitoring_interval,
batch_chunk_count=batch_chunk_count,
criterion_valid=criterion_valid,
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment