diff --git a/bob/extension/download.py b/bob/extension/download.py index 8cbb9872fd142e3ec680d58528bf49952d151ba5..0480dce717e1936afd3056d85cfa5fb4d5226da8 100644 --- a/bob/extension/download.py +++ b/bob/extension/download.py @@ -1,12 +1,18 @@ #!/usr/bin/env python # vim: set fileencoding=utf-8 : -# Tiago de Freitas Pereira <tiago.pereira@idiap.ch> -import os -import logging +import bz2 import hashlib -from . import rc +import io +import logging import os +import tarfile +import zipfile +from pathlib import Path +from shutil import copyfileobj +from urllib.request import urlopen + +from . import rc logger = logging.getLogger(__name__) @@ -16,7 +22,6 @@ def _bob_data_folder(): def _unzip(zip_file, directory): - import zipfile with zipfile.ZipFile(zip_file) as myzip: myzip.extractall(directory) @@ -31,14 +36,11 @@ def _untar(tar_file, directory, ext): else: mode = "r" - import tarfile - with tarfile.open(name=tar_file, mode=mode) as t: t.extractall(directory) def _unbz2(bz2_file): - import bz2 with bz2.BZ2File(bz2_file) as t: open(os.path.splitext(bz2_file)[0], "wb").write(t.read()) @@ -88,24 +90,9 @@ def download_file(url, out_file): out_file : str Where to save the file. """ - import sys - - if sys.version_info[0] < 3: - # python2 technique for downloading a file - from urllib2 import urlopen - + with urlopen(url) as response: with open(out_file, "wb") as f: - response = urlopen(url) - f.write(response.read()) - - else: - # python3 technique for downloading a file - from urllib.request import urlopen - from shutil import copyfileobj - - with urlopen(url) as response: - with open(out_file, "wb") as f: - copyfileobj(response, f) + copyfileobj(response, f) def download_file_from_possible_urls(urls, out_file): @@ -348,8 +335,6 @@ def find_element_in_tarball(filename, target_path): object It returns an opened file """ - import tarfile - import io f = tarfile.open(filename) for member in f.getmembers(): @@ -424,3 +409,28 @@ def search_file(base_path, options): else: return None + + +def list_folders(base_path): + # If the input is a directory + path = Path(base_path) + if path.is_dir(): + return sorted(x.name for x in path.iterdir() if x.is_dir()) + # If it's not a directory is a tarball + elif tarfile.is_tarfile(base_path): + with tarfile.open(base_path, mode="r") as t: + tar_infos = t.getmembers() + commonpath = os.path.commonpath([info.name for info in tar_infos]) + commonpath = Path(commonpath) + top_folders = [] + for info in tar_infos: + if not info.isdir(): + continue + path = Path(info.name) + if path.parent == commonpath: + top_folders.append(path.name) + return sorted(top_folders) + else: + raise ValueError( + f"The provided path: `{base_path}` should be a directory or a tarball." + )