From 4ebe5d5b8923d65e0c42bf5f1250fb8e1b7982eb Mon Sep 17 00:00:00 2001
From: Yannick DAYER <yannick.dayer@idiap.ch>
Date: Sun, 5 Feb 2023 03:35:05 +0100
Subject: [PATCH] [py] Add support for a category sub directory.

---
 src/bob/pipelines/dataset/database.py |  4 ++--
 tests/test_database.py                | 29 ++++++++++++++++++++++++---
 2 files changed, 28 insertions(+), 5 deletions(-)

diff --git a/src/bob/pipelines/dataset/database.py b/src/bob/pipelines/dataset/database.py
index 2ea4653..93da39b 100644
--- a/src/bob/pipelines/dataset/database.py
+++ b/src/bob/pipelines/dataset/database.py
@@ -245,8 +245,8 @@ class FileListDatabase:
 
         # Save to bob_data/protocols, or if present, in a category sub directory.
         subdir = Path("protocols")
-        if hasattr(cls, "category"):
-            subdir = subdir / getattr(cls, "category")
+        if hasattr(cls, "dataset_protocols_category"):
+            subdir = subdir / getattr(cls, "dataset_protocols_category")
 
         # Retrieve the file from the server (or use the local version).
         return retrieve_protocols(
diff --git a/tests/test_database.py b/tests/test_database.py
index 96c6def..f4f68d7 100644
--- a/tests/test_database.py
+++ b/tests/test_database.py
@@ -101,7 +101,7 @@ def test_filelist_class(monkeypatch):
         monkeypatch.setenv("HOME", tmp_home.as_posix())
         desired_name = "atnt_filename.tar.gz"
 
-        class DBDownloadCustomFileName(FileListDatabase):
+        class DBDownloadCustomFilename(FileListDatabase):
             name = "atnt"
             dataset_protocols_checksum = "f529acef"
             dataset_protocols_urls = [
@@ -109,8 +109,31 @@ def test_filelist_class(monkeypatch):
             ]
             dataset_protocols_name = desired_name
 
-        assert DBDownloadCustomFileName.protocols() == ["idiap_protocol"]
+        assert DBDownloadCustomFilename.protocols() == ["idiap_protocol"]
         assert (
-            DBDownloadCustomFileName.retrieve_dataset_protocols()
+            DBDownloadCustomFilename.retrieve_dataset_protocols()
             == tmp_home / "bob_data" / "protocols" / desired_name
         )
+
+    with TemporaryDirectory(prefix="bobtest_") as tmpdir:
+        tmp_home = Path(tmpdir)
+        monkeypatch.setenv("HOME", tmp_home.as_posix())
+        desired_category = "custom_category"
+
+        class DBDownloadCustomCategory(FileListDatabase):
+            name = "atnt"
+            dataset_protocols_checksum = "f529acef"
+            dataset_protocols_urls = [
+                "https://www.idiap.ch/software/bob/databases/latest/base/atnt-f529acef.tar.gz"
+            ]
+            dataset_protocols_category = desired_category
+
+        assert DBDownloadCustomCategory.protocols() == ["idiap_protocol"]
+        assert (
+            DBDownloadCustomCategory.retrieve_dataset_protocols()
+            == tmp_home
+            / "bob_data"
+            / "protocols"
+            / desired_category
+            / "atnt-f529acef.tar.gz"
+        )
-- 
GitLab