Skip to content
Snippets Groups Projects
Commit 98c43682 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[tests] Add test for cached dataset

parent d3df91c4
No related branches found
No related tags found
No related merge requests found
Pipeline #91311 failed
......@@ -208,6 +208,8 @@ class _CachedDataset(Dataset):
Which implementation of the multiprocessing context to use. Options are
defined at :py:mod:`multiprocessing`. If set to ``None``, use the
default for the current platform.
disable_pbar
If set, disables progress bars.
"""
def __init__(
......@@ -217,6 +219,7 @@ class _CachedDataset(Dataset):
transforms: TransformSequence = [],
parallel: int = -1,
multiprocessing_context: str | None = None,
disable_pbar: bool = False,
):
self.loader = functools.partial(
_apply_loader_and_transforms,
......@@ -233,7 +236,10 @@ class _CachedDataset(Dataset):
parallel = -1
if parallel < 0:
self.data = [self.loader(k) for k in tqdm.tqdm(raw_dataset, unit="sample")]
self.data = [
self.loader(k)
for k in tqdm.tqdm(raw_dataset, unit="sample", disable=disable_pbar)
]
else:
instances = parallel or multiprocessing.cpu_count()
logger.info(f"Caching dataset using {instances} processes...")
......@@ -244,6 +250,7 @@ class _CachedDataset(Dataset):
tqdm.tqdm(
p.imap(self.loader, raw_dataset),
total=len(raw_dataset),
disable=disable_pbar,
),
)
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Test code for datasets."""
"""Test code for database splits."""
from mednet.data.split import JSONDatabaseSplit
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Test code for datasets."""
import typing
import pytest
import torch
from torchvision import tv_tensors
import mednet.data.classify.typing
import mednet.data.typing
_NUM_SAMPLES = 1000
_raw_dataset = [(f"sample-{k:3d}", k, f"metadata-{k:3d}") for k in range(_NUM_SAMPLES)]
class _RawDataLoader(mednet.data.classify.typing.RawDataLoader):
def sample(
self, sample: tuple[str, int, typing.Any | None]
) -> mednet.data.typing.Sample:
image = torch.rand([1, 128, 128])
image = tv_tensors.Image(image)
return dict(
image=image, name=sample[0], target=self.target(sample), metadata=sample[2]
)
def target(self, sample: typing.Any) -> torch.Tensor:
return torch.FloatTensor([sample[1]])
def id_function(val):
if isinstance(val, dict):
return str(val)
return repr(val)
@pytest.mark.parametrize(
"parallel,multiprocessing_context",
[
(-1, None),
(1, None),
(2, None),
(4, None),
(1, "spawn"),
(2, "spawn"),
(4, "spawn"),
],
ids=id_function, # just changes how pytest prints it
)
def test_cached_dataset(parallel, multiprocessing_context):
from mednet.data.datamodule import _CachedDataset
dataset = _CachedDataset(
raw_dataset=_raw_dataset,
loader=_RawDataLoader(),
parallel=parallel,
multiprocessing_context=multiprocessing_context,
disable_pbar=True,
)
# tests targets
assert len(dataset.targets()) == _NUM_SAMPLES
# checks __len__
assert len(dataset) == _NUM_SAMPLES
# checks __iter__ works
# and returns in due order
for loaded_sample, raw_sample in zip(dataset, _raw_dataset):
assert loaded_sample["name"] == raw_sample[0]
assert loaded_sample["target"].item() == raw_sample[1]
assert loaded_sample["metadata"] == raw_sample[2]
# checks __getitem__
for k, raw_sample in enumerate(_raw_dataset):
loaded_sample = dataset[k]
assert loaded_sample["name"] == raw_sample[0]
assert loaded_sample["target"].item() == raw_sample[1]
assert loaded_sample["metadata"] == raw_sample[2]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment