diff --git a/src/bob/bio/base/database/utils.py b/src/bob/bio/base/database/utils.py index 17e6bfd471920d880433143d4207ba6891e11019..eddb4663adc75ba30a37c4979052aae139c58eb9 100644 --- a/src/bob/bio/base/database/utils.py +++ b/src/bob/bio/base/database/utils.py @@ -131,7 +131,9 @@ def search_in_archive_and_open( break else: logger.debug( - f"No file matching '{pattern}' were found in '{archive_path}'." + "No file matching '%s' were found in '%s'.", + pattern, + archive_path, ) return None @@ -146,7 +148,9 @@ def search_in_archive_and_open( break else: logger.debug( - f"No file matching '{pattern}' were found in '{archive_path}'." + "No file matching '%s' were found in '%s'.", + pattern, + archive_path, ) return zip_arch.open(name) @@ -274,7 +278,8 @@ def extract_archive( elif ".bz2" == archive_path.suffix: if sub_dir is not None: warnings.warn( - f"Ignored sub directory ({sub_dir}). Not supported for `.bz2` files.", + f"Ignored sub directory ({sub_dir}). Not supported for `.bz2` " + "files.", RuntimeWarning, ) extracted_file = destination / Path(archive_path.stem) @@ -298,8 +303,8 @@ def search_and_open( ) -> Union[IO[bytes], TextIO, None]: """Searches for a matching file recursively in a directory. - If ``base_dir`` points to an archive, the pattern will be searched inside that - archive. + If ``base_dir`` points to an archive, the pattern will be searched inside + that archive. Wildcards (``*``, ``?``, and ``**``) are supported (using :meth:`pathlib.Path.glob`). @@ -466,7 +471,8 @@ def _infer_filename_from_urls(urls=Union[list[str], str]) -> str: names = [u.split("/")[-1] for u in urls] if not all(n == names[0] for n in names): raise ValueError( - f"Cannot infer file name when urls point to different files ({names=})." + "Cannot infer file name when urls point to different files " + f"({names=})." ) return urls[0].split("/")[-1] @@ -481,6 +487,7 @@ def download_file( force: bool = False, extract: bool = False, makedirs: bool = True, + checksum_mismatch_download_attempts: int = 2, ) -> Path: """Downloads a remote file locally. @@ -489,18 +496,18 @@ def download_file( Parameters ---------- urls - The remote location of the server. If multiple addresses are given, we will try - to download from them in order until one succeeds. + The remote location of the server. If multiple addresses are given, we + will try to download from them in order until one succeeds. destination_directory - A path to a local directory where the file will be saved. If omitted, the file - will be saved in the folder pointed by the ``wdr.local_directory`` key in the - user configuration. + A path to a local directory where the file will be saved. If omitted, + the file will be saved in the folder pointed by the ``bob_data_dir`` key + in the user configuration. destination_sub_directory - An additional layer added to the destination directory (useful when using - ``destination_directory=None``). + An additional layer added to the destination directory (useful when + using ``destination_directory=None``). destination_filename - The final name of the local file. If omitted, the file will keep the name of - the remote file. + The final name of the local file. If omitted, the file will keep the + name of the remote file. checksum When provided, will compute the file's checksum and compare to this. force @@ -510,19 +517,24 @@ def download_file( If this is set, the parent directory path will be returned. makedirs Automatically make the parent directories of the new local file. + checksum_mismatch_download_attempts + Number of download attempts when the checksum does not match after + downloading, must be 1 or more. Returns ------- - The path to the new local file (or the parent directory if ``extract`` is True). + The path to the new local file (or the parent directory if ``extract`` is + True). Raises ------ RuntimeError When the URLs provided are all invalid. ValueError - When ``destination_filename`` is omitted and URLs point to files with different - names. - When the checksum of the file does not correspond to the provided ``checksum``. + - When ``destination_filename`` is omitted and URLs point to files with + different names. + - When the checksum of the file does not correspond to the provided + ``checksum``. """ if destination_filename is None: @@ -538,65 +550,92 @@ def download_file( destination_directory / destination_sub_directory ) + if checksum_mismatch_download_attempts < 1: + logger.warning( + "'Checksum_mismatch_download_attempts' must be greater than 0 " + "(got %d). Setting it to 1.", + checksum_mismatch_download_attempts, + ) + checksum_mismatch_download_attempts = 1 + local_file = destination_directory / destination_filename needs_download = True if not force and local_file.is_file(): logger.info( - f"File {local_file} already exists, skipping download ({force=})." + "File %s already exists, skipping download (force=%s).", + local_file, + force, ) needs_download = False if needs_download: - if isinstance(urls, str): - urls = [urls] - - for tries, url in enumerate(urls): - logger.debug(f"Retrieving file from '{url}'.") - try: - response = requests.get(url=url, timeout=10) - except requests.exceptions.ConnectionError as e: - if tries < len(urls) - 1: + for current_download_try in range(checksum_mismatch_download_attempts): + if isinstance(urls, str): + urls = [urls] + + for tries, url in enumerate(urls): + logger.debug("Retrieving file from '%s'.", url) + try: + response = requests.get(url=url, timeout=10) + except requests.exceptions.ConnectionError as e: + if tries < len(urls) - 1: + logger.info( + "Could not connect to %s. Trying other URLs.", + url, + ) + logger.debug(e) + continue + + logger.debug( + "http response: '%d: %s'.", + response.status_code, + response.reason, + ) + + if response.ok: + logger.debug("Got file from %s.", url) + break + elif tries < len(urls) - 1: logger.info( - f"Could not connect to {url}. Trying other URLs." + "Failed to get file from %s, trying other URLs.", url ) - logger.debug(e) - continue + logger.debug("requests.response was:\n%s", response) + else: + raise RuntimeError( + "Could not retrieve file from any of the provided URLs! " + f"({urls=})" + ) - logger.debug( - f"http response: '{response.status_code}: {response.reason}'." + if makedirs: + local_file.parent.mkdir(parents=True, exist_ok=True) + + with local_file.open("wb") as f: + f.write(response.content) + + # Check the created file integrity, re-download if needed + if checksum is None or verify_file( + local_file, checksum, hash_fct=checksum_fct + ): + break # Exit the re-download loop + logger.warning( + "Downloading %s created a file with a wrong checksum. Retry %d", + url, + current_download_try + 1, ) - - if response.ok: - logger.debug(f"Got file from {url}.") - break - elif tries < len(urls) - 1: - logger.info( - f"Failed to get file from {url}, trying other URLs." + if current_download_try >= checksum_mismatch_download_attempts - 1: + raise ValueError( + "The downloaded file hash " + f"({compute_crc(local_file, hash_fct=checksum_fct)}) for " + f"'{url}' does not correspond to '{checksum}', even after " + f"{checksum_mismatch_download_attempts} retries." ) - logger.debug(f"requests.response was\n{response}") - else: - raise RuntimeError( - f"Could not retrieve file from any of the provided URLs! ({urls=})" - ) - if makedirs: - local_file.parent.mkdir(parents=True, exist_ok=True) - - with local_file.open("wb") as f: - f.write(response.content) - - if checksum is not None: + elif checksum is not None: if not verify_file(local_file, checksum, hash_fct=checksum_fct): - if not needs_download: - raise ValueError( - f"The local file hash does not correspond to '{checksum}' " - f"and {force=} prevents overwriting." - ) raise ValueError( - "The downloaded file hash ('" - f"{compute_crc(local_file, hash_fct=checksum_fct)}') does not " - f"correspond to '{checksum}'." + f"The local file hash does not correspond to '{checksum}' and " + f"{force=} prevents overwriting." ) if extract: diff --git a/tests/test_download.py b/tests/test_download.py new file mode 100644 index 0000000000000000000000000000000000000000..949122b115dcb450e6708bf525cdab08c69d9ff7 --- /dev/null +++ b/tests/test_download.py @@ -0,0 +1,183 @@ +import tempfile + +from pathlib import Path + +import pytest + +from clapper.rc import UserDefaults + +from bob.bio.base.database.utils import download_file + +RESOURCE_URL = "https://www.idiap.ch/software/bob/databases/latest/base/atnt-f529acef.tar.gz" +RESOURCE_NAME = "atnt-f529acef.tar.gz" +RESOURCE_EXTRACTED_NAME = "atnt" +RESOURCE_CHECKSUM = "f529acef" +INVALID_URL_VALID_NAME = ( + "https://localhost/ysnctp/not/a/valid/path/atnt-f529acef.tar.gz" +) +INVALID_URL_INVALID_NAME = "https://localhost/ysnctp/not/a/valid/path" + + +def _create_custom_rc(rc_path: Path, **kwargs): + """This creates a config file dynamically, with the content of kwargs.""" + rc_path.parent.mkdir(exist_ok=True) + rc = UserDefaults(rc_path) + for k, v in kwargs.items(): + rc[k] = v + rc.write() + + +def test_download_file_defaults(monkeypatch: pytest.MonkeyPatch): + "Downloads to bob_data_dir, with all default settings." + with tempfile.TemporaryDirectory(prefix="test_download_") as tmp_dir: + dir_path = Path(tmp_dir) + data_path = dir_path / "bob_data" + monkeypatch.setenv("HOME", dir_path.as_posix()) + expected_result = data_path / RESOURCE_NAME + local_filename = download_file(urls=RESOURCE_URL) + assert local_filename == expected_result + assert local_filename.is_file() + + +def test_download_file_custom_data_dir_no_subdir( + monkeypatch: pytest.MonkeyPatch, +): + "Downloads to a custom bob_data_dir, with all default settings." + with tempfile.TemporaryDirectory(prefix="test_download_") as tmp_dir: + dir_path = Path(tmp_dir) + data_path = dir_path / "custom_bob_data" + rc_path = dir_path / ".config" / "bobrc.toml" + _create_custom_rc(rc_path=rc_path, bob_data_dir=data_path.as_posix()) + monkeypatch.setenv("HOME", dir_path.as_posix()) + expected_result = data_path / RESOURCE_NAME + local_filename = download_file(urls=RESOURCE_URL) + assert local_filename == expected_result + assert local_filename.is_file() + + +def test_download_file_custom_data_dir_and_subdir( + monkeypatch: pytest.MonkeyPatch, +): + "Downloads to a custom bob_data_dir, with all default settings." + with tempfile.TemporaryDirectory(prefix="test_download_") as tmp_dir: + dir_path = Path(tmp_dir) + data_path = dir_path / "custom_bob_data" + rc_path = dir_path / ".config" / "bobrc.toml" + _create_custom_rc(rc_path=rc_path, bob_data_dir=data_path.as_posix()) + monkeypatch.setenv("HOME", dir_path.as_posix()) + subdir = Path("download") / "subdir" + expected_result = data_path / subdir / RESOURCE_NAME + local_filename = download_file( + urls=RESOURCE_URL, destination_sub_directory=subdir + ) + assert local_filename == expected_result + assert local_filename.is_file() + + +def test_download_file_to_dir_no_subdir(): + with tempfile.TemporaryDirectory(prefix="test_download_") as tmp_dir: + destination = Path(tmp_dir) / "download_dir" + expected_result = destination / RESOURCE_NAME + local_filename = download_file( + urls=RESOURCE_URL, + destination_directory=destination, + ) + assert local_filename == expected_result + assert local_filename.is_file() + + +def test_download_file_to_dir_and_subdir(): + with tempfile.TemporaryDirectory(prefix="test_download_") as tmp_dir: + destination = Path(tmp_dir) + subdir = Path("download") / "subdir" + expected_result = destination / subdir / RESOURCE_NAME + local_filename = download_file( + urls=RESOURCE_URL, + destination_directory=destination, + destination_sub_directory=subdir, + ) + assert local_filename == expected_result + assert local_filename.is_file() + + +def test_download_file_rename(): + with tempfile.TemporaryDirectory(prefix="test_download_") as tmp_dir: + destination = Path(tmp_dir) + subdir = Path("download") / "subdir" + new_name = "custom_name.tar.gz" + expected_result = destination / subdir / new_name + local_filename = download_file( + urls=RESOURCE_URL, + destination_directory=destination, + destination_sub_directory=subdir, + destination_filename=new_name, + ) + assert local_filename == expected_result + assert local_filename.is_file() + + +def test_download_file_with_checksum(): + with tempfile.TemporaryDirectory(prefix="test_download_") as tmp_dir: + destination = Path(tmp_dir) + expected_result = destination / RESOURCE_NAME + local_filename = download_file( + urls=RESOURCE_URL, + destination_directory=destination, + checksum=RESOURCE_CHECKSUM, + ) + assert local_filename == expected_result + assert local_filename.is_file() + + +def test_download_file_multi_url_valid_names(): + with tempfile.TemporaryDirectory(prefix="test_download_") as tmp_dir: + destination = Path(tmp_dir) + expected_result = destination / RESOURCE_NAME + local_filename = download_file( + urls=[INVALID_URL_VALID_NAME, RESOURCE_URL], + destination_directory=destination, + checksum=RESOURCE_CHECKSUM, + ) + assert local_filename == expected_result + assert local_filename.is_file() + + +def test_download_file_multi_url_invalid_names(): + with tempfile.TemporaryDirectory(prefix="test_download_") as tmp_dir: + destination = Path(tmp_dir) + with pytest.raises(ValueError): + download_file( + urls=[RESOURCE_URL, INVALID_URL_INVALID_NAME], + destination_directory=destination, + checksum=RESOURCE_CHECKSUM, + ) + + +def test_download_file_extract_no_subdir(): + with tempfile.TemporaryDirectory(prefix="test_download_") as tmp_dir: + destination = Path(tmp_dir) + expected_result = destination + local_filename = download_file( + urls=RESOURCE_URL, + destination_directory=destination, + checksum=RESOURCE_CHECKSUM, + extract=True, + ) + assert local_filename == expected_result + assert (local_filename / RESOURCE_EXTRACTED_NAME).is_dir() + + +def test_download_file_extract_with_subdir(): + with tempfile.TemporaryDirectory(prefix="test_download_") as tmp_dir: + destination = Path(tmp_dir) + subdir = Path("download") / "subdir" + expected_result = destination / subdir + local_filename = download_file( + urls=RESOURCE_URL, + destination_directory=destination, + destination_sub_directory=subdir, + checksum=RESOURCE_CHECKSUM, + extract=True, + ) + assert local_filename == expected_result + assert (local_filename / RESOURCE_EXTRACTED_NAME).is_dir()