Skip to content
Snippets Groups Projects
Commit 0209ebe1 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Merge branch 'checkpointing_fix' into 'main'

epoch checkpointing fix

See merge request biosignal/software/ptbench!1
parents b5c5220e 09e50ca4
No related branches found
No related tags found
1 merge request!1epoch checkpointing fix
Pipeline #71311 passed
...@@ -63,7 +63,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -63,7 +63,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
) )
@click.option( @click.option(
"--device", "--device",
"-d",
help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
show_default=True, show_default=True,
required=True, required=True,
......
...@@ -259,7 +259,6 @@ def set_reproducible_cuda(): ...@@ -259,7 +259,6 @@ def set_reproducible_cuda():
) )
@click.option( @click.option(
"--device", "--device",
"-d",
help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
show_default=True, show_default=True,
required=True, required=True,
...@@ -289,13 +288,6 @@ def set_reproducible_cuda(): ...@@ -289,13 +288,6 @@ def set_reproducible_cuda():
default=-1, default=-1,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option(
"--weight",
"-w",
help="Path or URL to pretrained model file (.pth extension)",
required=False,
cls=ResourceOption,
)
@click.option( @click.option(
"--normalization", "--normalization",
"-n", "-n",
...@@ -338,7 +330,6 @@ def train( ...@@ -338,7 +330,6 @@ def train(
device, device,
seed, seed,
parallel, parallel,
weight,
normalization, normalization,
monitoring_interval, monitoring_interval,
**_, **_,
...@@ -364,7 +355,6 @@ def train( ...@@ -364,7 +355,6 @@ def train(
from ..configs.datasets import get_positive_weights, get_samples_weights from ..configs.datasets import get_positive_weights, get_samples_weights
from ..engine.trainer import run from ..engine.trainer import run
from ..utils.checkpointer import Checkpointer from ..utils.checkpointer import Checkpointer
from ..utils.download import download_to_tempfile
device = setup_pytorch_device(device) device = setup_pytorch_device(device)
...@@ -526,18 +516,11 @@ def train( ...@@ -526,18 +516,11 @@ def train(
# Checkpointer # Checkpointer
checkpointer = Checkpointer(model, optimizer, path=output_folder) checkpointer = Checkpointer(model, optimizer, path=output_folder)
# Load pretrained weights if needed # Initialize epoch information
if weight is not None:
if weight.startswith("http"):
logger.info(f"Temporarily downloading '{weight}'...")
f = download_to_tempfile(weight, progress=True)
weight_fullpath = os.path.abspath(f.name)
else:
weight_fullpath = os.path.abspath(weight)
checkpointer.load(weight_fullpath, strict=False)
arguments = {} arguments = {}
arguments["epoch"] = 0 arguments["epoch"] = 0
extra_checkpoint_data = checkpointer.load()
arguments.update(extra_checkpoint_data)
arguments["max_epoch"] = epochs arguments["max_epoch"] = epochs
logger.info("Training for {} epochs".format(arguments["max_epoch"])) logger.info("Training for {} epochs".format(arguments["max_epoch"]))
......
...@@ -226,6 +226,83 @@ def test_train_pasa_montgomery(temporary_basedir): ...@@ -226,6 +226,83 @@ def test_train_pasa_montgomery(temporary_basedir):
) )
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
from ptbench.scripts.train import train
runner = CliRunner()
output_folder = str(temporary_basedir / "results/pasa_checkpoint")
result0 = runner.invoke(
train,
[
"pasa",
"montgomery",
"-vv",
"--epochs=1",
"--batch-size=1",
"--normalization=current",
f"--output-folder={output_folder}",
],
)
_assert_exit_0(result0)
assert os.path.exists(os.path.join(output_folder, "model_final_epoch.pth"))
assert os.path.exists(
os.path.join(output_folder, "model_lowest_valid_loss.pth")
)
assert os.path.exists(os.path.join(output_folder, "last_checkpoint"))
assert os.path.exists(os.path.join(output_folder, "constants.csv"))
assert os.path.exists(os.path.join(output_folder, "trainlog.csv"))
assert os.path.exists(os.path.join(output_folder, "model_summary.txt"))
with stdout_logging() as buf:
result = runner.invoke(
train,
[
"pasa",
"montgomery",
"-vv",
"--epochs=2",
"--batch-size=1",
"--normalization=current",
f"--output-folder={output_folder}",
],
)
_assert_exit_0(result)
assert os.path.exists(
os.path.join(output_folder, "model_final_epoch.pth")
)
assert os.path.exists(
os.path.join(output_folder, "model_lowest_valid_loss.pth")
)
assert os.path.exists(os.path.join(output_folder, "last_checkpoint"))
assert os.path.exists(os.path.join(output_folder, "constants.csv"))
assert os.path.exists(os.path.join(output_folder, "trainlog.csv"))
assert os.path.exists(os.path.join(output_folder, "model_summary.txt"))
keywords = {
r"^Found \(dedicated\) '__train__' set for training$": 1,
r"^Found \(dedicated\) '__valid__' set for validation$": 1,
r"^Continuing from epoch 1$": 1,
r"^Saving model summary at.*$": 1,
r"^Model has.*$": 1,
r"^Saving checkpoint": 2,
r"^Total training time:": 1,
r"^Z-normalization with mean": 1,
}
buf.seek(0)
logging_output = buf.read()
for k, v in keywords.items():
assert _str_counter(k, logging_output) == v, (
f"Count for string '{k}' appeared "
f"({_str_counter(k, logging_output)}) "
f"instead of the expected {v}:\nOutput:\n{logging_output}"
)
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_predict_pasa_montgomery(temporary_basedir, datadir): def test_predict_pasa_montgomery(temporary_basedir, datadir):
from ptbench.scripts.predict import predict from ptbench.scripts.predict import predict
...@@ -416,7 +493,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): ...@@ -416,7 +493,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir):
runner = CliRunner() runner = CliRunner()
with stdout_logging() as buf: with stdout_logging() as buf:
output_folder = str(temporary_basedir / "results") output_folder = str(temporary_basedir / "results/signstotb")
result = runner.invoke( result = runner.invoke(
train, train,
[ [
...@@ -425,7 +502,6 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir): ...@@ -425,7 +502,6 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir):
"-vv", "-vv",
"--epochs=1", "--epochs=1",
"--batch-size=1", "--batch-size=1",
f"--weight={str(datadir / 'lfs' / 'models' / 'signstotb.pth')}",
f"--output-folder={output_folder}", f"--output-folder={output_folder}",
], ],
) )
...@@ -518,7 +594,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir): ...@@ -518,7 +594,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir):
runner = CliRunner() runner = CliRunner()
with stdout_logging() as buf: with stdout_logging() as buf:
output_folder = str(temporary_basedir / "results") output_folder = str(temporary_basedir / "results/logreg")
result = runner.invoke( result = runner.invoke(
train, train,
[ [
...@@ -527,7 +603,6 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir): ...@@ -527,7 +603,6 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir):
"-vv", "-vv",
"--epochs=1", "--epochs=1",
"--batch-size=1", "--batch-size=1",
f"--weight={str(datadir / 'lfs' / 'models' / 'logreg.pth')}",
f"--output-folder={output_folder}", f"--output-folder={output_folder}",
], ],
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment