diff --git a/bob/extension/data/example_csv_filelist.tar.gz b/bob/extension/data/example_csv_filelist.tar.gz index 556fe6903f54188c757f085ca4bcc1bfcfc54e1b..ebbf2c0c93fd65d57d29449667375525f14117d3 100644 Binary files a/bob/extension/data/example_csv_filelist.tar.gz and b/bob/extension/data/example_csv_filelist.tar.gz differ diff --git a/bob/extension/download.py b/bob/extension/download.py index e4e43826d32def7144899a243b82e3bcab41af8c..e510edaf389f811f7cfaae58097008aca3b26272 100644 --- a/bob/extension/download.py +++ b/bob/extension/download.py @@ -2,6 +2,7 @@ # vim: set fileencoding=utf-8 : import bz2 +import glob import hashlib import io import logging @@ -347,21 +348,22 @@ def find_element_in_tarball(filename, target_path, open_as_stream=False): """ f = tarfile.open(filename) - for member in f.getmembers(): - if member.isdir(): + # iterate over the members of the tarball + while True: + member = f.next() + if member is None: + return None + + if not member.isfile(): continue - if ( - member.isfile() - and target_path in member.name - and os.path.split(target_path)[-1] == os.path.split(member.name)[-1] - ): - if open_as_stream: - return io.BufferedReader(f.extractfile(member)).read() - else: - return io.TextIOWrapper(f.extractfile(member), encoding="utf-8") - else: - return None + if not member.name.endswith(target_path): + continue + + if open_as_stream: + return io.BufferedReader(f.extractfile(member)).read() + else: + return io.TextIOWrapper(f.extractfile(member), encoding="utf-8") def search_file(base_path, options): @@ -372,16 +374,18 @@ def search_file(base_path, options): ---------- base_path: str - Base path to start the search, or the tarball to be searched + Base folder to start the search, or the tarball to be searched options: list - Files to be searched. This function will return the first occurency + Files to be searched. This function will return the first occurrence. + The option can be an incomplete relative path. For example, if you have + a file called ``"/a/b/c/d.txt"``, and base_path is ``"/a/b"``, then + options can be ``["d.txt"]``. Returns ------- object It returns an opened file - """ if not isinstance(options, list): @@ -389,26 +393,13 @@ def search_file(base_path, options): # If the input is a directory if os.path.isdir(base_path): - - def get_fs(): - fs = [] - for root, _, files in os.walk(base_path, topdown=False): - for name in files: - fs.append(os.path.join(root, name)) - return fs - - def search_in_list(o, lst): - for i, l in enumerate(lst): - if o in l: - return i - else: - return -1 - - fs = get_fs() for o in options: - index = search_in_list(o, fs) - if index >= 0: - return open(fs[index]) + # we append './' to o because o might start with / + pattern = os.path.join(base_path, "**", f"./{o}") + for path in glob.iglob(pattern, recursive=True): + if not os.path.isfile(path): + continue + return open(path) else: return None else: diff --git a/bob/extension/rc_config.py b/bob/extension/rc_config.py index b318f21abd41312c8884563262a5f621d31356ad..536a4211edbf328c397ca16b43503b2998d4b7a6 100644 --- a/bob/extension/rc_config.py +++ b/bob/extension/rc_config.py @@ -97,3 +97,4 @@ def _saverc(context): path = _get_rc_path() with open(path, "wt") as f: f.write(_rc_to_str(context)) + f.write("\n") diff --git a/bob/extension/test_download.py b/bob/extension/test_download.py index 43ed52638bca86640dec38498e35d51f373d62e2..cf82f5d698d95e2e513e12dc4f840da9b11df848 100644 --- a/bob/extension/test_download.py +++ b/bob/extension/test_download.py @@ -97,27 +97,37 @@ def test_search_file(): filename = pkg_resources.resource_filename( __name__, "data/example_csv_filelist.tar.gz" ) - # Search in the tarball - assert ( - search_file(filename, "protocol_dev_eval/norm/train_world.csv") - is not None - ) - assert search_file(filename, "protocol_dev_eval/norm/xuxa.csv") is None - - # Search in a file structure - final_path = "./test_search_file" - - pass - - _untar(filename, final_path, ".gz") - - assert ( - search_file(final_path, "protocol_dev_eval/norm/train_world.csv") - is not None - ) - assert search_file(final_path, "protocol_dev_eval/norm/xuxa.csv") is None - shutil.rmtree(final_path) + with tempfile.TemporaryDirectory(suffix="_extracted") as tmpdir: + + _untar(filename, tmpdir, ".gz") + + # Search in the tarball and in its extracted folder + for final_path in (filename, tmpdir): + in_extracted_folder = final_path.endswith("_extracted") + all_files = list_dir(final_path) + + output_file = search_file( + final_path, "protocol_dev_eval/norm/train_world.csv" + ) + assert output_file is not None, all_files + + # test to see if using / we can force an exact match + output_file = search_file( + final_path, "/protocol_dev_eval/norm/train_world.csv" + ) + assert output_file is not None, all_files + assert "my_data" not in output_file.read() + if in_extracted_folder: + assert "my_protocol" not in output_file.name + + assert ( + search_file(final_path, "norm/train_world.csv") is not None + ), all_files + assert ( + search_file(final_path, "protocol_dev_eval/norm/xuxa.csv") + is None + ), all_files def test_list_dir(): diff --git a/bob/extension/test_rc.py b/bob/extension/test_rc.py index b20f342ba11da437f6fe2ec8e0ea66340c35ecbc..74c57c04deb62a59d838829b8053d701a9164dac 100644 --- a/bob/extension/test_rc.py +++ b/bob/extension/test_rc.py @@ -74,8 +74,14 @@ def test_bob_config(): assert expected_output == result.output, result.output # test config unset (with starting substring) - result = runner.invoke(main_cli, ["config", "unset", "bob.db.atnt"]) - result = runner.invoke(main_cli, ["config", "get", "bob.db.atnt"]) + result = runner.invoke( + main_cli, + ["config", "unset", "bob.db.atnt"], + env={ENVNAME: bobrcfile}, + ) + result = runner.invoke( + main_cli, ["config", "get", "bob.db.atnt"], env={ENVNAME: bobrcfile} + ) assert_click_runner_result(result, 1) # test config unset (with substring contained) @@ -91,7 +97,11 @@ def test_bob_config(): env={ENVNAME: bobrcfile}, ) result = runner.invoke( - main_cli, ["config", "unset", "--contain", "atnt"] + main_cli, + ["config", "unset", "--contain", "atnt"], + env={ENVNAME: bobrcfile}, + ) + result = runner.invoke( + main_cli, ["config", "get", "bob.db.atnt"], env={ENVNAME: bobrcfile} ) - result = runner.invoke(main_cli, ["config", "get", "bob.db.atnt"]) assert_click_runner_result(result, 1)