From 647cd8ad0f5238b23bd2a5bd1d400d6457582c2a Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI Date: Thu, 12 Aug 2021 14:55:29 +0200 Subject: [PATCH 1/2] [CheckpointWrapper] Use atomic writing when saving features --- bob/pipelines/wrappers.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/bob/pipelines/wrappers.py b/bob/pipelines/wrappers.py index 3d6500d..bdea5a8 100644 --- a/bob/pipelines/wrappers.py +++ b/bob/pipelines/wrappers.py @@ -1,6 +1,7 @@ """Scikit-learn Estimator Wrappers.""" import logging import os +import tempfile from functools import partial @@ -393,9 +394,13 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin): to_save = getattr(sample, self.sample_attribute) for _ in range(self.attempts): try: - os.makedirs(os.path.dirname(path), exist_ok=True) + dirname = os.path.dirname(path) + os.makedirs(dirname, exist_ok=True) - self.save_func(to_save, path) + # Atomic writing + with tempfile.NamedTemporaryFile(dir=dirname, delete=False) as f: + self.save_func(to_save, f.name) + os.replace(f.name, path) # test loading self.load_func(path) -- GitLab From 755e6cdfb877a49899f53f774283f160f697afe0 Mon Sep 17 00:00:00 2001 From: Amir MOHAMMADI Date: Thu, 12 Aug 2021 15:18:11 +0200 Subject: [PATCH 2/2] [CheckpointWrapper] Use correct extension during atomic writing --- bob/pipelines/wrappers.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/bob/pipelines/wrappers.py b/bob/pipelines/wrappers.py index bdea5a8..bb9941f 100644 --- a/bob/pipelines/wrappers.py +++ b/bob/pipelines/wrappers.py @@ -4,6 +4,7 @@ import os import tempfile from functools import partial +from pathlib import Path import cloudpickle import dask.bag @@ -394,11 +395,15 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin): to_save = getattr(sample, self.sample_attribute) for _ in range(self.attempts): try: + dirname = os.path.dirname(path) os.makedirs(dirname, exist_ok=True) # Atomic writing - with tempfile.NamedTemporaryFile(dir=dirname, delete=False) as f: + extension = "".join(Path(path).suffixes) + with tempfile.NamedTemporaryFile( + dir=dirname, delete=False, suffix=extension + ) as f: self.save_func(to_save, f.name) os.replace(f.name, path) -- GitLab