diff --git a/src/bob/pipelines/dataset/database.py b/src/bob/pipelines/dataset/database.py index 2ea4653cdf410d24ec7d73d804e0a7788974090d..93da39b6934468e8a6191f17d1eef595530cb987 100644 --- a/src/bob/pipelines/dataset/database.py +++ b/src/bob/pipelines/dataset/database.py @@ -245,8 +245,8 @@ class FileListDatabase: # Save to bob_data/protocols, or if present, in a category sub directory. subdir = Path("protocols") - if hasattr(cls, "category"): - subdir = subdir / getattr(cls, "category") + if hasattr(cls, "dataset_protocols_category"): + subdir = subdir / getattr(cls, "dataset_protocols_category") # Retrieve the file from the server (or use the local version). return retrieve_protocols( diff --git a/tests/test_database.py b/tests/test_database.py index 96c6def5f14c59888f98201e86eef205624aaf0e..f4f68d72f5073a7282b56f567ca89da157e2252c 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -101,7 +101,7 @@ def test_filelist_class(monkeypatch): monkeypatch.setenv("HOME", tmp_home.as_posix()) desired_name = "atnt_filename.tar.gz" - class DBDownloadCustomFileName(FileListDatabase): + class DBDownloadCustomFilename(FileListDatabase): name = "atnt" dataset_protocols_checksum = "f529acef" dataset_protocols_urls = [ @@ -109,8 +109,31 @@ def test_filelist_class(monkeypatch): ] dataset_protocols_name = desired_name - assert DBDownloadCustomFileName.protocols() == ["idiap_protocol"] + assert DBDownloadCustomFilename.protocols() == ["idiap_protocol"] assert ( - DBDownloadCustomFileName.retrieve_dataset_protocols() + DBDownloadCustomFilename.retrieve_dataset_protocols() == tmp_home / "bob_data" / "protocols" / desired_name ) + + with TemporaryDirectory(prefix="bobtest_") as tmpdir: + tmp_home = Path(tmpdir) + monkeypatch.setenv("HOME", tmp_home.as_posix()) + desired_category = "custom_category" + + class DBDownloadCustomCategory(FileListDatabase): + name = "atnt" + dataset_protocols_checksum = "f529acef" + dataset_protocols_urls = [ + "https://www.idiap.ch/software/bob/databases/latest/base/atnt-f529acef.tar.gz" + ] + dataset_protocols_category = desired_category + + assert DBDownloadCustomCategory.protocols() == ["idiap_protocol"] + assert ( + DBDownloadCustomCategory.retrieve_dataset_protocols() + == tmp_home + / "bob_data" + / "protocols" + / desired_category + / "atnt-f529acef.tar.gz" + )