Skip to content
Snippets Groups Projects
Commit bd79cd2f authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Removed reliance on make_dataset and added method to cache samples for default shenzhen

parent 9a03daba
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -117,7 +117,7 @@ montgomery_rs_f7 = "ptbench.configs.datasets.montgomery_RS.fold_7" ...@@ -117,7 +117,7 @@ montgomery_rs_f7 = "ptbench.configs.datasets.montgomery_RS.fold_7"
montgomery_rs_f8 = "ptbench.configs.datasets.montgomery_RS.fold_8" montgomery_rs_f8 = "ptbench.configs.datasets.montgomery_RS.fold_8"
montgomery_rs_f9 = "ptbench.configs.datasets.montgomery_RS.fold_9" montgomery_rs_f9 = "ptbench.configs.datasets.montgomery_RS.fold_9"
# shenzhen dataset (and cross-validation folds) # shenzhen dataset (and cross-validation folds)
shenzhen = "ptbench.data.shenzhen.default" shenzhen = "ptbench.configs.datasets.shenzhen.default"
shenzhen_rgb = "ptbench.data.shenzhen.rgb" shenzhen_rgb = "ptbench.data.shenzhen.rgb"
shenzhen_f0 = "ptbench.data.shenzhen.fold_0" shenzhen_f0 = "ptbench.data.shenzhen.fold_0"
shenzhen_f1 = "ptbench.data.shenzhen.fold_1" shenzhen_f1 = "ptbench.data.shenzhen.fold_1"
......
...@@ -267,7 +267,6 @@ def get_positive_weights(dataset): ...@@ -267,7 +267,6 @@ def get_positive_weights(dataset):
the positive weight of each class in the dataset given as input the positive weight of each class in the dataset given as input
""" """
targets = [] targets = []
if isinstance(dataset, torch.utils.data.ConcatDataset): if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets: for ds in dataset.datasets:
for s in ds._samples: for s in ds._samples:
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
...@@ -12,9 +12,10 @@ ...@@ -12,9 +12,10 @@
from clapper.logging import setup from clapper.logging import setup
from .. import return_subsets from ....data import return_subsets
from ..base_datamodule import BaseDataModule from ....data.base_datamodule import BaseDataModule
from . import _maker from ....data.dataset import JSONDataset
from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
...@@ -25,6 +26,7 @@ class DefaultModule(BaseDataModule): ...@@ -25,6 +26,7 @@ class DefaultModule(BaseDataModule):
train_batch_size=1, train_batch_size=1,
predict_batch_size=1, predict_batch_size=1,
drop_incomplete_batch=False, drop_incomplete_batch=False,
cache_samples=False,
multiproc_kwargs=None, multiproc_kwargs=None,
): ):
super().__init__( super().__init__(
...@@ -34,14 +36,31 @@ class DefaultModule(BaseDataModule): ...@@ -34,14 +36,31 @@ class DefaultModule(BaseDataModule):
multiproc_kwargs=multiproc_kwargs, multiproc_kwargs=multiproc_kwargs,
) )
self.cache_samples = cache_samples
def setup(self, stage: str): def setup(self, stage: str):
self.dataset = _maker("default") if self.cache_samples:
logger.info(
"Argument cache_samples set to True. Samples will be loaded in memory."
)
samples_loader = _cached_loader
else:
logger.info(
"Argument cache_samples set to False. Samples will be loaded at runtime."
)
samples_loader = _delayed_loader
self.json_dataset = JSONDataset(
protocols=_protocols,
fieldnames=("data", "label"),
loader=samples_loader,
)
( (
self.train_dataset, self.train_dataset,
self.validation_dataset, self.validation_dataset,
self.extra_validation_datasets, self.extra_validation_datasets,
self.predict_dataset, self.predict_dataset,
) = return_subsets(self.dataset) ) = return_subsets(self.json_dataset, "default")
datamodule = DefaultModule datamodule = DefaultModule
...@@ -13,7 +13,9 @@ Reference: [PASA-2019]_ ...@@ -13,7 +13,9 @@ Reference: [PASA-2019]_
from torch import empty from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torchvision import transforms
from ...data.transforms import ElasticDeformation
from ...models.pasa import PASA from ...models.pasa import PASA
# config # config
...@@ -26,5 +28,9 @@ optimizer = "Adam" ...@@ -26,5 +28,9 @@ optimizer = "Adam"
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
train_transforms = transforms.Compose([ElasticDeformation(p=0.8)])
# model # model
model = PASA(criterion, criterion_valid, optimizer, optimizer_configs) model = PASA(
train_transforms, criterion, criterion_valid, optimizer, optimizer_configs
)
...@@ -6,6 +6,8 @@ import torch ...@@ -6,6 +6,8 @@ import torch
from clapper.logging import setup from clapper.logging import setup
from .utils import SampleListDataset
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
...@@ -301,41 +303,45 @@ def get_positive_weights(dataset): ...@@ -301,41 +303,45 @@ def get_positive_weights(dataset):
return positive_weights return positive_weights
def return_subsets(dataset): def return_subsets(dataset, protocol):
train_dataset = None train_dataset = None
validation_dataset = None validation_dataset = None
extra_validation_datasets = None extra_validation_datasets = None
predict_dataset = None predict_dataset = None
if "__train__" in dataset: subsets = dataset.subsets(protocol)
logger.info("Found (dedicated) '__train__' set for training") if "train" in subsets.keys():
train_dataset = dataset["__train__"] train_dataset = SampleListDataset(subsets["train"], [])
else:
train_dataset = dataset["train"] if "validation" in subsets.keys():
validation_dataset = SampleListDataset(subsets["validation"], [])
if "__valid__" in dataset: else:
logger.info("Found (dedicated) '__valid__' set for validation") logger.warning(
validation_dataset = dataset["__valid__"] "No validation dataset found, using training set instead."
if "__extra_valid__" in dataset:
if not isinstance(dataset["__extra_valid__"], list):
raise RuntimeError(
f"If present, dataset['__extra_valid__'] must be a list, "
f"but you passed a {type(dataset['__extra_valid__'])}, "
f"which is invalid."
) )
logger.info( validation_dataset = train_dataset
f"Found {len(dataset['__extra_valid__'])} extra validation "
f"set(s) to be tracked during training" if "__extra_valid__" in subsets.keys():
) if not isinstance(subsets["__extra_valid__"], list):
logger.info( raise RuntimeError(
"Extra validation sets are NOT used for model checkpointing!" f"If present, dataset['__extra_valid__'] must be a list, "
) f"but you passed a {type(subsets['__extra_valid__'])}, "
extra_validation_datasets = dataset["__extra_valid__"] f"which is invalid."
else: )
extra_validation_datasets = None logger.info(
f"Found {len(subsets['__extra_valid__'])} extra validation "
f"set(s) to be tracked during training"
)
logger.info(
"Extra validation sets are NOT used for model checkpointing!"
)
extra_validation_datasets = SampleListDataset(
subsets["__extra_valid__"], []
)
else:
extra_validation_datasets = None
predict_dataset = dataset predict_dataset = subsets
return ( return (
train_dataset, train_dataset,
......
...@@ -10,6 +10,7 @@ import pathlib ...@@ -10,6 +10,7 @@ import pathlib
import random import random
import torch import torch
import tqdm
from torchvision.transforms import RandomRotation from torchvision.transforms import RandomRotation
...@@ -169,12 +170,14 @@ class JSONDataset: ...@@ -169,12 +170,14 @@ class JSONDataset:
retval = {} retval = {}
for subset, samples in data.items(): for subset, samples in data.items():
logger.info(f"Loading subset {subset} samples.")
retval[subset] = [ retval[subset] = [
self._loader( self._loader(
dict(protocol=protocol, subset=subset, order=n), dict(protocol=protocol, subset=subset, order=n),
dict(zip(self.fieldnames, k)), dict(zip(self.fieldnames, k)),
) )
for n, k in enumerate(samples) for n, k in tqdm.tqdm(enumerate(samples))
] ]
return retval return retval
......
...@@ -10,7 +10,7 @@ import functools ...@@ -10,7 +10,7 @@ import functools
import PIL.Image import PIL.Image
from .sample import DelayedSample from .sample import DelayedSample, Sample
def load_pil(path): def load_pil(path):
...@@ -70,6 +70,14 @@ def load_pil_rgb(path): ...@@ -70,6 +70,14 @@ def load_pil_rgb(path):
return load_pil(path).convert("RGB") return load_pil(path).convert("RGB")
def make_cached(sample, loader, key=None):
return Sample(
loader(sample),
key=key or sample["data"],
label=sample["label"],
)
def make_delayed(sample, loader, key=None): def make_delayed(sample, loader, key=None):
"""Returns a delayed-loading Sample object. """Returns a delayed-loading Sample object.
......
...@@ -19,16 +19,15 @@ the daily routine using Philips DR Digital Diagnose systems. ...@@ -19,16 +19,15 @@ the daily routine using Philips DR Digital Diagnose systems.
* Validation samples: 16% of TB and healthy CXR (including labels) * Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels) * Test samples: 20% of TB and healthy CXR (including labels)
""" """
import importlib.resources import importlib.resources
import os import os
from clapper.logging import setup from clapper.logging import setup
from torchvision import transforms
from ...utils.rc import load_rc from ...utils.rc import load_rc
from .. import make_dataset from ..loader import load_pil_baw, make_cached, make_delayed
from ..dataset import JSONDataset from ..transforms import RemoveBlackBorders
from ..loader import load_pil_baw, make_delayed
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
...@@ -49,45 +48,32 @@ _protocols = [ ...@@ -49,45 +48,32 @@ _protocols = [
_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir)) _datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
_resize_size = 512
_cc_size = 512
_data_transforms = transforms.Compose(
[
RemoveBlackBorders(),
transforms.Resize(_resize_size),
transforms.CenterCrop(_cc_size),
transforms.ToTensor(),
]
)
def _raw_data_loader(sample): def _raw_data_loader(sample):
raw_data = load_pil_baw(os.path.join(_datadir, sample["data"]))
return dict( return dict(
data=load_pil_baw(os.path.join(_datadir, sample["data"])), data=_data_transforms(raw_data),
label=sample["label"], label=sample["label"],
) )
def _loader(context, sample): def _cached_loader(context, sample):
return make_cached(sample, _raw_data_loader)
def _delayed_loader(context, sample):
# "context" is ignored in this case - database is homogeneous # "context" is ignored in this case - database is homogeneous
# we returned delayed samples to avoid loading all images at once # we returned delayed samples to avoid loading all images at once
return make_delayed(sample, _raw_data_loader) return make_delayed(sample, _raw_data_loader)
json_dataset = JSONDataset(
protocols=_protocols, fieldnames=("data", "label"), loader=_loader
)
"""Shenzhen dataset object."""
def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
from torchvision import transforms
from ..transforms import ElasticDeformation, RemoveBlackBorders
post_transforms = []
if RGB:
post_transforms = [
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
]
return make_dataset(
[json_dataset.subsets(protocol)],
[
RemoveBlackBorders(),
transforms.Resize(resize_size),
transforms.CenterCrop(cc_size),
],
[ElasticDeformation(p=0.8)],
post_transforms,
)
...@@ -19,6 +19,7 @@ import numpy ...@@ -19,6 +19,7 @@ import numpy
import PIL.Image import PIL.Image
from scipy.ndimage import gaussian_filter, map_coordinates from scipy.ndimage import gaussian_filter, map_coordinates
from torchvision import transforms
class SingleAutoLevel16to8: class SingleAutoLevel16to8:
...@@ -76,6 +77,8 @@ class ElasticDeformation: ...@@ -76,6 +77,8 @@ class ElasticDeformation:
self.random_state = random_state self.random_state = random_state
self.p = p self.p = p
self.tensor_transform = transforms.Compose([transforms.ToTensor()])
def __call__(self, img): def __call__(self, img):
if random.random() < self.p: if random.random() < self.p:
img = numpy.asarray(img) img = numpy.asarray(img)
...@@ -114,6 +117,6 @@ class ElasticDeformation: ...@@ -114,6 +117,6 @@ class ElasticDeformation:
result[:, :] = map_coordinates( result[:, :] = map_coordinates(
img[:, :], indices, order=self.spline_order, mode=self.mode img[:, :], indices, order=self.spline_order, mode=self.mode
).reshape(shape) ).reshape(shape)
return PIL.Image.fromarray(result) return self.tensor_transform(PIL.Image.fromarray(result))
else: else:
return img return img
...@@ -17,16 +17,23 @@ class PASA(pl.LightningModule): ...@@ -17,16 +17,23 @@ class PASA(pl.LightningModule):
""" """
def __init__( def __init__(
self, criterion, criterion_valid, optimizer, optimizer_configs self,
train_transforms,
criterion,
criterion_valid,
optimizer,
optimizer_configs,
): ):
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters(ignore=["train_transforms"])
self.name = "pasa" self.name = "pasa"
self.normalizer = TorchVisionNormalizer(nb_channels=1) self.normalizer = TorchVisionNormalizer(nb_channels=1)
self.train_transforms = train_transforms
# First convolution block # First convolution block
self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1)) self.fc1 = nn.Conv2d(1, 4, (3, 3), (2, 2), (1, 1))
self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1)) self.fc2 = nn.Conv2d(4, 16, (3, 3), (2, 2), (1, 1))
...@@ -126,6 +133,10 @@ class PASA(pl.LightningModule): ...@@ -126,6 +133,10 @@ class PASA(pl.LightningModule):
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[1] images = batch[1]
labels = batch[2] labels = batch[2]
for img in images:
img = torch.unsqueeze(
self.train_transforms(torch.squeeze(img, 0)), 0
)
# Increase label dimension if too low # Increase label dimension if too low
# Allows single and multiclass usage # Allows single and multiclass usage
......
...@@ -165,6 +165,14 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -165,6 +165,14 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
default="cpu", default="cpu",
cls=ResourceOption, cls=ResourceOption,
) )
@click.option(
"--cache-samples",
help="If set to True, loads the sample into memory, otherwise loads them at runtime.",
required=True,
show_default=True,
default=False,
cls=ResourceOption,
)
@click.option( @click.option(
"--seed", "--seed",
"-s", "-s",
...@@ -235,6 +243,7 @@ def train( ...@@ -235,6 +243,7 @@ def train(
datamodule, datamodule,
checkpoint_period, checkpoint_period,
accelerator, accelerator,
cache_samples,
seed, seed,
parallel, parallel,
normalization, normalization,
...@@ -293,6 +302,7 @@ def train( ...@@ -293,6 +302,7 @@ def train(
train_batch_size=batch_chunk_size, train_batch_size=batch_chunk_size,
drop_incomplete_batch=drop_incomplete_batch, drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs, multiproc_kwargs=multiproc_kwargs,
cache_samples=cache_samples,
) )
# Manually calling these as we need to access some values to reweight the criterion # Manually calling these as we need to access some values to reweight the criterion
datamodule.prepare_data() datamodule.prepare_data()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment