Skip to content
Snippets Groups Projects

epoch checkpointing fix

Merged Özgür Acar Güler requested to merge checkpointing_fix into main
4 unresolved threads
2 files
+ 82
38
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -288,13 +288,6 @@ def set_reproducible_cuda():
default=-1,
cls=ResourceOption,
)
@click.option(
"--weight",
"-w",
help="Path or URL to pretrained model file (.pth extension)",
required=False,
cls=ResourceOption,
)
@click.option(
"--normalization",
"-n",
@@ -337,7 +330,6 @@ def train(
device,
seed,
parallel,
weight,
normalization,
monitoring_interval,
**_,
@@ -363,7 +355,6 @@ def train(
from ..configs.datasets import get_positive_weights, get_samples_weights
from ..engine.trainer import run
from ..utils.checkpointer import Checkpointer
from ..utils.download import download_to_tempfile
device = setup_pytorch_device(device)
@@ -528,29 +519,7 @@ def train(
# Initialize epoch information
arguments = {}
arguments["epoch"] = 0
# Load pretrained weights if needed
if weight is not None:
if checkpointer.has_checkpoint():
logger.warning(
"Weights are being ignored because a checkpoint already exists. "
"Weights from checkpoint will be loaded instead."
)
extra_checkpoint_data = checkpointer.load()
else:
if weight.startswith("http"):
logger.info(f"Temporarily downloading '{weight}'...")
f = download_to_tempfile(weight, progress=True)
weight_fullpath = os.path.abspath(f.name)
else:
weight_fullpath = os.path.abspath(weight)
extra_checkpoint_data = checkpointer.load(
weight_fullpath, strict=False
)
else:
extra_checkpoint_data = checkpointer.load()
# Update epoch information with checkpoint data
extra_checkpoint_data = checkpointer.load()
arguments.update(extra_checkpoint_data)
arguments["max_epoch"] = epochs
Loading