From cf3507983a4b2562c13ce6eb4af299ce02001698 Mon Sep 17 00:00:00 2001
From: Amir MOHAMMADI <amir.mohammadi@idiap.ch>
Date: Mon, 20 Jun 2022 19:43:11 +0200
Subject: [PATCH] Improves the search_file function Speeds up searching and
 fixes #88

---
 .../data/example_csv_filelist.tar.gz          | Bin 551 -> 739 bytes
 bob/extension/download.py                     |  61 ++++++++----------
 bob/extension/test_download.py                |  50 ++++++++------
 3 files changed, 56 insertions(+), 55 deletions(-)

diff --git a/bob/extension/data/example_csv_filelist.tar.gz b/bob/extension/data/example_csv_filelist.tar.gz
index 556fe6903f54188c757f085ca4bcc1bfcfc54e1b..ebbf2c0c93fd65d57d29449667375525f14117d3 100644
GIT binary patch
literal 739
zcmV<90v!DxiwFP!000001MQmKZqq;zg>%hQ<N++|nf)=nL=Xfw2vyu$rg4QOJC5Q|
z;O((p+O*W2>~8C=M9+7VpRv4Yj?eLQC*`lA+16!#-d*MwRbAFqw?A1tEoo0?%<!nv
zjILf=JwE<rL7I&wR!Z4LN`@6Jmgb2+$0R?xeX(0E7e&*4FPf%USJTH;>#Eqs_&{+f
zr}W?M+I@T8*7>@;%*)H7e%uw|q`%56S!-DL?{j}RL;n~!>c44sP4GqFtUoEmS$`Z;
zMj7;vfsg+CT~TfFpY5(*zwVEBUSkm0YW<-8n(0b^CT&Okb^l_qOlSPNZ+PMQe|h`$
zV{*3ousU0Pez(d$y??bX_QgrluFIMz@5-*L+D-ptyi6z|vCc>8>VY~>pP%oq&tjdA
z)YSv^!}|Q(f<(sA|NWEj9>6RAjT-XbkVXE-z(fCwc9-{+Z)N9|{to$1xANZ%{hwBu
zM*hb@%71kokLo5Kl~iWAtftF)^5pvd+FB~JTvpTN!`k|}g?J`B{Xc$vv-)3O`QO6-
zQE>7<$^IwV|0MgLW&e}xf0F%Avi}+Wck}<FHi5kQKbsBxpB?W1=>O)&NB__U@L2v^
zlOFQlj`?r;e*=*J5%A1^Bj>+ZR*9AK-^%%)$@wpq)pYr=_8Y0p$*ELU38hk9C&W_S
zBs6*JyRO^*{B!B=>;I4j@R<J%-SU5yj{aBZ|Dpi@JNTbu|I>{AoSRAiH^1|z90B^w
ze<>aJ|FlB?7X^3y!$*Mr`m;_k|BHgB{y`%^r}^I<{f*WJ_x~7R{)hRWoBzW{fC2f>
zgx~)}!B79N9U!Ou|GW90TI~Ns0rNlX|NHshe+MYU{|5U%QNa8U{&(|#_$E+5|EGrg
zKl;D<@zg)40d$)G-OYd2`2BAbu>X(wpQr!58$hS|-#`9m?Eggp{(sr?KL7v#00000
V000000093de*pT28o2;a007>ztX%*A

literal 551
zcmV+?0@(c@iwFRyYR6vy1MQemZ<{a_h4b8Bksp9s-|K6$Jx*Ju_GOx?y(3m!35kRV
zXxg7&L)WHK$qh~u=VAAI0dXu$ayTcB-BZ3?S1!A2H`#qzxvFfttJ!f;N?~fuaI~rF
zk1Dk^9RD!`om#NgAo=S+sUJ2sM}Ce`9@;K%=JQ#%c8x2WeAl?LDD!oU2a1au^MBpc
zU42(qS>ZOB+vL^f-iToTQBi8e{}@d4e^ocjgI|Kl{)5(jp!Xk$wD=!`5dXU-FIU-P
z-BiV;A3u4H-e60#(|<yE@!$Jy2mi?^@4q_X=YHcq_y6~|KiA3a;^*Ra@$35{yZLcZ
z<XwKXtP58GWUXu4vR?U<;WdE-SRME1dWVkV_s5U(Bdg;cUGLDh^W(V<iOj_O`;)`h
z073n4bl?9-DE%LU0RQjnCiBW)*9Nu!g#L$j{crm9pTA<I|6{=QzkaPp{Z@~flaZJ8
zxJ;urr;l?JCnGQGad|s8o!f{r5t#qs>pN-w=_UV4@*jms`3FA#z~>+M{3D-#;PVfB
z{(;XwO8%$i|5KYlLF+%84*&mA_Wz;~;{Txwz<vF1joS9V9s0lD29o}dKw$qHzW-TS
zCzkJj%lCiE_dhG^ad|uU7f#0HD<|s&6sJ=H$mu45(VxCo-TM2V%Ky;(AF=@4mwyBA
p^55$||3=n-QIPyk$UpG;hZE*g5ClOG1VIpv$ul%Im^A<>001BGFXaFL

diff --git a/bob/extension/download.py b/bob/extension/download.py
index e4e4382..e510eda 100644
--- a/bob/extension/download.py
+++ b/bob/extension/download.py
@@ -2,6 +2,7 @@
 # vim: set fileencoding=utf-8 :
 
 import bz2
+import glob
 import hashlib
 import io
 import logging
@@ -347,21 +348,22 @@ def find_element_in_tarball(filename, target_path, open_as_stream=False):
     """
 
     f = tarfile.open(filename)
-    for member in f.getmembers():
-        if member.isdir():
+    # iterate over the members of the tarball
+    while True:
+        member = f.next()
+        if member is None:
+            return None
+
+        if not member.isfile():
             continue
 
-        if (
-            member.isfile()
-            and target_path in member.name
-            and os.path.split(target_path)[-1] == os.path.split(member.name)[-1]
-        ):
-            if open_as_stream:
-                return io.BufferedReader(f.extractfile(member)).read()
-            else:
-                return io.TextIOWrapper(f.extractfile(member), encoding="utf-8")
-    else:
-        return None
+        if not member.name.endswith(target_path):
+            continue
+
+        if open_as_stream:
+            return io.BufferedReader(f.extractfile(member)).read()
+        else:
+            return io.TextIOWrapper(f.extractfile(member), encoding="utf-8")
 
 
 def search_file(base_path, options):
@@ -372,16 +374,18 @@ def search_file(base_path, options):
     ----------
 
     base_path: str
-       Base path to start the search, or the tarball to be searched
+        Base folder to start the search, or the tarball to be searched
 
     options: list
-       Files to be searched. This function will return the first occurency
+        Files to be searched. This function will return the first occurrence.
+        The option can be an incomplete relative path. For example, if you have
+        a file called ``"/a/b/c/d.txt"``, and base_path is ``"/a/b"``, then
+        options can be ``["d.txt"]``.
 
     Returns
     -------
     object
         It returns an opened file
-
     """
 
     if not isinstance(options, list):
@@ -389,26 +393,13 @@ def search_file(base_path, options):
 
     # If the input is a directory
     if os.path.isdir(base_path):
-
-        def get_fs():
-            fs = []
-            for root, _, files in os.walk(base_path, topdown=False):
-                for name in files:
-                    fs.append(os.path.join(root, name))
-            return fs
-
-        def search_in_list(o, lst):
-            for i, l in enumerate(lst):
-                if o in l:
-                    return i
-            else:
-                return -1
-
-        fs = get_fs()
         for o in options:
-            index = search_in_list(o, fs)
-            if index >= 0:
-                return open(fs[index])
+            # we append './' to o because o might start with /
+            pattern = os.path.join(base_path, "**", f"./{o}")
+            for path in glob.iglob(pattern, recursive=True):
+                if not os.path.isfile(path):
+                    continue
+                return open(path)
         else:
             return None
     else:
diff --git a/bob/extension/test_download.py b/bob/extension/test_download.py
index 43ed526..cf82f5d 100644
--- a/bob/extension/test_download.py
+++ b/bob/extension/test_download.py
@@ -97,27 +97,37 @@ def test_search_file():
     filename = pkg_resources.resource_filename(
         __name__, "data/example_csv_filelist.tar.gz"
     )
-    # Search in the tarball
-    assert (
-        search_file(filename, "protocol_dev_eval/norm/train_world.csv")
-        is not None
-    )
-    assert search_file(filename, "protocol_dev_eval/norm/xuxa.csv") is None
-
-    # Search in a file structure
-    final_path = "./test_search_file"
-
-    pass
-
-    _untar(filename, final_path, ".gz")
-
-    assert (
-        search_file(final_path, "protocol_dev_eval/norm/train_world.csv")
-        is not None
-    )
-    assert search_file(final_path, "protocol_dev_eval/norm/xuxa.csv") is None
 
-    shutil.rmtree(final_path)
+    with tempfile.TemporaryDirectory(suffix="_extracted") as tmpdir:
+
+        _untar(filename, tmpdir, ".gz")
+
+        # Search in the tarball and in its extracted folder
+        for final_path in (filename, tmpdir):
+            in_extracted_folder = final_path.endswith("_extracted")
+            all_files = list_dir(final_path)
+
+            output_file = search_file(
+                final_path, "protocol_dev_eval/norm/train_world.csv"
+            )
+            assert output_file is not None, all_files
+
+            # test to see if using / we can force an exact match
+            output_file = search_file(
+                final_path, "/protocol_dev_eval/norm/train_world.csv"
+            )
+            assert output_file is not None, all_files
+            assert "my_data" not in output_file.read()
+            if in_extracted_folder:
+                assert "my_protocol" not in output_file.name
+
+            assert (
+                search_file(final_path, "norm/train_world.csv") is not None
+            ), all_files
+            assert (
+                search_file(final_path, "protocol_dev_eval/norm/xuxa.csv")
+                is None
+            ), all_files
 
 
 def test_list_dir():
-- 
GitLab