Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import logging
import os
logger = logging.getLogger(__name__)
def get_checkpoint(output_folder, resume_from):
last_checkpoint_path = os.path.join(output_folder, "model_final_epoch.ckpt")
best_checkpoint_path = os.path.join(
output_folder, "model_lowest_valid_loss.ckpt"
)
if resume_from == "last":
if os.path.isfile(last_checkpoint_path):
checkpoint_file = last_checkpoint_path
logger.info(f"Resuming training from {resume_from} checkpoint")
else:
raise FileNotFoundError(
f"Could not find checkpoint {last_checkpoint_path}"
)
elif resume_from == "best":
if os.path.isfile(best_checkpoint_path):
checkpoint_file = last_checkpoint_path
logger.info(f"Resuming training from {resume_from} checkpoint")
else:
raise FileNotFoundError(
f"Could not find checkpoint {best_checkpoint_path}"
)
elif resume_from is None:
checkpoint_file = None
else:
if os.path.isfile(resume_from):
checkpoint_file = resume_from
logger.info(f"Resuming training from checkpoint {resume_from}")
else:
raise FileNotFoundError(f"Could not find checkpoint {resume_from}")
return checkpoint_file