diff --git a/src/bob/pipelines/dataset/protocols/retrieve.py b/src/bob/pipelines/dataset/protocols/retrieve.py index 8bab0d1995e84d9722f0773ceae9b70dbe2180a4..713a68c3c2d1b1c47d6004b9b2724a64dd458af3 100644 --- a/src/bob/pipelines/dataset/protocols/retrieve.py +++ b/src/bob/pipelines/dataset/protocols/retrieve.py @@ -171,9 +171,9 @@ def get_protocol_path( for protocol_path in protocol_paths: if archive.is_archive(protocol_path): _base, inner = archive.path_and_subdir(protocol_path) - if inner.stem == protocol: + if inner.name == protocol: return protocol_path - elif protocol_path.stem == protocol: + elif protocol_path.name == protocol: return protocol_path logger.warning(f"Protocol {protocol} not found in {database_name}.") return None @@ -232,14 +232,14 @@ def list_protocol_names( # Handle a database archive having database_name as top-level directory if len(top_level_dirs) == 1 and top_level_dirs[0].name == database_name: return [ - p.stem + p.name for p in archive.list_dirs( final_path, inner_dir=database_name, show_files=False ) ] - return [p.stem for p in top_level_dirs] + return [p.name for p in top_level_dirs] # Not an archive: list the dirs - return [p.stem for p in final_path.iterdir() if p.is_dir()] + return [p.name for p in final_path.iterdir() if p.is_dir()] def open_definition_file( @@ -316,7 +316,8 @@ def list_group_names( subdir=subdir, database_filename=database_filename, ) - return [p.stem for p in paths] + # Supports groups as files or dirs + return [p.stem for p in paths] # ! This means group can't include a '.' def download_protocol_definition(