Skip to content
Snippets Groups Projects

Add protocols as classmethod for FileListDatabase

Merged Yannick DAYER requested to merge add-protocols into master
@@ -18,7 +18,7 @@ from typing import Any, Optional, TextIO
import sklearn.pipeline
from bob.extension.download import list_dir, search_file
from bob.extension.download import get_file, list_dir, search_file
from .sample import Sample
from .utils import check_parameter_for_validity, check_parameters_for_validity
@@ -117,8 +117,8 @@ class FileListDatabase:
def __init__(
self,
dataset_protocols_path: str,
protocol: str,
dataset_protocols_path: Optional[str] = None,
reader_cls: Iterable = CSVToSamples,
transformer: Optional[sklearn.pipeline.Pipeline] = None,
**kwargs,
@@ -141,6 +141,8 @@ class FileListDatabase:
ValueError
If the dataset_protocols_path does not exist.
"""
if dataset_protocols_path is None:
dataset_protocols_path = self.retrieve_dataset_protocols()
if not os.path.exists(dataset_protocols_path):
raise ValueError(
f"The path `{dataset_protocols_path}` was not found"
@@ -151,6 +153,8 @@ class FileListDatabase:
self.readers = dict()
self._protocol = None
self.protocol = protocol
# Tricksy trick to make protocols non-classmethod when instantiated
self.protocols = self._instance_protocols
super().__init__(**kwargs)
@property
@@ -182,10 +186,64 @@ class FileListDatabase:
names = [os.path.splitext(n)[0] for n in names]
return names
def protocols(self) -> list[str]:
def _instance_protocols(self) -> list[str]:
"""Returns all the available protocols."""
return list_dir(self.dataset_protocols_path, files=False)
@classmethod
def protocols(cls) -> list[str]:
return list_dir(cls.retrieve_dataset_protocols())
@classmethod
def retrieve_dataset_protocols(
cls,
name: Optional[str] = None,
urls: Optional[list[str]] = None,
hash: Optional[str] = None,
category: Optional[str] = None,
) -> str:
"""Return a path to the protocols definition files.
If the files are not present locally in ``bob_data/datasets``, they will be
downloaded.
The class inheriting from CSVDatabase must have a ``name`` and an
``dataset_protocols_urls`` attributes if those parameters are not provided.
A ``hash`` attribute can be used to verify the file and ensure the correct
version is used.
Parameters
----------
name
File name created locally. If not provided, will try to use
``cls.dataset_protocols_name``.
urls
Possible addresses to retrieve the definition file from. If not provided,
will try to use ``cls.dataset_protocols_urls``.
hash
hash of the downloaded file. If not provided, will try to use
``cls.dataset_protocol_hash``.
category
Used to specify a sub directory in ``{bob_data}/datasets/{category}``. If
not provided, will try to use ``cls.category``.
"""
# Save to bob_data/datasets, or if present, in a category sub directory.
subdir = "datasets"
if category or hasattr(cls, "category"):
subdir = os.path.join(subdir, category or getattr(cls, "category"))
# put an os.makedirs(exist_ok=True) here if needed (needs bob_data path)
# Retrieve the file from the server (or use the local version).
return get_file(
filename=name or cls.dataset_protocols_name,
urls=urls or cls.dataset_protocols_urls,
cache_subdir=subdir,
file_hash=hash or getattr(cls, "dataset_protocols_hash", None),
)
def list_file(self, group: str) -> TextIO:
"""Returns the corresponding definition file of a group."""
list_file = search_file(
Loading