Commit 755e6cdf authored by Amir MOHAMMADI's avatar Amir MOHAMMADI
Browse files

[CheckpointWrapper] Use correct extension during atomic writing

parent 647cd8ad
Pipeline #53133 passed with stage
in 16 minutes and 16 seconds
......@@ -4,6 +4,7 @@ import os
import tempfile
from functools import partial
from pathlib import Path
import cloudpickle
import dask.bag
......@@ -394,11 +395,15 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin):
to_save = getattr(sample, self.sample_attribute)
for _ in range(self.attempts):
try:
dirname = os.path.dirname(path)
os.makedirs(dirname, exist_ok=True)
# Atomic writing
with tempfile.NamedTemporaryFile(dir=dirname, delete=False) as f:
extension = "".join(Path(path).suffixes)
with tempfile.NamedTemporaryFile(
dir=dirname, delete=False, suffix=extension
) as f:
self.save_func(to_save, f.name)
os.replace(f.name, path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment