Skip to content
Snippets Groups Projects
Commit 0152c9c3 authored by André Anjos's avatar André Anjos :speech_balloon: Committed by Daniel CARRON
Browse files

[utils.checkpointer] Refactor checkpoint saving and loading

parent 6c09539b
No related branches found
No related tags found
1 merge request!12Adds grad-cam support on classifiers
...@@ -580,8 +580,10 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -580,8 +580,10 @@ class ConcatDataModule(lightning.LightningDataModule):
if value < 0: if value < 0:
num_workers = 0 num_workers = 0
else: else:
num_workers = value or multiprocessing.cpu_count() num_workers = value or multiprocessing.cpu_count()
self._dataloader_multiproc["num_workers"] = num_workers self._dataloader_multiproc["num_workers"] = num_workers
if num_workers > 0 and sys.platform == "darwin": if num_workers > 0 and sys.platform == "darwin":
...@@ -589,6 +591,9 @@ class ConcatDataModule(lightning.LightningDataModule): ...@@ -589,6 +591,9 @@ class ConcatDataModule(lightning.LightningDataModule):
"multiprocessing_context" "multiprocessing_context"
] = multiprocessing.get_context("spawn") ] = multiprocessing.get_context("spawn")
# keep workers hanging around if we have multiple
self._dataloader_multiproc["persistent_workers"] = True
@property @property
def model_transforms(self) -> list[Transform] | None: def model_transforms(self) -> list[Transform] | None:
"""Transforms required to fit data into the model. """Transforms required to fit data into the model.
......
...@@ -13,6 +13,7 @@ import lightning.pytorch.callbacks ...@@ -13,6 +13,7 @@ import lightning.pytorch.callbacks
import lightning.pytorch.loggers import lightning.pytorch.loggers
import torch.nn import torch.nn
from ..utils.checkpointer import CHECKPOINT_ALIASES
from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
from .callbacks import LoggingCallback from .callbacks import LoggingCallback
from .device import DeviceManager from .device import DeviceManager
...@@ -47,13 +48,13 @@ def save_model_summary( ...@@ -47,13 +48,13 @@ def save_model_summary(
summary_path = output_folder / "model-summary.txt" summary_path = output_folder / "model-summary.txt"
logger.info(f"Saving model summary at {summary_path}...") logger.info(f"Saving model summary at {summary_path}...")
with summary_path.open("w") as f: with summary_path.open("w") as f:
summary = lightning.pytorch.utilities.model_summary.ModelSummary( summary = lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore
model, max_depth=-1 model, max_depth=-1
) )
f.write(str(summary)) f.write(str(summary))
return ( return (
summary, summary,
lightning.pytorch.utilities.model_summary.ModelSummary( lightning.pytorch.utilities.model_summary.ModelSummary( # type: ignore
model model
).total_parameters, ).total_parameters,
) )
...@@ -99,13 +100,13 @@ def static_information_to_csv( ...@@ -99,13 +100,13 @@ def static_information_to_csv(
def run( def run(
model: lightning.pytorch.LightningModule, model: lightning.pytorch.LightningModule,
datamodule: lightning.pytorch.LightningDataModule, datamodule: lightning.pytorch.LightningDataModule,
checkpoint_period: int, validation_period: int,
device_manager: DeviceManager, device_manager: DeviceManager,
max_epochs: int, max_epochs: int,
output_folder: pathlib.Path, output_folder: pathlib.Path,
monitoring_interval: int | float, monitoring_interval: int | float,
batch_chunk_count: int, batch_chunk_count: int,
checkpoint: str | None, checkpoint: pathlib.Path | None,
): ):
"""Fits a CNN model using supervised learning and save it to disk. """Fits a CNN model using supervised learning and save it to disk.
...@@ -122,9 +123,15 @@ def run( ...@@ -122,9 +123,15 @@ def run(
datamodule datamodule
The lightning datamodule to use for training **and** validation The lightning datamodule to use for training **and** validation
checkpoint_period validation_period
Save a checkpoint every ``n`` epochs. If set to ``0`` (zero), then do Number of epochs after which validation happens. By default, we run
not save intermediary checkpoints. validation after every training epoch (period=1). You can change this
to make validation more sparse, by increasing the validation period.
Notice that this affects checkpoint saving. While checkpoints are
created after every training step (the last training step always
triggers the overriding of latest checkpoint), and that this process is
independent of validation runs, evaluation of the 'best' model obtained
so far based on those will be influenced by this setting.
device_manager device_manager
An internal device representation, to be used for training and An internal device representation, to be used for training and
...@@ -177,17 +184,22 @@ def run( ...@@ -177,17 +184,22 @@ def run(
logging_level=logging.ERROR, logging_level=logging.ERROR,
) )
checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint( # This checkpointer will operate at the end of every validation epoch
output_folder, # (which happens at each checkpoint period), it will then save the lowest
"model_lowest_valid_loss", # validation loss model observed. It will also save the last trained model
save_last=True, checkpoint_minvalloss_callback = lightning.pytorch.callbacks.ModelCheckpoint(
dirpath=output_folder,
filename=CHECKPOINT_ALIASES["best"],
save_last=True, # will (re)create the last trained model, at every iteration
monitor="loss/validation", monitor="loss/validation",
mode="min", mode="min",
save_on_train_epoch_end=True, save_on_train_epoch_end=True, # run checks at the end of validation
every_n_epochs=checkpoint_period, every_n_epochs=validation_period, # frequency at which it would check the "monitor"
enable_version_counter=False, # no versioning of aliased checkpoints
) )
checkpoint_minvalloss_callback.CHECKPOINT_NAME_LAST = CHECKPOINT_ALIASES[ # type: ignore
checkpoint_callback.CHECKPOINT_NAME_LAST = "model_final_epoch" "periodic"
]
# write static information to a CSV file # write static information to a CSV file
static_information_to_csv( static_information_to_csv(
...@@ -204,9 +216,13 @@ def run( ...@@ -204,9 +216,13 @@ def run(
max_epochs=max_epochs, max_epochs=max_epochs,
accumulate_grad_batches=batch_chunk_count, accumulate_grad_batches=batch_chunk_count,
logger=tensorboard_logger, logger=tensorboard_logger,
check_val_every_n_epoch=1, check_val_every_n_epoch=validation_period,
log_every_n_steps=len(datamodule.train_dataloader()), log_every_n_steps=len(datamodule.train_dataloader()),
callbacks=[LoggingCallback(resource_monitor), checkpoint_callback], callbacks=[
LoggingCallback(resource_monitor),
checkpoint_minvalloss_callback,
],
) )
_ = trainer.fit(model, datamodule, ckpt_path=checkpoint) checkpoint_str = checkpoint if checkpoint is None else str(checkpoint)
_ = trainer.fit(model, datamodule, ckpt_path=checkpoint_str)
...@@ -42,7 +42,7 @@ def experiment( ...@@ -42,7 +42,7 @@ def experiment(
batch_chunk_count, batch_chunk_count,
drop_incomplete_batch, drop_incomplete_batch,
datamodule, datamodule,
checkpoint_period, validation_period,
device, device,
cache_samples, cache_samples,
seed, seed,
...@@ -84,7 +84,7 @@ def experiment( ...@@ -84,7 +84,7 @@ def experiment(
batch_chunk_count=batch_chunk_count, batch_chunk_count=batch_chunk_count,
drop_incomplete_batch=drop_incomplete_batch, drop_incomplete_batch=drop_incomplete_batch,
datamodule=datamodule, datamodule=datamodule,
checkpoint_period=checkpoint_period, validation_period=validation_period,
device=device, device=device,
cache_samples=cache_samples, cache_samples=cache_samples,
seed=seed, seed=seed,
...@@ -111,13 +111,12 @@ def experiment( ...@@ -111,13 +111,12 @@ def experiment(
logger.info("Started predicting") logger.info("Started predicting")
from .predict import predict from ..utils.checkpointer import get_checkpoint_to_run_inference
model_file = get_checkpoint_to_run_inference(train_output_folder)
logger.info(f"Found `{str(model_file)}`. Continuing...")
# preferably, we use the best model on the validation set from .predict import predict
# otherwise, we get the last saved model
model_file = train_output_folder / "model_lowest_valid_loss.ckpt"
if not model_file.exists():
model_file = train_output_folder / "model_final_epoch.ckpt"
predictions_output = output_folder / "predictions.json" predictions_output = output_folder / "predictions.json"
......
...@@ -125,16 +125,21 @@ def reusable_options(f): ...@@ -125,16 +125,21 @@ def reusable_options(f):
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
"--checkpoint-period", "--validation-period",
"-p", "-p",
help="""Number of epochs after which a checkpoint is saved. A value of help="""Number of epochs after which validation happens. By default,
zero will disable check-pointing. If checkpointing is enabled and we run validation after every training epoch (period=1). You can
training stops, it is automatically resumed from the last saved change this to make validation more sparse, by increasing the
checkpoint if training is restarted with the same configuration.""", validation period. Notice that this affects checkpoint saving. While
checkpoints are created after every training step (the last training
step always triggers the overriding of latest checkpoint), and that
this process is independent of validation runs, evaluation of the
'best' model obtained so far based on those will be influenced by this
setting.""",
show_default=True, show_default=True,
required=False, required=True,
default=None, default=1,
type=click.IntRange(min=0), type=click.IntRange(min=1),
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
...@@ -183,27 +188,19 @@ def reusable_options(f): ...@@ -183,27 +188,19 @@ def reusable_options(f):
"--monitoring-interval", "--monitoring-interval",
"-I", "-I",
help="""Time between checks for the use of resources during each training help="""Time between checks for the use of resources during each training
epoch. An interval of 5 seconds, for example, will lead to CPU and GPU epoch, in seconds. An interval of 5 seconds, for example, will lead to
resources being probed every 5 seconds during each training epoch. CPU and GPU resources being probed every 5 seconds during each training
Values registered in the training logs correspond to averages (or maxima) epoch. Values registered in the training logs correspond to averages
observed through possibly many probes in each epoch. Notice that setting a (or maxima) observed through possibly many probes in each epoch.
very small value may cause the probing process to become extremely busy, Notice that setting a very small value may cause the probing process to
potentially biasing the overall perception of resource usage.""", become extremely busy, potentially biasing the overall perception of
resource usage.""",
type=click.FloatRange(min=0.1), type=click.FloatRange(min=0.1),
show_default=True, show_default=True,
required=True, required=True,
default=5.0, default=5.0,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option(
"--resume-from",
help="""Which checkpoint to resume training from. If set, can be one of
`best`, `last`, or a path to a model checkpoint.""",
type=click.STRING,
required=False,
default=None,
cls=ResourceOption,
)
@click.option( @click.option(
"--balance-classes/--no-balance-classes", "--balance-classes/--no-balance-classes",
"-B/-N", "-B/-N",
...@@ -244,13 +241,12 @@ def train( ...@@ -244,13 +241,12 @@ def train(
batch_chunk_count, batch_chunk_count,
drop_incomplete_batch, drop_incomplete_batch,
datamodule, datamodule,
checkpoint_period, validation_period,
device, device,
cache_samples, cache_samples,
seed, seed,
parallel, parallel,
monitoring_interval, monitoring_interval,
resume_from,
balance_classes, balance_classes,
**_, **_,
) -> None: ) -> None:
...@@ -263,20 +259,31 @@ def train( ...@@ -263,20 +259,31 @@ def train(
resume the procedure in case it stops abruptly. resume the procedure in case it stops abruptly.
""" """
import os
import torch import torch
from lightning.pytorch import seed_everything from lightning.pytorch import seed_everything
from ..engine.device import DeviceManager from ..engine.device import DeviceManager
from ..engine.trainer import run from ..engine.trainer import run
from ..utils.checkpointer import get_checkpoint from ..utils.checkpointer import get_checkpoint_to_resume_training
from .utils import save_sh_command from .utils import save_sh_command
checkpoint_file = None
if os.path.isdir(output_folder):
try:
checkpoint_file = get_checkpoint_to_resume_training(output_folder)
except FileNotFoundError:
logger.info(
f"Folder {output_folder} already exists, but I did not"
f" find any usable checkpoint file to resume training"
f" from. Starting from scratch..."
)
save_sh_command(output_folder / "command.sh") save_sh_command(output_folder / "command.sh")
seed_everything(seed) seed_everything(seed)
checkpoint_file = get_checkpoint(output_folder, resume_from)
# reset datamodule with user configurable options # reset datamodule with user configurable options
datamodule.set_chunk_size(batch_size, batch_chunk_count) datamodule.set_chunk_size(batch_size, batch_chunk_count)
datamodule.drop_incomplete_batch = drop_incomplete_batch datamodule.drop_incomplete_batch = drop_incomplete_batch
...@@ -307,25 +314,31 @@ def train( ...@@ -307,25 +314,31 @@ def train(
arguments["epoch"] = 0 arguments["epoch"] = 0
if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"): if checkpoint_file is None or not hasattr(model, "on_load_checkpoint"):
# Sets the model normalizer with the unaugmented-train-subset. # Sets the model normalizer with the unaugmented-train-subset if we are
# this call may be a NOOP, if the model was pre-trained and expects # starting from scratch and/or the model does not contain its own
# different weights for the normalisation layer. # checkpoint loading strategy (e.g. a pytorch stock checkpoint). This
# call may be a NOOP, if the model comes from outside this framework,
# and expects different weights for the normalisation layer.
if hasattr(model, "set_normalizer"): if hasattr(model, "set_normalizer"):
model.set_normalizer(datamodule.unshuffled_train_dataloader()) model.set_normalizer(datamodule.unshuffled_train_dataloader())
else: else:
logger.warning( logger.warning(
f"Model {model.name} has no 'set_normalizer' method. Skipping." f"Model {model.name} has no `set_normalizer` method. "
"Skipping normalization setup (unsupported external model)."
) )
else: else:
# Normalizer will be loaded during model.on_load_checkpoint # Normalizer will be loaded during model.on_load_checkpoint
checkpoint = torch.load(checkpoint_file) checkpoint = torch.load(checkpoint_file)
start_epoch = checkpoint["epoch"] start_epoch = checkpoint["epoch"]
logger.info(f"Resuming from epoch {start_epoch}...") logger.info(
f"Resuming from epoch {start_epoch} "
f"(checkpoint file: `{str(checkpoint_file)}`)..."
)
run( run(
model=model, model=model,
datamodule=datamodule, datamodule=datamodule,
checkpoint_period=checkpoint_period, validation_period=validation_period,
device_manager=DeviceManager(device), device_manager=DeviceManager(device),
max_epochs=epochs, max_epochs=epochs,
output_folder=output_folder, output_folder=output_folder,
......
...@@ -4,76 +4,146 @@ ...@@ -4,76 +4,146 @@
import logging import logging
import pathlib import pathlib
import re
import typing import typing
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_checkpoint( CHECKPOINT_ALIASES = {
output_folder: pathlib.Path, "best": "model-at-lowest-validation-loss-{epoch}",
resume_from: typing.Literal["last", "best"] | str | None, "periodic": "model-at-{epoch}",
) -> str | None: }
"""Gets a checkpoint file. """Standard paths where checkpoints may be (if produced with this
framework)."""
Can return the best or last checkpoint, or a checkpoint at a specific path. CHECKPOINT_EXTENSION = ".ckpt"
Ensures the checkpoint exists, raising an error if it is not the case.
If ``resume_from`` is ``None``, checks the output directory if a "last"
checkpoint file already exists and returns it. If no checkpoint is found,
returns ``None``.
``resume_from`` can also be a path to an existing checkpoint file. In this def _get_checkpoint_from_alias(
case, we check it and return if it exists. path: pathlib.Path,
alias: typing.Literal["best", "periodic"],
) -> pathlib.Path:
"""Gets an existing checkpoint file path.
This function can search for names matching the checkpoint alias "stem"
(ie. the prefix), and then assumes a dash "-" and a number follows that
prefix before the expected file extension. The number is parsed and
considred to be an epoch number. The latest file (the file containing the
highest epoch number) is returned.
If only one file is present matching the alias characteristics, then it is
returned.
Parameters Parameters
---------- ----------
output_folder path
Folder in which checkpoints are stored. Folder in which may contain checkpoint
resume_from alias
Which model to get. Can be one of "best", "last", or a path to a checkpoint. Can be one of "best" or "periodic".
If ``None``, gets the last checkpoint if it exists, otherwise returns
``None`` (signal to start from scratch).
Returns Returns
------- -------
Path to the requested checkpoint (as a plain string) or ``None`` (start Path to the requested checkpoint, or ``None``, if no checkpoint file
from scratch). matching specifications is found on the provided path.
Raises Raises
------ ------
FileNotFoundError FileNotFoundError
In case a required file cannot be found. In case it cannot find any file on the provided path matching the given
specifications.
""" """
# standard paths where checkpoints may be (if produced with this framework)
last_path = output_folder / "model_final_epoch.ckpt" template = path / (CHECKPOINT_ALIASES[alias] + CHECKPOINT_EXTENSION)
best_path = output_folder / "model_lowest_valid_loss.ckpt"
if template.exists():
if resume_from in ("last", "best"): return template
use_file = last_path if resume_from == "last" else best_path
if use_file.is_file(): # otherwise, we see if we are looking for a template instead, in which case
logger.info(f"Found checkpoint at `{str(use_file)}`") # we must pick the latest.
return str(use_file) assert "{epoch}" in str(
else: template
raise FileNotFoundError( ), f"Template `{str(template)}` does not contain the keyword `{{epoch}}`"
f"Could not find a checkpoint file at `{str(use_file)}`"
) pattern = re.compile(
template.name.replace("{epoch}", r"epoch=(?P<epoch>\d+)")
elif resume_from is None: )
# use-case: user is re-starting a crashed/cancelled job highest = -1
if last_path.is_file(): for f in template.parent.iterdir():
logger.info(f"Found checkpoint at `{str(last_path)}`") match = pattern.match(f.name)
return str(last_path) if match is not None:
else: value = int(match.group("epoch"))
return None if value > highest:
highest = value
elif isinstance(resume_from, str):
if pathlib.Path(resume_from).is_file(): if highest != -1:
logger.info(f"Found checkpoint at `{resume_from}`") return template.with_name(
return resume_from template.name.replace("{epoch}", f"epoch={highest}")
else: )
raise FileNotFoundError(
f"Could not find a checkpoint file at `{resume_from}`" raise FileNotFoundError(
) f"A file matching `{str(template)}` specifications was not found"
)
def get_checkpoint_to_resume_training(
path: pathlib.Path,
):
"""Returns the best checkpoint file path to resume training from.
Parameters
----------
path
The base directory containing either the "periodic" checkpoint to start
the training session from.
Returns
-------
Path to a checkpoint file that exists on disk
Raises
------
FileNotFoundError
If none of the checkpoints can be found on the provided directory.
"""
return _get_checkpoint_from_alias(path, "periodic")
def get_checkpoint_to_run_inference(
path: pathlib.Path,
):
"""Returns the best checkpoint file path to run inference with.
Parameters
----------
path
The base directory containing either the "best", "last" or "periodic"
checkpoint to start the training session from.
Returns
-------
Path to a checkpoint file that exists on disk
Raises
------
FileNotFoundError
If none of the checkpoints can be found on the provided directory.
"""
try:
_get_checkpoint_from_alias(path, "best")
except FileNotFoundError:
logger.error(
"Did not find lowest-validation-loss model to run inference "
"from. Trying to search for the last periodically saved model..."
)
return _get_checkpoint_from_alias(path, "periodic")
...@@ -193,11 +193,15 @@ def test_compare_vis_help(): ...@@ -193,11 +193,15 @@ def test_compare_vis_help():
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_train_pasa_montgomery(temporary_basedir): def test_train_pasa_montgomery(temporary_basedir):
from ptbench.scripts.train import train from ptbench.scripts.train import train
from ptbench.utils.checkpointer import (
CHECKPOINT_EXTENSION,
_get_checkpoint_from_alias,
)
runner = CliRunner() runner = CliRunner()
with stdout_logging() as buf: with stdout_logging() as buf:
output_folder = str(temporary_basedir / "results") output_folder = temporary_basedir / "results"
result = runner.invoke( result = runner.invoke(
train, train,
[ [
...@@ -206,17 +210,17 @@ def test_train_pasa_montgomery(temporary_basedir): ...@@ -206,17 +210,17 @@ def test_train_pasa_montgomery(temporary_basedir):
"-vv", "-vv",
"--epochs=1", "--epochs=1",
"--batch-size=1", "--batch-size=1",
f"--output-folder={output_folder}", f"--output-folder={str(output_folder)}",
], ],
) )
_assert_exit_0(result) _assert_exit_0(result)
assert os.path.exists( # asserts checkpoints are there, or raises FileNotFoundError
os.path.join(output_folder, "model_final_epoch.ckpt") last = _get_checkpoint_from_alias(output_folder, "periodic")
) assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
assert os.path.exists( best = _get_checkpoint_from_alias(output_folder, "best")
os.path.join(output_folder, "model_lowest_valid_loss.ckpt") assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
)
assert os.path.exists(os.path.join(output_folder, "constants.csv")) assert os.path.exists(os.path.join(output_folder, "constants.csv"))
assert ( assert (
len( len(
...@@ -254,10 +258,14 @@ def test_train_pasa_montgomery(temporary_basedir): ...@@ -254,10 +258,14 @@ def test_train_pasa_montgomery(temporary_basedir):
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
from ptbench.scripts.train import train from ptbench.scripts.train import train
from ptbench.utils.checkpointer import (
CHECKPOINT_EXTENSION,
_get_checkpoint_from_alias,
)
runner = CliRunner() runner = CliRunner()
output_folder = str(temporary_basedir / "results/pasa_checkpoint") output_folder = temporary_basedir / "results" / "pasa_checkpoint"
result0 = runner.invoke( result0 = runner.invoke(
train, train,
[ [
...@@ -266,15 +274,17 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): ...@@ -266,15 +274,17 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
"-vv", "-vv",
"--epochs=1", "--epochs=1",
"--batch-size=1", "--batch-size=1",
f"--output-folder={output_folder}", f"--output-folder={str(output_folder)}",
], ],
) )
_assert_exit_0(result0) _assert_exit_0(result0)
assert os.path.exists(os.path.join(output_folder, "model_final_epoch.ckpt")) # asserts checkpoints are there, or raises FileNotFoundError
assert os.path.exists( last = _get_checkpoint_from_alias(output_folder, "periodic")
os.path.join(output_folder, "model_lowest_valid_loss.ckpt") assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
) best = _get_checkpoint_from_alias(output_folder, "best")
assert best.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
assert os.path.exists(os.path.join(output_folder, "constants.csv")) assert os.path.exists(os.path.join(output_folder, "constants.csv"))
assert ( assert (
len( len(
...@@ -301,12 +311,11 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): ...@@ -301,12 +311,11 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
) )
_assert_exit_0(result) _assert_exit_0(result)
assert os.path.exists( # asserts checkpoints are there, or raises FileNotFoundError
os.path.join(output_folder, "model_final_epoch.ckpt") last = _get_checkpoint_from_alias(output_folder, "periodic")
) assert last.name.endswith("epoch=1" + CHECKPOINT_EXTENSION)
assert os.path.exists( best = _get_checkpoint_from_alias(output_folder, "best")
os.path.join(output_folder, "model_lowest_valid_loss.ckpt")
)
assert os.path.exists(os.path.join(output_folder, "constants.csv")) assert os.path.exists(os.path.join(output_folder, "constants.csv"))
assert ( assert (
...@@ -348,11 +357,19 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir): ...@@ -348,11 +357,19 @@ def test_train_pasa_montgomery_from_checkpoint(temporary_basedir):
@pytest.mark.skip_if_rc_var_not_set("datadir.montgomery") @pytest.mark.skip_if_rc_var_not_set("datadir.montgomery")
def test_predict_pasa_montgomery(temporary_basedir, datadir): def test_predict_pasa_montgomery(temporary_basedir, datadir):
from ptbench.scripts.predict import predict from ptbench.scripts.predict import predict
from ptbench.utils.checkpointer import (
CHECKPOINT_EXTENSION,
_get_checkpoint_from_alias,
)
runner = CliRunner() runner = CliRunner()
with stdout_logging() as buf: with stdout_logging() as buf:
output = str(temporary_basedir / "predictions") output = temporary_basedir / "predictions"
last = _get_checkpoint_from_alias(
temporary_basedir / "results", "periodic"
)
assert last.name.endswith("epoch=0" + CHECKPOINT_EXTENSION)
result = runner.invoke( result = runner.invoke(
predict, predict,
[ [
...@@ -360,7 +377,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir): ...@@ -360,7 +377,7 @@ def test_predict_pasa_montgomery(temporary_basedir, datadir):
"montgomery", "montgomery",
"-vv", "-vv",
"--batch-size=1", "--batch-size=1",
f"--weight={str(temporary_basedir / 'results' / 'model_final_epoch.ckpt')}", f"--weight={str(last)}",
f"--output={output}", f"--output={output}",
], ],
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment