Skip to content
Snippets Groups Projects
Commit fc9cb8e4 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Removed TBDataset, using Runtime or Cached datasets instead

parent 1729b307
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -367,50 +367,6 @@ class RuntimeDataset(torch.utils.data.Dataset): ...@@ -367,50 +367,6 @@ class RuntimeDataset(torch.utils.data.Dataset):
return len(self._samples) return len(self._samples)
class TBDataset(torch.utils.data.Dataset):
def __init__(
self,
json_protocol,
protocol,
subset,
raw_data_loader,
transforms,
cache_samples=False,
):
self.json_protocol = json_protocol
self.subset = subset
self.raw_data_loader = raw_data_loader
self.transforms = transforms
self.cache_samples = cache_samples
self._samples = json_protocol.subsets(protocol)[self.subset]
# Dict entry with relative path to files
for s in self._samples:
s["name"] = s["data"]
if self.cache_samples:
logger.info(f"Caching {self.subset} samples")
for sample in tqdm(self._samples):
sample["data"] = self.transforms(
self.raw_data_loader(sample["data"])
)
def __getitem__(self, idx):
if self.cache_samples:
return self._samples[idx]
else:
sample = self._samples[idx].copy()
sample["data"] = self.transforms(
self.raw_data_loader(sample["data"])
)
return sample
def __len__(self):
return len(self._samples)
def get_samples_weights(dataset): def get_samples_weights(dataset):
"""Compute the weights of all the samples of the dataset to balance it """Compute the weights of all the samples of the dataset to balance it
using the sampler of the dataloader. using the sampler of the dataloader.
......
...@@ -14,7 +14,7 @@ from clapper.logging import setup ...@@ -14,7 +14,7 @@ from clapper.logging import setup
from torchvision import transforms from torchvision import transforms
from ..base_datamodule import BaseDataModule from ..base_datamodule import BaseDataModule
from ..dataset import JSONProtocol, TBDataset from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
from ..shenzhen import _protocols, _raw_data_loader from ..shenzhen import _protocols, _raw_data_loader
from ..transforms import ElasticDeformation, RemoveBlackBorders from ..transforms import ElasticDeformation, RemoveBlackBorders
...@@ -59,43 +59,44 @@ class DefaultModule(BaseDataModule): ...@@ -59,43 +59,44 @@ class DefaultModule(BaseDataModule):
fieldnames=("data", "label"), fieldnames=("data", "label"),
) )
if self._cache_samples:
dataset = CachedDataset
else:
dataset = RuntimeDataset
if not self._has_setup_fit and stage == "fit": if not self._has_setup_fit and stage == "fit":
self.train_dataset = TBDataset( self.train_dataset = dataset(
json_protocol, json_protocol,
self._protocol, self._protocol,
"train", "train",
_raw_data_loader, _raw_data_loader,
self._build_transforms(is_train=True), self._build_transforms(is_train=True),
cache_samples=self._cache_samples,
) )
self.validation_dataset = TBDataset( self.validation_dataset = dataset(
json_protocol, json_protocol,
self._protocol, self._protocol,
"validation", "validation",
_raw_data_loader, _raw_data_loader,
self._build_transforms(is_train=False), self._build_transforms(is_train=False),
cache_samples=self._cache_samples,
) )
self._has_setup_fit = True self._has_setup_fit = True
if not self._has_setup_predict and stage == "predict": if not self._has_setup_predict and stage == "predict":
self.train_dataset = TBDataset( self.train_dataset = dataset(
json_protocol, json_protocol,
self._protocol, self._protocol,
"train", "train",
_raw_data_loader, _raw_data_loader,
self._build_transforms(is_train=False), self._build_transforms(is_train=False),
cache_samples=self._cache_samples,
) )
self.validation_dataset = TBDataset( self.validation_dataset = dataset(
json_protocol, json_protocol,
self._protocol, self._protocol,
"validation", "validation",
_raw_data_loader, _raw_data_loader,
self._build_transforms(is_train=False), self._build_transforms(is_train=False),
cache_samples=self._cache_samples,
) )
self._has_setup_predict = True self._has_setup_predict = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment