Commit 6d198499 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira

Merge branch 'attempts' into 'master'

Implemented a mechanism in the Checkpoint wrapper that asserts if data was...

Closes #31

See merge request !63
parents 0b743b85 2230ecdd
Pipeline #50741 passed with stages
in 8 minutes and 52 seconds
......@@ -238,6 +238,11 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
This is useful when is desirable file directories with less than
a certain number of files.
attempts
Number of checkpoint attempts. Sometimes, because of network/disk issues
files can't be saved. This argument sets the maximum number of attempts
to checkpoint a sample.
"""
def __init__(
......@@ -250,6 +255,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
load_func=None,
sample_attribute="data",
hash_fn=None,
attempts=5,
**kwargs,
):
super().__init__(**kwargs)
......@@ -269,6 +275,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
)
self.sample_attribute = sample_attribute
self.hash_fn = hash_fn
self.attempts = attempts
if model_path is None and features_dir is None:
logger.warning(
"Both model_path and features_dir are None. "
......@@ -370,10 +377,17 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
os.makedirs(os.path.dirname(path), exist_ok=True)
# Gets sample.data or sample.<sample_attribute> if specified
to_save = getattr(sample, self.sample_attribute)
try:
self.save_func(to_save, path)
except Exception as e:
raise RuntimeError(f"Could not save {to_save} duing {self}.save") from e
for _ in range(self.attempts):
try:
self.save_func(to_save, path)
# test loading
self.load_func(path)
break
except Exception:
pass
else:
raise RuntimeError(f"Could not save {to_save} doing {self}.save")
def load(self, sample, path):
# because we are checkpointing, we return a DelayedSample
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment