Skip to content
Snippets Groups Projects

Making use of LightningDataModule and simplification of data loading

Merged Daniel CARRON requested to merge add-datamodule into main
2 files
+ 10
53
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -14,7 +14,7 @@ from clapper.logging import setup
from torchvision import transforms
from ..base_datamodule import BaseDataModule
from ..dataset import JSONProtocol, TBDataset
from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
from ..shenzhen import _protocols, _raw_data_loader
from ..transforms import ElasticDeformation, RemoveBlackBorders
@@ -59,43 +59,44 @@ class DefaultModule(BaseDataModule):
fieldnames=("data", "label"),
)
if self._cache_samples:
dataset = CachedDataset
else:
dataset = RuntimeDataset
if not self._has_setup_fit and stage == "fit":
self.train_dataset = TBDataset(
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=True),
cache_samples=self._cache_samples,
)
self.validation_dataset = TBDataset(
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
cache_samples=self._cache_samples,
)
self._has_setup_fit = True
if not self._has_setup_predict and stage == "predict":
self.train_dataset = TBDataset(
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=False),
cache_samples=self._cache_samples,
)
self.validation_dataset = TBDataset(
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
cache_samples=self._cache_samples,
)
self._has_setup_predict = True
Loading