Skip to content
Snippets Groups Projects

Implemented a mechanism in the Checkpoint wrapper that asserts if data was...

Merged Tiago de Freitas Pereira requested to merge attempts into master
+ 18
4
@@ -238,6 +238,11 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
This is useful when is desirable file directories with less than
a certain number of files.
attempts
Number of checkpoint attempts. Sometimes, because of network/disk issues
files can't be saved. This argument sets the maximum number of attempts
to checkpoint a sample.
"""
def __init__(
@@ -250,6 +255,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
load_func=None,
sample_attribute="data",
hash_fn=None,
attempts=5,
**kwargs,
):
super().__init__(**kwargs)
@@ -269,6 +275,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
)
self.sample_attribute = sample_attribute
self.hash_fn = hash_fn
self.attempts = attempts
if model_path is None and features_dir is None:
logger.warning(
"Both model_path and features_dir are None. "
@@ -370,10 +377,17 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
os.makedirs(os.path.dirname(path), exist_ok=True)
# Gets sample.data or sample.<sample_attribute> if specified
to_save = getattr(sample, self.sample_attribute)
try:
self.save_func(to_save, path)
except Exception as e:
raise RuntimeError(f"Could not save {to_save} duing {self}.save") from e
for _ in range(self.attempts):
try:
self.save_func(to_save, path)
# test loading
self.load_func(path)
break
except Exception:
pass
else:
raise RuntimeError(f"Could not save {to_save} doing {self}.save")
def load(self, sample, path):
# because we are checkpointing, we return a DelayedSample
Loading