Skip to content
Snippets Groups Projects
Commit 09e50ca4 authored by ogueler@idiap.ch's avatar ogueler@idiap.ch
Browse files

removed weight loading functionality

parent 010a0147
No related branches found
No related tags found
1 merge request!1epoch checkpointing fix
Pipeline #71287 passed
This commit is part of merge request !1. Comments created here will be created in the context of that merge request.
......@@ -288,13 +288,6 @@ def set_reproducible_cuda():
default=-1,
cls=ResourceOption,
)
@click.option(
"--weight",
"-w",
help="Path or URL to pretrained model file (.pth extension)",
required=False,
cls=ResourceOption,
)
@click.option(
"--normalization",
"-n",
......@@ -337,7 +330,6 @@ def train(
device,
seed,
parallel,
weight,
normalization,
monitoring_interval,
**_,
......@@ -363,7 +355,6 @@ def train(
from ..configs.datasets import get_positive_weights, get_samples_weights
from ..engine.trainer import run
from ..utils.checkpointer import Checkpointer
from ..utils.download import download_to_tempfile
device = setup_pytorch_device(device)
......@@ -528,29 +519,7 @@ def train(
# Initialize epoch information
arguments = {}
arguments["epoch"] = 0
# Load pretrained weights if needed
if weight is not None:
if checkpointer.has_checkpoint():
logger.warning(
"Weights are being ignored because a checkpoint already exists. "
"Weights from checkpoint will be loaded instead."
)
extra_checkpoint_data = checkpointer.load()
else:
if weight.startswith("http"):
logger.info(f"Temporarily downloading '{weight}'...")
f = download_to_tempfile(weight, progress=True)
weight_fullpath = os.path.abspath(f.name)
else:
weight_fullpath = os.path.abspath(weight)
extra_checkpoint_data = checkpointer.load(
weight_fullpath, strict=False
)
else:
extra_checkpoint_data = checkpointer.load()
# Update epoch information with checkpoint data
extra_checkpoint_data = checkpointer.load()
arguments.update(extra_checkpoint_data)
arguments["max_epoch"] = epochs
......
......@@ -226,6 +226,83 @@ def test_train_pasa_montgomery(temporary_basedir):
)
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
from ptbench.scripts.train import train
runner = CliRunner()
output_folder = str(temporary_basedir / "results/pasa_checkpoint")
result0 = runner.invoke(
train,
[
"pasa",
"montgomery",
"-vv",
"--epochs=1",
"--batch-size=1",
"--normalization=current",
f"--output-folder={output_folder}",
],
)
_assert_exit_0(result0)
assert os.path.exists(os.path.join(output_folder, "model_final_epoch.pth"))
assert os.path.exists(
os.path.join(output_folder, "model_lowest_valid_loss.pth")
)
assert os.path.exists(os.path.join(output_folder, "last_checkpoint"))
assert os.path.exists(os.path.join(output_folder, "constants.csv"))
assert os.path.exists(os.path.join(output_folder, "trainlog.csv"))
assert os.path.exists(os.path.join(output_folder, "model_summary.txt"))
with stdout_logging() as buf:
result = runner.invoke(
train,
[
"pasa",
"montgomery",
"-vv",
"--epochs=2",
"--batch-size=1",
"--normalization=current",
f"--output-folder={output_folder}",
],
)
_assert_exit_0(result)
assert os.path.exists(
os.path.join(output_folder, "model_final_epoch.pth")
)
assert os.path.exists(
os.path.join(output_folder, "model_lowest_valid_loss.pth")
)
assert os.path.exists(os.path.join(output_folder, "last_checkpoint"))
assert os.path.exists(os.path.join(output_folder, "constants.csv"))
assert os.path.exists(os.path.join(output_folder, "trainlog.csv"))
assert os.path.exists(os.path.join(output_folder, "model_summary.txt"))
keywords = {
r"^Found \(dedicated\) '__train__' set for training$": 1,
r"^Found \(dedicated\) '__valid__' set for validation$": 1,
r"^Continuing from epoch 1$": 1,
r"^Saving model summary at.*$": 1,
r"^Model has.*$": 1,
r"^Saving checkpoint": 2,
r"^Total training time:": 1,
r"^Z-normalization with mean": 1,
}
buf.seek(0)
logging_output = buf.read()
for k, v in keywords.items():
assert _str_counter(k, logging_output) == v, (
f"Count for string '{k}' appeared "
f"({_str_counter(k, logging_output)}) "
f"instead of the expected {v}:\nOutput:\n{logging_output}"
)
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_predict_pasa_montgomery(temporary_basedir, datadir):
from ptbench.scripts.predict import predict
......@@ -423,9 +500,8 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir):
"signs_to_tb",
"montgomery_rs",
"-vv",
"--epochs=15",
"--epochs=1",
"--batch-size=1",
f"--weight={str(datadir / 'lfs' / 'models' / 'signstotb.pth')}",
f"--output-folder={output_folder}",
],
)
......@@ -445,7 +521,7 @@ def test_train_signstotb_montgomery_rs(temporary_basedir, datadir):
keywords = {
r"^Found \(dedicated\) '__train__' set for training$": 1,
r"^Found \(dedicated\) '__valid__' set for validation$": 1,
r"^Continuing from epoch 14$": 1,
r"^Continuing from epoch 0$": 1,
r"^Saving model summary at.*$": 1,
r"^Model has.*$": 1,
r"^Saving checkpoint": 2,
......@@ -525,9 +601,8 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir):
"logistic_regression",
"montgomery_rs",
"-vv",
"--epochs=43",
"--epochs=1",
"--batch-size=1",
f"--weight={str(datadir / 'lfs' / 'models' / 'logreg.pth')}",
f"--output-folder={output_folder}",
],
)
......@@ -547,7 +622,7 @@ def test_train_logreg_montgomery_rs(temporary_basedir, datadir):
keywords = {
r"^Found \(dedicated\) '__train__' set for training$": 1,
r"^Found \(dedicated\) '__valid__' set for validation$": 1,
r"^Continuing from epoch 42$": 1,
r"^Continuing from epoch 0$": 1,
r"^Saving model summary at.*$": 1,
r"^Model has.*$": 1,
r"^Saving checkpoint": 2,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment