diff --git a/src/ptbench/configs/datasets/shenzhen/default.py b/src/ptbench/configs/datasets/shenzhen/default.py index 967d40e9358c8f946f876199cac56a921a89606f..09aa54748d382122fd1f67a3b3e1871f3f0aa132 100644 --- a/src/ptbench/configs/datasets/shenzhen/default.py +++ b/src/ptbench/configs/datasets/shenzhen/default.py @@ -14,7 +14,7 @@ from clapper.logging import setup from ....data import return_subsets from ....data.base_datamodule import BaseDataModule -from ....data.dataset import JSONDataset +from ....data.dataset import JSONProtocol from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -52,7 +52,7 @@ class DefaultModule(BaseDataModule): ) samples_loader = _delayed_loader - json_dataset = JSONDataset( + json_protocol = JSONProtocol( protocols=_protocols, fieldnames=("data", "label"), loader=samples_loader, @@ -63,7 +63,7 @@ class DefaultModule(BaseDataModule): self.train_dataset, self.validation_dataset, self.extra_validation_datasets, - ) = return_subsets(json_dataset, "default", stage) + ) = return_subsets(json_protocol, "default", stage) self.has_setup_fit = True if not self.has_setup_predict and stage == "predict": @@ -71,7 +71,7 @@ class DefaultModule(BaseDataModule): self.train_dataset, self.validation_dataset, self.extra_validation_datasets, - ) = return_subsets(json_dataset, "default", stage) + ) = return_subsets(json_protocol, "default", stage) self.has_setup_predict = True diff --git a/src/ptbench/configs/datasets/shenzhen/rgb.py b/src/ptbench/configs/datasets/shenzhen/rgb.py index 2506e79da82008b636160f88f353080627ec9e00..7cf77faa1406e6b34b0db06415002c2e23de49ab 100644 --- a/src/ptbench/configs/datasets/shenzhen/rgb.py +++ b/src/ptbench/configs/datasets/shenzhen/rgb.py @@ -14,7 +14,7 @@ from torchvision import transforms from ....data import return_subsets from ....data.base_datamodule import BaseDataModule -from ....data.dataset import JSONDataset +from ....data.dataset import JSONProtocol from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") @@ -57,7 +57,7 @@ class DefaultModule(BaseDataModule): ) samples_loader = _delayed_loader - self.json_dataset = JSONDataset( + self.json_protocol = JSONProtocol( protocols=_protocols, fieldnames=("data", "label"), loader=samples_loader, @@ -69,7 +69,7 @@ class DefaultModule(BaseDataModule): self.train_dataset, self.validation_dataset, self.extra_validation_datasets, - ) = return_subsets(self.json_dataset, "default", stage) + ) = return_subsets(self.json_protocol, "default", stage) self.has_setup_fit = True diff --git a/src/ptbench/data/dataset.py b/src/ptbench/data/dataset.py index 5966f392a7c66f60aeb089f92f8d88b03738a384..b568e78a225dc515c5e0b53d7d6b9a5b64f95cd1 100644 --- a/src/ptbench/data/dataset.py +++ b/src/ptbench/data/dataset.py @@ -20,7 +20,7 @@ RANDOM_ROTATION = [RandomRotation(15)] logger = logging.getLogger(__name__) -class JSONDataset: +class JSONProtocol: """Generic multi-protocol/subset filelist dataset that yields samples. To create a new dataset, you need to provide one or more JSON formatted