diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b4d1d1ca583eb211cc4b0a5ab9b2243c715bb45e..49337196592599a11942ecd8fc24a4230c5a196d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,20 +2,20 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/timothycrosley/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: [--settings-path, "pyproject.toml"] - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.1.0 hooks: - id: black - repo: https://github.com/pycqa/flake8 - rev: 3.9.2 + rev: 6.0.0 hooks: - id: flake8 - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 + rev: v4.4.0 hooks: - id: check-ast - id: check-case-conflict diff --git a/conda/meta.yaml b/conda/meta.yaml index adce2656f3f554cf6b4a5a5e914e57aef3ccfda1..2a3fb6df52b27678fb8c18fc9c7884f0d17b6465 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -21,9 +21,8 @@ requirements: - python {{ python }} - setuptools {{ setuptools }} - pip {{ pip }} - - exposed + - clapp # bob dependencies - - bob.extension - bob.io.base # other libraries - numpy {{ numpy }} @@ -32,14 +31,14 @@ requirements: - distributed {{ distributed }} - scikit-learn {{ scikit_learn }} - xarray {{ xarray }} - - h5py {{h5py}} + - h5py {{ h5py }} + - requests {{ requests }} # optional dependencies - dask-ml {{ dask_ml }} run: - python - setuptools - - exposed - - bob.extension + - clapp - bob.io.base - {{ pin_compatible('numpy') }} - {{ pin_compatible('dask') }} @@ -48,6 +47,7 @@ requirements: - {{ pin_compatible('scikit-learn') }} - {{ pin_compatible('xarray') }} - {{ pin_compatible('h5py') }} + - requests run_constrained: - {{ pin_compatible('dask-ml') }} diff --git a/doc/catalog.json b/doc/catalog.json index f04cb4241f037bf69c07aa0059257f047e6c3f2a..f1e5d9bf2bf4806b2cbb303fa8a725a13d983aaa 100644 --- a/doc/catalog.json +++ b/doc/catalog.json @@ -1,32 +1,10 @@ { - "bob.io.base": { - "versions": { - "latest": "https://www.idiap.ch/software/bob/docs/bob/bob.io.base/master/sphinx" - }, - "sources": {} + "bob.io.base": { + "versions": { + "5.0.3b1": "https://www.idiap.ch/software/bob/docs/bob/bob.io.base/master/sphinx/" }, - "dask-ml": { - "versions": { - "latest": "https://ml.dask.org" - }, - "sources": { - "readthedocs": "dask-ml" - } - }, - "scikit-learn": { - "versions": { - "latest": "https://scikit-learn.org/stable/" - }, - "sources": { - "readthedocs": "scikit-learn" - } - }, - "xarray": { - "versions": { - "latest": "https://docs.xarray.dev/en/stable/" - }, - "sources": { - "readthedocs": "xarray" - } + "sources": { + "environment": "bob.io.base" } + } } diff --git a/doc/datasets.rst b/doc/datasets.rst index 242086a8dc6d05ea8b8e8e2637100a529cbdba80..fbce62d4f81afcc50c4b9162bcb6cbe03999c71f 100644 --- a/doc/datasets.rst +++ b/doc/datasets.rst @@ -67,6 +67,7 @@ As you can see there is only one protocol called ``default`` and two groups >>> import bob.pipelines >>> dataset_protocols_path = "tests/data/iris_database" >>> database = bob.pipelines.FileListDatabase( + ... name="iris", ... protocol="default", ... dataset_protocols_path=dataset_protocols_path, ... ) @@ -100,6 +101,7 @@ to all samples: ... return [bob.pipelines.Sample(prepare_data(sample), parent=sample) for sample in samples] >>> database = bob.pipelines.FileListDatabase( + ... name="iris", ... protocol="default", ... dataset_protocols_path=dataset_protocols_path, ... transformer=FunctionTransformer(prepare_iris_samples), diff --git a/pyproject.toml b/pyproject.toml index 38a80a5e42d74fc3c828f7dc3ce810ada61c84fd..7be64e5edb290322230928833163d1b6b95e62b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,8 +28,7 @@ dependencies = [ "setuptools", "numpy", - "bob.extension", - "exposed", + "clapp", "bob.io.base", "scikit-learn", "dask", @@ -37,6 +36,7 @@ "dask-jobqueue", "xarray", "h5py", + "requests", ] [project.urls] diff --git a/src/bob/pipelines/__init__.py b/src/bob/pipelines/__init__.py index af0083a113b130496f70e2c7d449bae34d64efa1..8069e0bfa1332579d741087924e40ca8d62c8a06 100644 --- a/src/bob/pipelines/__init__.py +++ b/src/bob/pipelines/__init__.py @@ -34,7 +34,7 @@ from .wrappers import ( # noqa: F401 is_instance_nested, is_pipeline_wrapped, ) -from .datasets import FileListToSamples, CSVToSamples, FileListDatabase +from .dataset import FileListToSamples, CSVToSamples, FileListDatabase def __appropriate__(*args): @@ -45,7 +45,7 @@ def __appropriate__(*args): Parameters ---------- *args - The objects that you want sphinx to beleive that are defined here. + The objects that you want sphinx to believe that are defined here. Resolves `Sphinx referencing issues <https//github.com/sphinx- doc/sphinx/issues/3048>` diff --git a/src/bob/pipelines/dataset/__init__.py b/src/bob/pipelines/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c487b2964ebb6c914767b62cb15d4bc47ec1d0d --- /dev/null +++ b/src/bob/pipelines/dataset/__init__.py @@ -0,0 +1,34 @@ +"""Functionalities related to datasets processing.""" + +from .database import CSVToSamples, FileListDatabase, FileListToSamples +from .protocols import ( # noqa: F401 + download_protocol_definition, + list_group_names, + list_protocol_names, + open_definition_file, +) + + +def __appropriate__(*args): + """Says object was actually declared here, and not in the import module. + Fixing sphinx warnings of not being able to find classes, when path is + shortened. + + Parameters + ---------- + *args + The objects that you want sphinx to believe that are defined here. + + Resolves `Sphinx referencing issues <https//github.com/sphinx- + doc/sphinx/issues/3048>` + """ + + for obj in args: + obj.__module__ = __name__ + + +__appropriate__( + FileListToSamples, + CSVToSamples, + FileListDatabase, +) diff --git a/src/bob/pipelines/datasets.py b/src/bob/pipelines/dataset/database.py similarity index 69% rename from src/bob/pipelines/datasets.py rename to src/bob/pipelines/dataset/database.py index a0f21dc3799556d887244db7853abcdd70c4d09c..fe43dee9bb016b9f2f6c0a20fcc4e4d8aa55b845 100644 --- a/src/bob/pipelines/datasets.py +++ b/src/bob/pipelines/dataset/database.py @@ -11,21 +11,26 @@ The principles of this module are: import csv import itertools import os -import pathlib from collections.abc import Iterable -from typing import Any, Optional, TextIO +from pathlib import Path +from typing import Any, Optional, TextIO, Union import sklearn.pipeline -from bob.extension.download import get_file, list_dir, search_file +from bob.pipelines.dataset.protocols.retrieve import ( + list_group_names, + list_protocol_names, + open_definition_file, + retrieve_protocols, +) -from .sample import Sample -from .utils import check_parameter_for_validity, check_parameters_for_validity +from ..sample import Sample +from ..utils import check_parameter_for_validity, check_parameters_for_validity def _maybe_open_file(path, **kwargs): - if isinstance(path, (str, bytes, pathlib.Path)): + if isinstance(path, (str, bytes, Path)): path = open(path, **kwargs) return path @@ -117,8 +122,10 @@ class FileListDatabase: def __init__( self, + *, + name: str, protocol: str, - dataset_protocols_path: Optional[str] = None, + dataset_protocols_path: Union[os.PathLike[str], str, None] = None, reader_cls: Iterable = CSVToSamples, transformer: Optional[sklearn.pipeline.Pipeline] = None, **kwargs, @@ -130,30 +137,37 @@ class FileListDatabase: Path to a folder or a tarball where the csv protocol files are located. protocol The name of the protocol to be used for samples. If None, the first - protocol will be used. + protocol found will be used. reader_cls - A callable that will initialize the CSVToSamples reader, by default CSVToSamples TODO update + An iterable that returns created Sample objects from a list file. transformer - A scikit-learn transformer that further changes the samples + A scikit-learn transformer that further changes the samples. Raises ------ ValueError If the dataset_protocols_path does not exist. """ + + # Tricksy trick to make protocols non-classmethod when instantiated + self.protocols = self._instance_protocols + + if getattr(self, "name", None) is None: + self.name = name + if dataset_protocols_path is None: dataset_protocols_path = self.retrieve_dataset_protocols() - if not os.path.exists(dataset_protocols_path): + + self.dataset_protocols_path = Path(dataset_protocols_path) + + if len(self.protocols()) < 1: raise ValueError( - f"The path `{dataset_protocols_path}` was not found" + f"No protocols found at `{dataset_protocols_path}`!" ) - self.dataset_protocols_path = dataset_protocols_path self.reader_cls = reader_cls self._transformer = transformer - self.readers = dict() + self.readers: dict[str, Iterable] = {} self._protocol = None - # Tricksy trick to make protocols non-classmethod when instantiated - self.protocols = self._instance_protocols self.protocol = protocol super().__init__(**kwargs) @@ -180,75 +194,78 @@ class FileListDatabase: def groups(self) -> list[str]: """Returns all the available groups.""" - names = list_dir( - self.dataset_protocols_path, self.protocol, folders=False + return list_group_names( + database_name=self.name, + protocol=self.protocol, + database_filename=self.dataset_protocols_path.name, + base_dir=self.dataset_protocols_path.parent, + subdir=".", ) - names = [os.path.splitext(n)[0] for n in names] - return names def _instance_protocols(self) -> list[str]: """Returns all the available protocols.""" - return list_dir(self.dataset_protocols_path, files=False) + return list_protocol_names( + database_name=self.name, + database_filename=self.dataset_protocols_path.name, + base_dir=self.dataset_protocols_path.parent, + subdir=".", + ) @classmethod - def protocols(cls) -> list[str]: - return list_dir(cls.retrieve_dataset_protocols()) + def protocols(cls) -> list[str]: # pylint: disable=method-hidden + """Returns all the available protocols.""" + # Ensure the definition file exists locally + loc = cls.retrieve_dataset_protocols() + if not hasattr(cls, "name"): + raise ValueError(f"{cls} has no attribute 'name'.") + return list_protocol_names( + database_name=getattr(cls, "name"), + database_filename=loc.name, + base_dir=loc.parent, + subdir=".", + ) @classmethod - def retrieve_dataset_protocols( - cls, - name: Optional[str] = None, - urls: Optional[list[str]] = None, - hash: Optional[str] = None, - category: Optional[str] = None, - ) -> str: + def retrieve_dataset_protocols(cls) -> Path: """Return a path to the protocols definition files. - If the files are not present locally in ``bob_data/datasets``, they will be - downloaded. + If the files are not present locally in ``bob_data/<subdir>/<category>``, they + will be downloaded. The class inheriting from CSVDatabase must have a ``name`` and an - ``dataset_protocols_urls`` attributes if those parameters are not provided. + ``dataset_protocols_urls`` attributes. - A ``hash`` attribute can be used to verify the file and ensure the correct + A ``checksum`` attribute can be used to verify the file and ensure the correct version is used. - - Parameters - ---------- - name - File name created locally. If not provided, will try to use - ``cls.dataset_protocols_name``. - urls - Possible addresses to retrieve the definition file from. If not provided, - will try to use ``cls.dataset_protocols_urls``. - hash - hash of the downloaded file. If not provided, will try to use - ``cls.dataset_protocol_hash``. - category - Used to specify a sub directory in ``{bob_data}/datasets/{category}``. If - not provided, will try to use ``cls.category``. - """ - # Save to bob_data/datasets, or if present, in a category sub directory. - subdir = "datasets" - if category or hasattr(cls, "category"): - subdir = os.path.join(subdir, category or getattr(cls, "category")) - # put an os.makedirs(exist_ok=True) here if needed (needs bob_data path) + # When the path is specified, just return it. + if getattr(cls, "dataset_protocols_path", None) is not None: + return getattr(cls, "dataset_protocols_path") + + # Save to bob_data/protocols, or if present, in a category sub directory. + subdir = Path("protocols") + if hasattr(cls, "category"): + subdir = subdir / getattr(cls, "category") # Retrieve the file from the server (or use the local version). - return get_file( - filename=name or cls.dataset_protocols_name, - urls=urls or cls.dataset_protocols_urls, - cache_subdir=subdir, - file_hash=hash or getattr(cls, "dataset_protocols_hash", None), + return retrieve_protocols( + urls=getattr(cls, "dataset_protocols_urls"), + destination_filename=getattr(cls, "dataset_protocols_name", None), + base_dir=None, + subdir=subdir, + checksum=getattr(cls, "dataset_protocols_checksum", None), ) def list_file(self, group: str) -> TextIO: """Returns the corresponding definition file of a group.""" - list_file = search_file( - self.dataset_protocols_path, - os.path.join(self.protocol, group + ".csv"), + list_file = open_definition_file( + search_pattern=group + ".csv", + database_name=self.name, + protocol=self.protocol, + database_filename=self.dataset_protocols_path.name, + base_dir=self.dataset_protocols_path.parent, + subdir=".", ) return list_file @@ -282,7 +299,6 @@ class FileListDatabase: ) all_samples = [] for grp in groups: - for sample in self.get_reader(grp): all_samples.append(sample) diff --git a/src/bob/pipelines/dataset/protocols/__init__.py b/src/bob/pipelines/dataset/protocols/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f4a10d6067df3810363c70ae4fe7b084e44a2479 --- /dev/null +++ b/src/bob/pipelines/dataset/protocols/__init__.py @@ -0,0 +1,8 @@ +"""Adds functionalities to retrieve and process protocol definition files.""" + +from .retrieve import ( # noqa: F401 + download_protocol_definition, + list_group_names, + list_protocol_names, + open_definition_file, +) diff --git a/src/bob/pipelines/dataset/protocols/archive.py b/src/bob/pipelines/dataset/protocols/archive.py new file mode 100644 index 0000000000000000000000000000000000000000..d45ede2651d934ad4266e859d3ea83d29ed90630 --- /dev/null +++ b/src/bob/pipelines/dataset/protocols/archive.py @@ -0,0 +1,203 @@ +"""Archives (tar, zip) operations like searching for files and extracting.""" + +import bz2 +import io +import logging +import os +import tarfile +import zipfile + +from fnmatch import fnmatch +from pathlib import Path +from typing import IO, TextIO, Union + +logger = logging.getLogger(__name__) + + +def path_and_subdir( + archive_path: Union[str, os.PathLike], +) -> tuple[Path, Union[Path, None]]: + """Splits an archive's path from a sub directory (separated by ``:``).""" + archive_path_str = Path(archive_path).as_posix() + if ":" in archive_path_str: + archive, sub_dir = archive_path_str.rsplit(":", 1) + return Path(archive), Path(sub_dir) + return Path(archive_path), None + + +def _is_bz2(path: Union[str, os.PathLike]) -> bool: + try: + with bz2.BZ2File(path) as f: + f.read(1024) + return True + except (OSError, EOFError): + return False + + +def is_archive(path: Union[str, os.PathLike]) -> bool: + """Returns whether the path points in an archive. + + Any path pointing to a valid tar or zip archive or to a valid bz2 + file will return ``True``. + """ + archive = path_and_subdir(path)[0] + try: + return any( + tester(path_and_subdir(archive)[0]) + for tester in (tarfile.is_tarfile, zipfile.is_zipfile, _is_bz2) + ) + except (FileNotFoundError, IsADirectoryError): + return False + + +def search_and_open( + search_pattern: str, + archive_path: Union[str, os.PathLike], + inner_dir: Union[os.PathLike, None] = None, + open_as_binary: bool = False, +) -> Union[IO[bytes], TextIO, None]: + """Returns a read-only stream of a file matching a pattern in an archive. + + Wildcards (``*``, ``?``, and ``**``) are supported (using + :meth:`pathlib.Path.glob`). + + The first matching file will be open and returned. + + examples: + + .. code-block: text + + archive.tar.gz + + subdir1 + | + file1.txt + | + file2.txt + | + + subdir2 + + file1.txt + + ``search_and_open("archive.tar.gz", "file1.txt")`` + opens``archive.tar.gz/subdir1/file1.txt`` + + ``search_and_open("archive.tar.gz:subdir2", "file1.txt")`` + opens ``archive.tar.gz/subdir2/file1.txt`` + + ``search_and_open("archive.tar.gz", "*.txt")`` + opens ``archive.tar.gz/subdir1/file1.txt`` + + + Parameters + ---------- + archive_path + The ``.tar.gz`` archive file containing the wanted file. To match + ``search_pattern`` in a sub path in that archive, append the sub path + to ``archive_path`` with a ``:`` (e.g. + ``/path/to/archive.tar.gz:sub/dir/``). + search_pattern + A string to match to the file. Wildcards are supported (Unix pattern + matching). + + Returns + ------- + io.TextIOBase or io.BytesIO + A read-only file stream. + """ + + archive_path = Path(archive_path) + + if inner_dir is None: + archive_path, inner_dir = path_and_subdir(archive_path) + + if inner_dir is not None: + pattern = (Path("/") / inner_dir / search_pattern).as_posix() + else: + pattern = (Path("/") / search_pattern).as_posix() + + if ".tar" in archive_path.suffixes: + tar_arch = tarfile.open(archive_path) # TODO File not closed + for member in tar_arch: + if member.isfile() and fnmatch("/" + member.name, pattern): + break + else: + logger.debug( + f"No file matching '{pattern}' were found in '{archive_path}'." + ) + return None + + if open_as_binary: + return tar_arch.extractfile(member) + return io.TextIOWrapper(tar_arch.extractfile(member), encoding="utf-8") + + elif archive_path.suffix == ".zip": + zip_arch = zipfile.ZipFile(archive_path) + for name in zip_arch.namelist(): + if fnmatch("/" + name, pattern): + break + else: + logger.debug( + f"No file matching '{pattern}' were found in '{archive_path}'." + ) + return zip_arch.open(name) + + raise ValueError( + f"Unknown file extension '{''.join(archive_path.suffixes)}'" + ) + + +def list_dirs( + archive_path: Union[str, os.PathLike], + inner_dir: Union[os.PathLike, None] = None, + show_dirs: bool = True, + show_files: bool = True, +) -> list[Path]: + """Returns a list of all the elements in an archive or inner directory. + + Parameters + ---------- + archive_path + A path to an archive, or an inner directory of an archive (appended + with a ``:``). + inner_dir + A path inside the archive with its root at the archive's root. + show_dirs + Returns directories. + show_files + Returns files. + """ + + archive_path, arch_inner_dir = path_and_subdir(archive_path) + inner_dir = Path(inner_dir or arch_inner_dir or Path(".")) + + results = [] + # Read the archive info and iterate over the paths. Return the ones we want. + if ".tar" in archive_path.suffixes: + with tarfile.open(archive_path) as arch: + for info in arch.getmembers(): + path = Path(info.name) + if path.parent != inner_dir: + continue + if info.isdir() and show_dirs: + results.append(Path("/") / path) + if info.isfile() and show_files: + results.append(Path("/") / path) + elif archive_path.suffix == ".zip": + with zipfile.ZipFile(archive_path) as arch: + for zip_info in arch.infolist(): + zip_path = zipfile.Path(archive_path, zip_info.filename) + if Path(zip_info.filename).parent != inner_dir: + continue + if zip_path.is_dir() and show_dirs: + results.append(Path("/") / zip_info.filename) + if not zip_path.is_dir() and show_files: + results.append(Path("/") / zip_info.filename) + elif archive_path.suffix == ".bz2": + if inner_dir != Path("."): + raise ValueError( + ".bz2 files don't have an inner structure (tried to access " + f"'{archive_path}:{inner_dir}')." + ) + results.extend([Path(archive_path.stem)] if show_files else []) + else: + raise ValueError( + f"Unsupported archive extension '{''.join(archive_path.suffixes)}'." + ) + return sorted(results) # Fixes inconsistent file ordering across platforms diff --git a/src/bob/pipelines/dataset/protocols/hashing.py b/src/bob/pipelines/dataset/protocols/hashing.py new file mode 100644 index 0000000000000000000000000000000000000000..322cfaf10248e46dcc0f0bd45a946c3e3be2d149 --- /dev/null +++ b/src/bob/pipelines/dataset/protocols/hashing.py @@ -0,0 +1,65 @@ +"""Hashing functionalities for verifying files and computing CRCs.""" + + +import hashlib +import os + +from pathlib import Path +from typing import Any, Callable, Union + + +def md5_hash(readable: Any, chunk_size: int = 65535) -> str: + """Computes the md5 hash of any object with a read method.""" + hasher = hashlib.md5() + for chunk in iter(lambda: readable.read(chunk_size), b""): + hasher.update(chunk) + return hasher.hexdigest() + + +def sha256_hash(readable: Any, chunk_size: int = 65535) -> str: + """Computes the SHA256 hash of any object with a read method.""" + hasher = hashlib.sha256() + for chunk in iter(lambda: readable.read(chunk_size), b""): + hasher.update(chunk) + return hasher.hexdigest() + + +def verify_file( + file_path: Union[str, os.PathLike], + file_hash: str, + hash_fct: Callable[[Any, int], str] = sha256_hash, + full_match: bool = False, +) -> bool: + """Returns True if the file computed hash corresponds to `file_hash`. + + For comfort, we allow ``file_hash`` to match with the first + characters of the digest, allowing storing only e.g. the first 8 + char. + + Parameters + ---------- + file_path + The path to the file needing verification. + file_hash + The expected file hash digest. + hash_fct + A function taking a path and returning a digest. Defaults to SHA256. + full_match + If set to False, allows ``file_hash`` to match the first characters of + the files digest (this allows storing e.g. 8 chars of a digest instead + of the whole 64 characters of SHA256, and still matching.) + """ + file_path = Path(file_path) + with file_path.open("rb") as f: + digest = hash_fct(f, 65535) + return digest == file_hash if full_match else digest.startswith(file_hash) + + +def compute_crc( + file_path: Union[str, os.PathLike], + hash_fct: Callable[[Any, int], str] = sha256_hash, +) -> str: + """Returns the CRC of a file.""" + file_path = Path(file_path) + with file_path.open("rb") as f: + return hash_fct(f, 65535) diff --git a/src/bob/pipelines/dataset/protocols/retrieve.py b/src/bob/pipelines/dataset/protocols/retrieve.py new file mode 100644 index 0000000000000000000000000000000000000000..14e716d9e97b45467e68af39736e0d6155473869 --- /dev/null +++ b/src/bob/pipelines/dataset/protocols/retrieve.py @@ -0,0 +1,451 @@ +"""Allows to find a protocol definition file locally, or download it if needed. + + +Expected protocol structure: + +``base_dir / subdir / database_filename / protocol_name / group_name`` + + +By default, ``base_dir`` will be pointed by the ``bob_data_dir`` config. +``subdir`` is provided as a way to use a directory inside ``base_dir`` when +using its default. + +Here are some valid example paths (``bob_data_dir=/home/username/bob_data``): + +In a "raw" directory (not an archive): + +``/home/username/bob_data/protocols/my_db/my_protocol/my_group`` + +In an archive: + +``/home/username/bob_data/protocols/my_db.tar.gz/my_protocol/my_group`` + +In an archive with the database name as top-level (some legacy db used that): + +``/home/username/bob_data/protocols/my_db.tar.gz/my_db/my_protocol/my_group`` + +""" + +import glob + +from logging import getLogger +from os import PathLike +from pathlib import Path +from typing import Any, Callable, Optional, TextIO, Union + +import requests + +from clapp.rc import UserDefaults + +from bob.pipelines.dataset.protocols import archive, hashing + +logger = getLogger(__name__) + + +def _get_local_data_directory() -> Path: + """Returns the local directory for data (``bob_data_dir`` config).""" + user_config = UserDefaults("bob.toml") + return Path( + user_config.get("bob_data_dir", default=Path.home() / "bob_data") + ) + + +def _infer_filename_from_urls(urls=Union[list[str], str]) -> str: + """Retrieves the remote filename from the URLs. + + Parameters + ---------- + urls + One or multiple URLs pointing to files with the same name. + + Returns + ------- + The remote file name. + + Raises + ------ + ValueError + When urls point to files with different names. + """ + if isinstance(urls, str): + return urls.split("/")[-1] + + # Check that all urls point to the same file name + names = [u.split("/")[-1] for u in urls] + if not all(n == names[0] for n in names): + raise ValueError( + f"Cannot infer file name when urls point to different files ({names=})." + ) + return urls[0].split("/")[-1] + + +def retrieve_protocols( + urls: list[str], + destination_filename: Optional[str] = None, + base_dir: Union[PathLike[str], str, None] = None, + subdir: Union[PathLike[str], str] = "protocol", + checksum: Union[str, None] = None, +) -> Path: + """Automatically downloads the necessary protocol definition files.""" + if base_dir is None: + base_dir = _get_local_data_directory() + + remote_filename = _infer_filename_from_urls(urls) + if destination_filename is None: + destination_filename = remote_filename + elif Path(remote_filename).suffixes != Path(destination_filename).suffixes: + raise ValueError( + "Local dataset protocol definition files must have the same " + f"extension as the remote ones ({remote_filename=})" + ) + + return download_protocol_definition( + urls=urls, + destination_base_dir=base_dir, + destination_subdir=subdir, + destination_filename=destination_filename, + checksum=checksum, + force=False, + ) + + +def list_protocol_paths( + database_name: str, + base_dir: Union[PathLike[str], str, None] = None, + subdir: Union[PathLike[str], str] = "protocol", + database_filename: Union[str, None] = None, +) -> list[Path]: + """Returns the paths of each protocol in a database definition file.""" + if base_dir is None: + base_dir = _get_local_data_directory() + final_dir = Path(base_dir) / subdir + final_dir /= ( + database_name if database_filename is None else database_filename + ) + + if archive.is_archive(final_dir): + protocols = archive.list_dirs(final_dir, show_files=False) + if len(protocols) == 1 and protocols[0].name == database_name: + protocols = archive.list_dirs( + final_dir, database_name, show_files=False + ) + + archive_path, inner_dir = archive.path_and_subdir(final_dir) + if inner_dir is None: + return [ + Path(f"{archive_path.as_posix()}:{p.as_posix().lstrip('/')}") + for p in protocols + ] + + return [ + Path(f"{archive_path.as_posix()}:{inner_dir.as_posix()}/{p.name}") + for p in protocols + ] + + # Not an archive + return final_dir.iterdir() + + +def get_protocol_path( + database_name: str, + protocol: str, + base_dir: Union[PathLike[str], str, None] = None, + subdir: Union[PathLike[str], str] = "protocols", + database_filename: Optional[str] = None, +) -> Union[Path, None]: + """Returns the path of a specific protocol. + + Will look for ``protocol`` in ``base_dir / subdir / database_(file)name``. + + Returns + ------- + Path + The required protocol's path for the database. + """ + protocol_paths = list_protocol_paths( + database_name=database_name, + base_dir=base_dir, + subdir=subdir, + database_filename=database_filename, + ) + for protocol_path in protocol_paths: + if archive.is_archive(protocol_path): + _base, inner = archive.path_and_subdir(protocol_path) + if inner.name == protocol: + return protocol_path + elif protocol_path.name == protocol: + return protocol_path + logger.warning(f"Protocol {protocol} not found in {database_name}.") + return None + + +def list_protocol_names( + database_name: str, + base_dir: Union[PathLike[str], str, None] = None, + subdir: Union[PathLike[str], str, None] = "protocols", + database_filename: Union[str, None] = None, +) -> list[str]: + """Returns the paths of the protocol directories for a given database. + + Archives are also accepted, either if the file name is the same as + ``database_name`` with a ``.tar.gz`` extension or by specifying the filename + in ``database_filename``. + + This will look in ``base_dir/subdir`` for ``database_filename``, then + ``database_name``, then ``database_name+".tar.gz"``. + + Parameters + ---------- + database_name + The database name used to infer ``database_filename`` if not specified. + base_dir + The base path of data files (defaults to the ``bob_data_dir`` config, or + ``~/bob_data`` if not configured). + subdir + A sub directory for the protocols in ``base_dir``. + database_filename + If the file/directory name of the protocols is not the same as the + name of the database, this can be set to look in the correct file. + + Returns + ------- + A list of protocol names + The different protocols available for that database. + """ + + if base_dir is None: + base_dir = _get_local_data_directory() + + if subdir is None: + subdir = "." + + if database_filename is None: + database_filename = database_name + final_path: Path = Path(base_dir) / subdir / database_filename + if not final_path.is_dir(): + database_filename = database_name + ".tar.gz" + + final_path: Path = Path(base_dir) / subdir / database_filename + + if archive.is_archive(final_path): + top_level_dirs = archive.list_dirs(final_path, show_files=False) + # Handle a database archive having database_name as top-level directory + if len(top_level_dirs) == 1 and top_level_dirs[0].name == database_name: + return [ + p.name + for p in archive.list_dirs( + final_path, inner_dir=database_name, show_files=False + ) + ] + return [p.name for p in top_level_dirs] + # Not an archive: list the dirs + return [p.name for p in final_path.iterdir() if p.is_dir()] + + +def open_definition_file( + search_pattern: Union[PathLike[str], str], + database_name: str, + protocol: str, + base_dir: Union[PathLike[str], str, None] = None, + subdir: Union[PathLike[str], str, None] = "protocols", + database_filename: Optional[str] = None, +) -> Union[TextIO, None]: + """Opens a protocol definition file inside a protocol directory. + + Also handles protocols inside an archive. + """ + search_path = get_protocol_path( + database_name, protocol, base_dir, subdir, database_filename + ) + + if archive.is_archive(search_path): + return archive.search_and_open( + search_pattern=search_pattern, + archive_path=search_path, + ) + + search_pattern = Path(search_pattern) + + # we prepend './' to search_pattern because it might start with '/' + pattern = search_path / "**" / f"./{search_pattern.as_posix()}" + for path in glob.iglob(pattern.as_posix(), recursive=True): + if not Path(path).is_file(): + continue + return open(path, mode="rt") + logger.info(f"Unable to locate and open a file that matches '{pattern}'.") + return None + + +def list_group_paths( + database_name: str, + protocol: str, + base_dir: Union[PathLike[str], str, None] = None, + subdir: Union[PathLike[str], str] = "protocols", + database_filename: Optional[str] = None, +) -> list[Path]: + """Returns the file paths of the groups in protocol""" + protocol_path = get_protocol_path( + database_name=database_name, + protocol=protocol, + base_dir=base_dir, + subdir=subdir, + database_filename=database_filename, + ) + if archive.is_archive(protocol_path): + groups_inner = archive.list_dirs(protocol_path) + archive_path, inner_path = archive.path_and_subdir(protocol_path) + return [ + Path(f"{archive_path.as_posix()}:{inner_path.as_posix()}/{g}") + for g in groups_inner + ] + return protocol_path.iterdir() + + +def list_group_names( + database_name: str, + protocol: str, + base_dir: Union[PathLike[str], str, None] = None, + subdir: Union[PathLike[str], str] = "protocols", + database_filename: Optional[str] = None, +) -> list[str]: + """Returns the group names of a protocol.""" + paths = list_group_paths( + database_name=database_name, + protocol=protocol, + base_dir=base_dir, + subdir=subdir, + database_filename=database_filename, + ) + # Supports groups as files or dirs + return [p.stem for p in paths] # ! This means group can't include a '.' + + +def download_protocol_definition( + urls: Union[list[str], str], + destination_base_dir: Union[PathLike, None] = None, + destination_subdir: Union[str, None] = None, + destination_filename: Union[str, None] = None, + checksum: Union[str, None] = None, + checksum_fct: Callable[[Any, int], str] = hashing.sha256_hash, + force: bool = False, + makedirs: bool = True, +) -> Path: + """Downloads a remote file locally. + + Parameters + ---------- + urls + The remote location of the server. If multiple addresses are given, we will try + to download from them in order until one succeeds. + destination_basedir + A path to a local directory where the file will be saved. If omitted, the file + will be saved in the folder pointed by the ``wdr.local_directory`` key in the + user configuration. + destination_subdir + An additional layer added to the destination directory (useful when using + ``destination_directory=None``). + destination_filename + The final name of the local file. If omitted, the file will keep the name of + the remote file. + checksum + When provided, will compute the file's checksum and compare to this. + checksum_fct + A callable that takes a ``reader`` and returns a hash. + force + Re-download and overwrite any existing file with the same name. + makedirs + Automatically make the parent directories of the new local file. + + Returns + ------- + The path to the new local file. + + Raises + ------ + RuntimeError + When the URLs provided are all invalid. + ValueError + When ``destination_filename`` is omitted and URLs point to files with different + names. + When the checksum of the file does not correspond to the provided ``checksum``. + """ + + if destination_filename is None: + destination_filename = _infer_filename_from_urls(urls=urls) + + if destination_base_dir is None: + destination_base_dir = _get_local_data_directory() + + destination_base_dir = Path(destination_base_dir) + + if destination_subdir is not None: + destination_base_dir = destination_base_dir / destination_subdir + + local_file = destination_base_dir / destination_filename + needs_download = True + + if not force and local_file.is_file(): + if checksum is None: + logger.info( + f"File {local_file} already exists, skipping download ({force=})." + ) + needs_download = False + elif hashing.verify_file(local_file, checksum, checksum_fct): + logger.info( + f"File {local_file} already exists and checksum is valid." + ) + needs_download = False + + if needs_download: + if isinstance(urls, str): + urls = [urls] + + for tries, url in enumerate(urls): + logger.debug(f"Retrieving file from '{url}'.") + try: + response = requests.get(url=url, timeout=10) + except requests.exceptions.ConnectionError as e: + if tries < len(urls) - 1: + logger.info( + f"Could not connect to {url}. Trying other URLs." + ) + logger.debug(e) + continue + + logger.debug( + f"http response: '{response.status_code}: {response.reason}'." + ) + + if response.ok: + logger.debug(f"Got file from {url}.") + break + if tries < len(urls) - 1: + logger.info( + f"Failed to get file from {url}, trying other URLs." + ) + logger.debug(f"requests.response was\n{response}") + else: + raise RuntimeError( + f"Could not retrieve file from any of the provided URLs! ({urls=})" + ) + + if makedirs: + local_file.parent.mkdir(parents=True, exist_ok=True) + + with local_file.open("wb") as f: + f.write(response.content) + + if checksum is not None: + if not hashing.verify_file(local_file, checksum, hash_fct=checksum_fct): + if not needs_download: + raise ValueError( + f"The local file hash does not correspond to '{checksum}' " + f"and {force=} prevents overwriting." + ) + raise ValueError( + "The downloaded file hash ('" + f"{hashing.compute_crc(local_file, hash_fct=checksum_fct)}') does " + f"not correspond to '{checksum}'." + ) + + return local_file diff --git a/src/bob/pipelines/distributed/sge.py b/src/bob/pipelines/distributed/sge.py index 83c678da0f7a67c67d59a68be8033532970b7e46..cc5bb278fdb0255cab960cdc7c8570927a32d7b4 100644 --- a/src/bob/pipelines/distributed/sge.py +++ b/src/bob/pipelines/distributed/sge.py @@ -7,10 +7,10 @@ import sys import dask +from clapp.rc import UserDefaults from dask_jobqueue.core import Job, JobQueueCluster from distributed.deploy import Adaptive from distributed.scheduler import Scheduler -from exposed.rc import UserDefaults from .sge_queues import QUEUE_DEFAULT @@ -42,7 +42,6 @@ class SGEIdiapJob(Job): config_name="sge", **kwargs, ): - if queue is None: queue = dask.config.get("jobqueue.%s.queue" % config_name) if project is None: @@ -244,7 +243,6 @@ class SGEMultipleQueuesCluster(JobQueueCluster): project=rc.get("sge.project"), **kwargs, ): - # Defining the job launcher self.job_cls = SGEIdiapJob self.sge_job_spec = sge_job_spec diff --git a/src/bob/pipelines/wrappers.py b/src/bob/pipelines/wrappers.py index 182dec239247bcf821b256f254154745115370a4..fe3326f61528a7b468945e7b03533b1b16b1ca92 100644 --- a/src/bob/pipelines/wrappers.py +++ b/src/bob/pipelines/wrappers.py @@ -566,7 +566,6 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin): return self.estimator.score(samples) def fit(self, samples, y=None, **kwargs): - if not estimator_requires_fit(self.estimator): return self @@ -582,7 +581,6 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin): return self.save_model() def make_path(self, sample): - if self.features_dir is None: return None @@ -605,7 +603,6 @@ class CheckpointWrapper(BaseWrapper, TransformerMixin): to_save = getattr(sample, self.sample_attribute) for _ in range(self.attempts): try: - dirname = os.path.dirname(path) os.makedirs(dirname, exist_ok=True) @@ -697,7 +694,6 @@ def _shape_samples(samples): def _array_from_sample_bags(X: dask.bag.Bag, attribute: str, ndim: int = 2): - if ndim not in (1, 2): raise NotImplementedError(f"ndim must be 1 or 2. Got: {ndim}") @@ -1028,7 +1024,6 @@ def wrap(bases, estimator=None, **kwargs): if isinstance(estimator, Pipeline): # wrap inner steps for idx, name, trans in estimator._iter(): - # when checkpointing a pipeline, checkpoint each transformer in its own folder new_kwargs = dict(kwargs) features_dir, model_path = ( diff --git a/src/bob/pipelines/xarray.py b/src/bob/pipelines/xarray.py index 6b35d1c0df4902d38d03d579eeafa80211b1ff0b..1a13b8367f3c9ea74e3089f58441a4f2ddae7bee 100644 --- a/src/bob/pipelines/xarray.py +++ b/src/bob/pipelines/xarray.py @@ -313,7 +313,6 @@ def _get_dask_args_from_ds(ds, columns): def _blockwise_with_block_args(args, block, method_name=None): - meta = [] for _ in range(1, block.output_ndim): meta = [meta] diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000000000000000000000000000000000000..b840546bac0c3a7ec7a93e8030213dd13a314dc8 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# coding=utf-8 + +"""Test code for datasets""" + +from pathlib import Path +from tempfile import TemporaryDirectory + +import numpy as np +import pytest + +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import FunctionTransformer + +from bob.pipelines import FileListDatabase +from bob.pipelines.transformers import Str_To_Types + +DATA_PATH = Path(__file__).parent / "data" + + +def iris_data_transform(samples): + for s in samples: + data = np.array( + [s.sepal_length, s.sepal_width, s.petal_length, s.petal_width] + ) + s.data = data + return samples + + +def test_iris_list_database(): + protocols_path = DATA_PATH / "iris_database" + + database = FileListDatabase( + name="iris", protocol=None, dataset_protocols_path=protocols_path + ) + assert database.name == "iris" + assert database.protocol == "default" + assert database.protocols() == ["default"] + assert database.groups() == ["test", "train"] + with pytest.raises(ValueError): + database.protocol = "none" + + samples = database.samples() + assert len(samples) == 150 + assert samples[0].data is None + assert samples[0].sepal_length == "5" + assert samples[0].petal_width == "0.2" + assert samples[0].target == "Iris-setosa" + + with pytest.raises(ValueError): + database.samples(groups="random") + + database.transformer = make_pipeline( + Str_To_Types( + fieldtypes=dict( + sepal_length=float, + sepal_width=float, + petal_length=float, + petal_width=float, + ) + ), + FunctionTransformer(iris_data_transform), + ) + samples = database.samples(groups="train") + assert len(samples) == 75 + np.testing.assert_allclose(samples[0].data, [5.1, 3.5, 1.4, 0.2]) + assert samples[0].sepal_length == 5.1 + assert samples[0].petal_width == 0.2 + assert samples[0].target == "Iris-setosa" + + +def test_filelist_class(monkeypatch): + protocols_path = Path(DATA_PATH / "iris_database") + + class DBLocal(FileListDatabase): + name = "iris" + dataset_protocols_path = protocols_path + + assert DBLocal.protocols() == ["default"] + assert DBLocal.retrieve_dataset_protocols() == protocols_path + + with TemporaryDirectory(prefix="bobtest_") as tmpdir: + tmp_home = Path(tmpdir) + monkeypatch.setenv("HOME", tmp_home.as_posix()) + + class DBDownloadDefault(FileListDatabase): + name = "atnt" + dataset_protocols_checksum = "f529acef" + dataset_protocols_urls = [ + "https://www.idiap.ch/software/bob/databases/latest/base/atnt-f529acef.tar.gz" + ] + + assert DBDownloadDefault.protocols() == ["idiap_protocol"] + assert ( + DBDownloadDefault.retrieve_dataset_protocols() + == tmp_home / "bob_data" / "protocols" / "atnt-f529acef.tar.gz" + ) + + with TemporaryDirectory(prefix="bobtest_") as tmpdir: + tmp_home = Path(tmpdir) + monkeypatch.setenv("HOME", tmp_home.as_posix()) + desired_name = "atnt_filename.tar.gz" + + class DBDownloadCustomFilename(FileListDatabase): + name = "atnt" + dataset_protocols_checksum = "f529acef" + dataset_protocols_urls = [ + "https://www.idiap.ch/software/bob/databases/latest/base/atnt-f529acef.tar.gz" + ] + dataset_protocols_name = desired_name + + assert DBDownloadCustomFilename.protocols() == ["idiap_protocol"] + assert ( + DBDownloadCustomFilename.retrieve_dataset_protocols() + == tmp_home / "bob_data" / "protocols" / desired_name + ) + + with TemporaryDirectory(prefix="bobtest_") as tmpdir: + tmp_home = Path(tmpdir) + monkeypatch.setenv("HOME", tmp_home.as_posix()) + desired_category = "custom_category" + + class DBDownloadCustomCategory(FileListDatabase): + name = "atnt" + category = desired_category + dataset_protocols_checksum = "f529acef" + dataset_protocols_urls = [ + "https://www.idiap.ch/software/bob/databases/latest/base/atnt-f529acef.tar.gz" + ] + + assert DBDownloadCustomCategory.protocols() == ["idiap_protocol"] + assert ( + DBDownloadCustomCategory.retrieve_dataset_protocols() + == tmp_home + / "bob_data" + / "protocols" + / desired_category + / "atnt-f529acef.tar.gz" + ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py deleted file mode 100644 index 48f7bca0faad69f7d906fc5a133493be0920528d..0000000000000000000000000000000000000000 --- a/tests/test_datasets.py +++ /dev/null @@ -1,68 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 - -"""Test code for datasets""" - -import os - -import numpy as np -import pkg_resources -import pytest - -from sklearn.pipeline import make_pipeline -from sklearn.preprocessing import FunctionTransformer - -from bob.pipelines.datasets import FileListDatabase -from bob.pipelines.transformers import Str_To_Types - - -def iris_data_transform(samples): - for s in samples: - data = np.array( - [s.sepal_length, s.sepal_width, s.petal_length, s.petal_width] - ) - s.data = data - return samples - - -def test_iris_list_database(): - dataset_protocols_path = pkg_resources.resource_filename( - __name__, os.path.join("data", "iris_database") - ) - - database = FileListDatabase( - protocol=None, dataset_protocols_path=dataset_protocols_path - ) - assert database.protocol == "default" - assert database.protocols() == ["default"] - assert database.groups() == ["test", "train"] - with pytest.raises(ValueError): - database.protocol = "none" - - samples = database.samples() - assert len(samples) == 150 - assert samples[0].data is None - assert samples[0].sepal_length == "5" - assert samples[0].petal_width == "0.2" - assert samples[0].target == "Iris-setosa" - - with pytest.raises(ValueError): - database.samples(groups="random") - - database.transformer = make_pipeline( - Str_To_Types( - fieldtypes=dict( - sepal_length=float, - sepal_width=float, - petal_length=float, - petal_width=float, - ) - ), - FunctionTransformer(iris_data_transform), - ) - samples = database.samples(groups="train") - assert len(samples) == 75 - np.testing.assert_allclose(samples[0].data, [5.1, 3.5, 1.4, 0.2]) - assert samples[0].sepal_length == 5.1 - assert samples[0].petal_width == 0.2 - assert samples[0].target == "Iris-setosa" diff --git a/tests/test_samples.py b/tests/test_samples.py index 4b246a33660152f4e30f6e397a4cfc2594116d86..84b79f7deca77dba05204cbdd0739a4809363c6a 100644 --- a/tests/test_samples.py +++ b/tests/test_samples.py @@ -17,7 +17,6 @@ from bob.pipelines import ( def test_sampleset_collection(): - n_samples = 10 X = np.ones(shape=(n_samples, 2), dtype=int) sampleset = SampleSet( @@ -46,7 +45,6 @@ def test_sampleset_collection(): # Testing delayed sampleset with tempfile.TemporaryDirectory() as dir_name: - samples = [Sample(data, key=str(i)) for i, data in enumerate(X)] filename = os.path.join(dir_name, "samples.pkl") with open(filename, "wb") as f: @@ -59,7 +57,6 @@ def test_sampleset_collection(): # Testing delayed sampleset cached with tempfile.TemporaryDirectory() as dir_name: - samples = [Sample(data, key=str(i)) for i, data in enumerate(X)] filename = os.path.join(dir_name, "samples.pkl") with open(filename, "wb") as f: diff --git a/tests/test_utils.py b/tests/test_utils.py index 488c796cb3998a3ff59714dc0dcb3f1494cfe91d..8a335bea7a8e646fc547c739f3ace5e445486af7 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -101,7 +101,6 @@ def test_is_instance_nested(): def test_break_sample_set(): - samplesets = [] n_samples = 10 X = np.ones(shape=(n_samples, 2), dtype=int) @@ -109,7 +108,6 @@ def test_break_sample_set(): # Creating a face list of samplesets for i in range(n_samples): - samplesets.append( SampleSet( [ diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index a90090df473497fb6266e4195e44bdba7c176c09..83c4c76ba3378be256dd6e0d052ea69d393135f6 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -65,7 +65,6 @@ class DummyTransformer(TransformerMixin, BaseEstimator): return self def transform(self, X): - # Input validation X = check_array(X) # Check that the input is of the same shape as the one passed @@ -181,7 +180,6 @@ def test_sklearn_compatible_estimator(): def test_function_sample_transfomer(): - X = np.zeros(shape=(10, 2), dtype=int) samples = [bob.pipelines.Sample(data) for data in X] @@ -200,7 +198,6 @@ def test_function_sample_transfomer(): def test_fittable_sample_transformer(): - X = np.ones(shape=(10, 2), dtype=int) samples = [bob.pipelines.Sample(data) for data in X] @@ -214,7 +211,6 @@ def test_fittable_sample_transformer(): def test_tagged_sample_transformer(): - X = np.ones(shape=(10, 2), dtype=int) samples = [bob.pipelines.Sample(data) for data in X] @@ -227,7 +223,6 @@ def test_tagged_sample_transformer(): def test_tagged_input_sample_transformer(): - X = np.ones(shape=(10, 2), dtype=int) samples = [bob.pipelines.Sample(data) for data in X] @@ -242,7 +237,6 @@ def test_tagged_input_sample_transformer(): def test_dask_tag_transformer(): - X = np.ones(shape=(10, 2), dtype=int) samples = [bob.pipelines.Sample(data) for data in X] sample_bags = bob.pipelines.ToDaskBag().transform(samples) @@ -255,7 +249,6 @@ def test_dask_tag_transformer(): def test_dask_tag_checkpoint_transformer(): - X = np.ones(shape=(10, 2), dtype=int) samples = [bob.pipelines.Sample(data) for data in X] sample_bags = bob.pipelines.ToDaskBag().transform(samples) @@ -279,7 +272,6 @@ def test_dask_tag_checkpoint_transformer(): def test_dask_tag_daskml_estimator(): - X, labels = make_blobs( n_samples=1000, n_features=2, @@ -328,7 +320,6 @@ def test_dask_tag_daskml_estimator(): def test_failing_sample_transformer(): - X = np.zeros(shape=(10, 2)) samples = [bob.pipelines.Sample(data) for i, data in enumerate(X)] expected = np.full_like(X, 2, dtype=object) @@ -371,7 +362,6 @@ def test_failing_sample_transformer(): def test_failing_checkpoint_transformer(): - X = np.zeros(shape=(10, 2)) samples = [bob.pipelines.Sample(data, key=i) for i, data in enumerate(X)] expected = np.full_like(X, 2) @@ -470,7 +460,6 @@ def _assert_delayed_samples(samples): def test_checkpoint_function_sample_transfomer(): - X = np.arange(20, dtype=int).reshape(10, 2) samples = [ bob.pipelines.Sample(data, key=str(i)) for i, data in enumerate(X) @@ -576,7 +565,6 @@ def _build_estimator(path, i): def _build_transformer(path, i, force=False): - features_dir = os.path.join(path, f"transformer{i}") estimator = bob.pipelines.wrap( [DummyTransformer, "sample", "checkpoint"], @@ -588,7 +576,6 @@ def _build_transformer(path, i, force=False): def test_checkpoint_fittable_pipeline(): - X = np.ones(shape=(10, 2), dtype=int) samples = [ bob.pipelines.Sample(data, key=str(i)) for i, data in enumerate(X) @@ -613,7 +600,6 @@ def test_checkpoint_fittable_pipeline(): def test_checkpoint_transform_pipeline(): def _run(dask_enabled): - X = np.ones(shape=(10, 2), dtype=int) samples_transform = [ bob.pipelines.Sample(data, key=str(i)) for i, data in enumerate(X) @@ -642,11 +628,9 @@ def test_checkpoint_transform_pipeline(): def test_checkpoint_transform_pipeline_force(): - with tempfile.TemporaryDirectory() as d: def _run(): - X = np.ones(shape=(10, 2), dtype=int) samples_transform = [ bob.pipelines.Sample(data, key=str(i)) @@ -782,7 +766,6 @@ def test_dask_checkpoint_transform_pipeline(): def test_checkpoint_transform_pipeline_with_sampleset(): def _run(dask_enabled): - X = np.ones(shape=(10, 2), dtype=int) samples_transform = bob.pipelines.SampleSet( [ @@ -821,7 +804,6 @@ def test_checkpoint_transform_pipeline_with_sampleset(): def test_estimator_requires_fit(): - all_wraps = [ ["sample"], ["sample", "checkpoint"], diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 6bdab976ca666c0bbcfcfcdf60e5bb04e70e2982..b12646714132c336337fbbf86de03b08f2dae91f 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -67,7 +67,6 @@ def test_delayed_samples_to_dataset(): def _build_iris_dataset(shuffle=False, delayed=False): - iris = datasets.load_iris() X = iris.data @@ -228,7 +227,6 @@ def test_dataset_pipeline_with_failures(): def test_dataset_pipeline_with_dask_ml(): - scaler = dask_ml.preprocessing.StandardScaler() pca = dask_ml.decomposition.PCA(n_components=3, random_state=0) clf = SGDClassifier(random_state=0, loss="log_loss", penalty="l2", tol=1e-3)