diff --git a/bob/pipelines/wrappers.py b/bob/pipelines/wrappers.py index 3d6500d89677634e1bac36f3a0a231be49cc29f6..bb9941f8f5e380279e1267757b1991b943e03d3f 100644 --- a/bob/pipelines/wrappers.py +++ b/bob/pipelines/wrappers.py @@ -1,8 +1,10 @@ """Scikit-learn Estimator Wrappers.""" import logging import os +import tempfile from functools import partial +from pathlib import Path import cloudpickle import dask.bag @@ -393,9 +395,17 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin): to_save = getattr(sample, self.sample_attribute) for _ in range(self.attempts): try: - os.makedirs(os.path.dirname(path), exist_ok=True) - self.save_func(to_save, path) + dirname = os.path.dirname(path) + os.makedirs(dirname, exist_ok=True) + + # Atomic writing + extension = "".join(Path(path).suffixes) + with tempfile.NamedTemporaryFile( + dir=dirname, delete=False, suffix=extension + ) as f: + self.save_func(to_save, f.name) + os.replace(f.name, path) # test loading self.load_func(path)