diff --git a/src/ptbench/utils/image.py b/src/ptbench/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..363a8309f4581dfdda42124c0562d3bf840904af --- /dev/null +++ b/src/ptbench/utils/image.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import os + +from typing import Union + +import torch + +from PIL.Image import Image +from torchvision import transforms + + +def save_image(img: Union[torch.Tensor, Image], filepath: str) -> None: + """Saves a PIL image or a tensor as an image at the specified destination. + + Parameters + ---------- + + img: + A torch.Tensor or PIL.Image to save + + filepath: + The file in which to save the image. The format is inferred from the file extension, or defaults to png if not specified. + """ + + if isinstance(img, torch.Tensor): + img = transforms.ToPILImage()(img) + + root, ext = os.path.splitext(filepath) + + if len(ext) == 0: + filepath = filepath + ".png" + + img.save(filepath)