Implemented a mechanism in the Checkpoint wrapper that asserts if data was...

Implemented a mechanism in the Checkpoint wrapper that asserts if data was properlly written in the disk
parent 0b743b85
Pipeline #50314 failed with stage
in 3 minutes and 32 seconds
......@@ -141,8 +141,7 @@ class SampleWrapper(BaseWrapper, TransformerMixin):
if isinstance(samples[0], SampleSet):
return [
SampleSet(
self._samples_transform(sset.samples, method_name),
parent=sset,
self._samples_transform(sset.samples, method_name), parent=sset,
)
for sset in samples
]
......@@ -238,6 +237,11 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
This is useful when is desirable file directories with less than
a certain number of files.
attempts
Number of checkpoint attempts. Sometimes, because of network/disk issues
files can't be saved. This argument sets the maximum number of attempts
to checkpoint a sample.
"""
def __init__(
......@@ -250,6 +254,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
load_func=None,
sample_attribute="data",
hash_fn=None,
attempts=5,
**kwargs,
):
super().__init__(**kwargs)
......@@ -269,6 +274,7 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
)
self.sample_attribute = sample_attribute
self.hash_fn = hash_fn
self.attempts = attempts
if model_path is None and features_dir is None:
logger.warning(
"Both model_path and features_dir are None. "
......@@ -370,10 +376,17 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
os.makedirs(os.path.dirname(path), exist_ok=True)
# Gets sample.data or sample.<sample_attribute> if specified
to_save = getattr(sample, self.sample_attribute)
try:
self.save_func(to_save, path)
except Exception as e:
raise RuntimeError(f"Could not save {to_save} duing {self}.save") from e
for _ in range(self.attempts):
try:
self.save_func(to_save, path)
# test loading
loaded = self.load_func(path)
break
except:
pass
else:
raise RuntimeError(f"Could not save {to_save} doing {self}.save")
def load(self, sample, path):
# because we are checkpointing, we return a DelayedSample
......@@ -424,11 +437,7 @@ class DaskWrapper(BaseWrapper, TransformerMixin):
"""
def __init__(
self,
estimator,
fit_tag=None,
transform_tag=None,
**kwargs,
self, estimator, fit_tag=None, transform_tag=None, **kwargs,
):
super().__init__(**kwargs)
self.estimator = estimator
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment