Skip to content
Snippets Groups Projects
Commit 2fcec25b authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Removed TBDataset, using Runtime or Cached datasets instead

parent d5b173b1
No related branches found
No related tags found
No related merge requests found
Pipeline #75312 canceled
This commit is part of merge request !6. Comments created here will be created in the context of that merge request.
......@@ -367,50 +367,6 @@ class RuntimeDataset(torch.utils.data.Dataset):
return len(self._samples)
class TBDataset(torch.utils.data.Dataset):
def __init__(
self,
json_protocol,
protocol,
subset,
raw_data_loader,
transforms,
cache_samples=False,
):
self.json_protocol = json_protocol
self.subset = subset
self.raw_data_loader = raw_data_loader
self.transforms = transforms
self.cache_samples = cache_samples
self._samples = json_protocol.subsets(protocol)[self.subset]
# Dict entry with relative path to files
for s in self._samples:
s["name"] = s["data"]
if self.cache_samples:
logger.info(f"Caching {self.subset} samples")
for sample in tqdm(self._samples):
sample["data"] = self.transforms(
self.raw_data_loader(sample["data"])
)
def __getitem__(self, idx):
if self.cache_samples:
return self._samples[idx]
else:
sample = self._samples[idx].copy()
sample["data"] = self.transforms(
self.raw_data_loader(sample["data"])
)
return sample
def __len__(self):
return len(self._samples)
def get_samples_weights(dataset):
"""Compute the weights of all the samples of the dataset to balance it
using the sampler of the dataloader.
......
......@@ -14,7 +14,7 @@ from clapper.logging import setup
from torchvision import transforms
from ..base_datamodule import BaseDataModule
from ..dataset import JSONProtocol, TBDataset
from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
from ..shenzhen import _protocols, _raw_data_loader
from ..transforms import ElasticDeformation, RemoveBlackBorders
......@@ -59,43 +59,44 @@ class DefaultModule(BaseDataModule):
fieldnames=("data", "label"),
)
if self._cache_samples:
dataset = CachedDataset
else:
dataset = RuntimeDataset
if not self._has_setup_fit and stage == "fit":
self.train_dataset = TBDataset(
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=True),
cache_samples=self._cache_samples,
)
self.validation_dataset = TBDataset(
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
cache_samples=self._cache_samples,
)
self._has_setup_fit = True
if not self._has_setup_predict and stage == "predict":
self.train_dataset = TBDataset(
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=False),
cache_samples=self._cache_samples,
)
self.validation_dataset = TBDataset(
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
cache_samples=self._cache_samples,
)
self._has_setup_predict = True
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment