Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • medai/software/mednet
1 result
Show changes
Commits on Source (20)
Showing
with 1192 additions and 1067 deletions
...@@ -15,33 +15,33 @@ dynamic = ["readme"] ...@@ -15,33 +15,33 @@ dynamic = ["readme"]
license = { text = "GNU General Public License v3 (GPLv3)" } license = { text = "GNU General Public License v3 (GPLv3)" }
authors = [{ name = "Geoffrey Raposo", email = "geoffrey@raposo.ch" }] authors = [{ name = "Geoffrey Raposo", email = "geoffrey@raposo.ch" }]
maintainers = [ maintainers = [
{ name = "Andre Anjos", email = "andre.anjos@idiap.ch" }, { name = "Andre Anjos", email = "andre.anjos@idiap.ch" },
{ name = "Daniel Carron", email = "daniel.carron@idiap.ch" }, { name = "Daniel Carron", email = "daniel.carron@idiap.ch" },
] ]
classifiers = [ classifiers = [
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Natural Language :: English", "Natural Language :: English",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Software Development :: Libraries :: Python Modules",
] ]
dependencies = [ dependencies = [
"clapper", "clapper",
"click", "click",
"numpy", "numpy",
"pandas", "pandas",
"scipy", "scipy",
"scikit-learn", "scikit-learn",
"tqdm", "tqdm",
"psutil", "psutil",
"tabulate", "tabulate",
"matplotlib", "matplotlib",
"pillow", "pillow",
"torch>=1.8", "torch>=1.8",
"torchvision>=0.10", "torchvision>=0.10",
"lightning>=2.0.3", "lightning>=2.0.3",
"tensorboard", "tensorboard",
] ]
[project.urls] [project.urls]
...@@ -53,13 +53,13 @@ changelog = "https://gitlab.idiap.ch/biosignal/software/ptbench/-/releases" ...@@ -53,13 +53,13 @@ changelog = "https://gitlab.idiap.ch/biosignal/software/ptbench/-/releases"
[project.optional-dependencies] [project.optional-dependencies]
qa = ["pre-commit"] qa = ["pre-commit"]
doc = [ doc = [
"sphinx", "sphinx",
"furo", "furo",
"sphinx-autodoc-typehints", "sphinx-autodoc-typehints",
"auto-intersphinx", "auto-intersphinx",
"sphinx-copybutton", "sphinx-copybutton",
"sphinx-inline-tabs", "sphinx-inline-tabs",
"sphinx-click", "sphinx-click",
] ]
test = ["pytest", "pytest-cov", "coverage"] test = ["pytest", "pytest-cov", "coverage"]
......
...@@ -6,19 +6,30 @@ ...@@ -6,19 +6,30 @@
from torch import empty from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
from ...models.alexnet import Alexnet from ...models.alexnet import Alexnet
# config # optimizer
optimizer = SGD
optimizer_configs = {"lr": 0.01, "momentum": 0.1} optimizer_configs = {"lr": 0.01, "momentum": 0.1}
# optimizer
optimizer = "SGD"
# criterion # criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model # model
model = Alexnet( model = Alexnet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=False,
augmentation_transforms=augmentation_transforms,
) )
...@@ -6,19 +6,30 @@ ...@@ -6,19 +6,30 @@
from torch import empty from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
from ...models.alexnet import Alexnet from ...models.alexnet import Alexnet
# config
optimizer_configs = {"lr": 0.001, "momentum": 0.1}
# optimizer # optimizer
optimizer = "SGD" optimizer = SGD
optimizer_configs = {"lr": 0.01, "momentum": 0.1}
# criterion # criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model # model
model = Alexnet( model = Alexnet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=True,
augmentation_transforms=augmentation_transforms,
) )
...@@ -6,20 +6,30 @@ ...@@ -6,20 +6,30 @@
from torch import empty from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.densenet import Densenet from ...models.densenet import Densenet
# config
optimizer_configs = {"lr": 0.0001}
# optimizer # optimizer
optimizer = "Adam" optimizer = Adam
optimizer_configs = {"lr": 0.0001}
# criterion # criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model # model
model = Densenet( model = Densenet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=False,
augmentation_transforms=augmentation_transforms,
) )
...@@ -6,20 +6,30 @@ ...@@ -6,20 +6,30 @@
from torch import empty from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.densenet import Densenet from ...models.densenet import Densenet
# config
optimizer_configs = {"lr": 0.01}
# optimizer # optimizer
optimizer = "Adam" optimizer = Adam
optimizer_configs = {"lr": 0.0001}
# criterion # criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model # model
model = Densenet( model = Densenet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=True,
augmentation_transforms=augmentation_transforms,
) )
...@@ -13,18 +13,30 @@ Reference: [PASA-2019]_ ...@@ -13,18 +13,30 @@ Reference: [PASA-2019]_
from torch import empty from torch import empty
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.pasa import PASA from ...models.pasa import PASA
# config
optimizer_configs = {"lr": 8e-5}
# optimizer # optimizer
optimizer = "Adam" optimizer = Adam
optimizer_configs = {"lr": 8e-5}
# criterion # criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1)) criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1)) criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [ElasticDeformation(p=0.8)]
# from torchvision.transforms.v2 import ElasticTransform, InterpolationMode
# augmentation_transforms = [ElasticTransform(alpha=1000.0, sigma=30.0, interpolation=InterpolationMode.NEAREST)]
# model # model
model = PASA(criterion, criterion_valid, optimizer, optimizer_configs) model = PASA(
criterion,
criterion_valid,
optimizer,
optimizer_configs,
augmentation_transforms=augmentation_transforms,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import multiprocessing
import sys
import lightning as pl
import torch
from clapper.logging import setup
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from .dataset import get_samples_weights
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class BaseDataModule(pl.LightningDataModule):
def __init__(
self,
batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
parallel=-1,
):
super().__init__()
self.batch_size = batch_size
self.batch_chunk_count = batch_chunk_count
self.train_dataset = None
self.validation_dataset = None
self.extra_validation_datasets = None
self._raw_data_transforms = []
self._model_transforms = []
self._augmentation_transforms = []
self.drop_incomplete_batch = drop_incomplete_batch
self.pin_memory = (
torch.cuda.is_available()
) # should only be true if GPU available and using it
self.parallel = parallel
def setup(self, stage: str):
# Implemented by user
# Must define self.train_dataset, self.validation_dataset, self.extra_validation_datasets and self.predict_dataset
raise NotImplementedError
def train_dataloader(self):
train_samples_weights = get_samples_weights(self.train_dataset)
multiproc_kwargs = self._setup_multiproc(self.parallel)
train_sampler = WeightedRandomSampler(
train_samples_weights, len(train_samples_weights), replacement=True
)
return DataLoader(
self.train_dataset,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
sampler=train_sampler,
**multiproc_kwargs,
)
def val_dataloader(self):
loaders_dict = {}
multiproc_kwargs = self._setup_multiproc(self.parallel)
val_loader = DataLoader(
dataset=self.validation_dataset,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
**multiproc_kwargs,
)
loaders_dict["validation_loader"] = val_loader
if self.extra_validation_datasets is not None:
for set_idx, extra_set in enumerate(self.extra_validation_datasets):
extra_val_loader = DataLoader(
dataset=extra_set,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
**multiproc_kwargs,
)
loaders_dict[
f"extra_validation_loader{set_idx}"
] = extra_val_loader
return loaders_dict
def predict_dataloader(self):
loaders_dict = {}
loaders_dict["train_loader"] = self.train_dataloader()
for k, v in self.val_dataloader().items():
loaders_dict[k] = v
return loaders_dict
def update_module_properties(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def _compute_chunk_size(self, batch_size, chunk_count):
batch_chunk_size = batch_size
if batch_size % chunk_count != 0:
# batch_size must be divisible by batch_chunk_count.
raise RuntimeError(
f"--batch-size ({batch_size}) must be divisible by "
f"--batch-chunk-size ({chunk_count})."
)
else:
batch_chunk_size = batch_size // chunk_count
return batch_chunk_size
def _setup_multiproc(self, parallel):
multiproc_kwargs = dict()
if parallel < 0:
multiproc_kwargs["num_workers"] = 0
else:
multiproc_kwargs["num_workers"] = (
parallel or multiprocessing.cpu_count()
)
if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
multiproc_kwargs[
"multiprocessing_context"
] = multiprocessing.get_context("spawn")
return multiproc_kwargs
def _build_transforms(self, is_train):
all_transforms = self.raw_data_transforms + self.model_transforms
if is_train:
all_transforms = all_transforms + self.augmentation_transforms
# all_transforms.append(transforms.ToTensor())
return transforms.Compose(all_transforms)
@property
def raw_data_transforms(self):
return self._raw_data_transforms
@raw_data_transforms.setter
def raw_data_transforms(self, transforms):
self._raw_data_transforms = transforms
@property
def model_transforms(self):
return self._model_transforms
@model_transforms.setter
def model_transforms(self, transforms):
self._model_transforms = transforms
@property
def augmentation_transforms(self):
return self._augmentation_transforms
@augmentation_transforms.setter
def augmentation_transforms(self, transforms):
self._augmentation_transforms = transforms
def get_dataset_from_module(module, stage, **module_args):
"""Instantiates a DataModule and retrieves the corresponding dataset.
Useful when combining multiple datasets.
"""
module_instance = module(**module_args)
module_instance.prepare_data()
module_instance.setup(stage=stage)
dataset = module_instance.dataset
return dataset
This diff is collapsed.
...@@ -2,460 +2,19 @@ ...@@ -2,460 +2,19 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
import csv
import json
import logging import logging
import os
import pathlib
import torch import torch
import torch.utils.data
from tqdm import tqdm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class JSONProtocol: def _get_positive_weights(dataloader):
"""Generic multi-protocol/subset filelist dataset that yields samples.
To create a new dataset, you need to provide one or more JSON formatted
filelists (one per protocol) with the following contents:
.. code-block:: json
{
"subset1": [
[
"value1",
"value2",
"value3"
],
[
"value4",
"value5",
"value6"
]
],
"subset2": [
]
}
Your dataset many contain any number of subsets, but all sample entries
must contain the same number of fields.
Parameters
----------
protocols : list, dict
Paths to one or more JSON formatted files containing the various
protocols to be recognized by this dataset, or a dictionary, mapping
protocol names to paths (or opened file objects) of CSV files.
Internally, we save a dictionary where keys default to the basename of
paths (list input).
fieldnames : list, tuple
An iterable over the field names (strings) to assign to each entry in
the JSON file. It should have as many items as fields in each entry of
the JSON file.
loader : object
A function that receives as input, a context dictionary (with at least
a "protocol" and "subset" keys indicating which protocol and subset are
being served), and a dictionary with ``{fieldname: value}`` entries,
and returns an object with at least 2 attributes:
* ``key``: which must be a unique string for every sample across
subsets in a protocol, and
* ``data``: which contains the data associated witht this sample
"""
def __init__(self, protocols, fieldnames):
if isinstance(protocols, dict):
self._protocols = protocols
else:
self._protocols = {
os.path.basename(
str(k).replace("".join(pathlib.Path(k).suffixes), "")
): k
for k in protocols
}
self.fieldnames = fieldnames
def check(self, limit=0):
"""For each protocol, check if all data can be correctly accessed.
This function assumes each sample has a ``data`` and a ``key``
attribute. The ``key`` attribute should be a string, or representable
as such.
Parameters
----------
limit : int
Maximum number of samples to check (in each protocol/subset
combination) in this dataset. If set to zero, then check
everything.
Returns
-------
errors : int
Number of errors found
"""
logger.info("Checking dataset...")
errors = 0
for proto in self._protocols:
logger.info(f"Checking protocol '{proto}'...")
for name, samples in self.subsets(proto).items():
logger.info(f"Checking subset '{name}'...")
if limit:
logger.info(f"Checking at most first '{limit}' samples...")
samples = samples[:limit]
for pos, sample in enumerate(samples):
try:
sample.data # may trigger data loading
logger.info(f"{sample.key}: OK")
except Exception as e:
logger.error(
f"Found error loading entry {pos} in subset {name} "
f"of protocol {proto} from file "
f"'{self._protocols[proto]}': {e}"
)
errors += 1
return errors
def subsets(self, protocol):
"""Returns all subsets in a protocol.
This method will load JSON information for a given protocol and return
all subsets of the given protocol after converting each entry through
the loader function.
Parameters
----------
protocol : str
Name of the protocol data to load
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of objects (respecting
the ``key``, ``data`` interface).
"""
fileobj = self._protocols[protocol]
if isinstance(fileobj, (str, bytes, pathlib.Path)):
if str(fileobj).endswith(".bz2"):
import bz2
with bz2.open(self._protocols[protocol]) as f:
data = json.load(f)
else:
with open(self._protocols[protocol]) as f:
data = json.load(f)
else:
data = json.load(fileobj)
fileobj.seek(0)
retval = {}
for subset, samples in data.items():
logger.info(f"Loading subset {subset} samples.")
retval[subset] = [
dict(zip(self.fieldnames, k))
for n, k in enumerate(tqdm(samples))
]
return retval
class CSVDataset:
"""Generic multi-subset filelist dataset that yields samples.
To create a new dataset, you only need to provide a CSV formatted filelist
using any separator (e.g. comma, space, semi-colon) with the following
information:
.. code-block:: text
value1,value2,value3
value4,value5,value6
...
Notice that all rows must have the same number of entries.
Parameters
----------
subsets : list, dict
Paths to one or more CSV formatted files containing the various subsets
to be recognized by this dataset, or a dictionary, mapping subset names
to paths (or opened file objects) of CSV files. Internally, we save a
dictionary where keys default to the basename of paths (list input).
fieldnames : list, tuple
An iterable over the field names (strings) to assign to each column in
the CSV file. It should have as many items as fields in each row of
the CSV file(s).
loader : object
A function that receives as input, a context dictionary (with, at
least, a "subset" key indicating which subset is being served), and a
dictionary with ``{key: path}`` entries, and returns a dictionary with
the loaded data.
"""
def __init__(self, subsets, fieldnames, loader):
if isinstance(subsets, dict):
self._subsets = subsets
else:
self._subsets = {
os.path.basename(
str(k).replace("".join(pathlib.Path(k).suffixes), "")
): k
for k in subsets
}
self.fieldnames = fieldnames
self._loader = loader
def check(self, limit=0):
"""For each subset, check if all data can be correctly accessed.
This function assumes each sample has a ``data`` and a ``key``
attribute. The ``key`` attribute should be a string, or representable
as such.
Parameters
----------
limit : int
Maximum number of samples to check (in each protocol/subset
combination) in this dataset. If set to zero, then check
everything.
Returns
-------
errors : int
Number of errors found
"""
logger.info("Checking dataset...")
errors = 0
for name in self._subsets.keys():
logger.info(f"Checking subset '{name}'...")
samples = self.samples(name)
if limit:
logger.info(f"Checking at most first '{limit}' samples...")
samples = samples[:limit]
for pos, sample in enumerate(samples):
try:
sample.data # may trigger data loading
logger.info(f"{sample.key}: OK")
except Exception as e:
logger.error(
f"Found error loading entry {pos} in subset {name} "
f"from file '{self._subsets[name]}': {e}"
)
errors += 1
return errors
def subsets(self):
"""Returns all available subsets at once.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of objects (respecting
the ``key``, ``data`` interface).
"""
return {k: self.samples(k) for k in self._subsets.keys()}
def samples(self, subset):
"""Returns all samples in a subset.
This method will load CSV information for a given subset and return
all samples of the given subset after passing each entry through the
loading function.
Parameters
----------
subset : str
Name of the subset data to load
Returns
-------
subset : list
A lists of objects (respecting the ``key``, ``data`` interface).
"""
fileobj = self._subsets[subset]
if isinstance(fileobj, (str, bytes, pathlib.Path)):
with open(self._subsets[subset], newline="") as f:
cf = csv.reader(f)
samples = [k for k in cf]
else:
cf = csv.reader(fileobj)
samples = [k for k in cf]
fileobj.seek(0)
return [
self._loader(
dict(subset=subset, order=n), dict(zip(self.fieldnames, k))
)
for n, k in enumerate(samples)
]
class CachedDataset(torch.utils.data.Dataset):
def __init__(
self, json_protocol, protocol, subset, raw_data_loader, transforms
):
self.json_protocol = json_protocol
self.subset = subset
self.raw_data_loader = raw_data_loader
self.transforms = transforms
self._samples = json_protocol.subsets(protocol)[self.subset]
# Dict entry with relative path to files, used during prediction
for s in self._samples:
s["name"] = s["data"]
logger.info(f"Caching {self.subset} samples")
for sample in tqdm(self._samples):
sample["data"] = self.raw_data_loader(sample["data"])
def __getitem__(self, idx):
sample = self._samples[idx].copy()
sample["data"] = self.transforms(sample["data"])
return sample
def __len__(self):
return len(self._samples)
class RuntimeDataset(torch.utils.data.Dataset):
def __init__(
self, json_protocol, protocol, subset, raw_data_loader, transforms
):
self.json_protocol = json_protocol
self.subset = subset
self.raw_data_loader = raw_data_loader
self.transforms = transforms
self._samples = json_protocol.subsets(protocol)[self.subset]
# Dict entry with relative path to files
for s in self._samples:
s["name"] = s["data"]
def __getitem__(self, idx):
sample = self._samples[idx].copy()
sample["data"] = self.transforms(self.raw_data_loader(sample["data"]))
return sample
def __len__(self):
return len(self._samples)
def get_samples_weights(dataset):
"""Compute the weights of all the samples of the dataset to balance it
using the sampler of the dataloader.
This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
and computes the weights to balance each class in the dataset and the
datasets themselves if we have a ConcatDataset.
Parameters
----------
dataset : torch.utils.data.dataset.Dataset
An instance of torch.utils.data.dataset.Dataset
ConcatDataset are supported
Returns
-------
samples_weights : :py:class:`torch.Tensor`
the weights for all the samples in the dataset given as input
"""
samples_weights = []
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
# Weighting only for binary labels
if isinstance(ds._samples[0]["label"], int):
# Groundtruth
targets = []
for s in ds._samples:
targets.append(s["label"])
targets = torch.tensor(targets)
# Count number of samples per class
class_sample_count = torch.tensor(
[
(targets == t).sum()
for t in torch.unique(targets, sorted=True)
]
)
weight = 1.0 / class_sample_count.float()
samples_weights.append(
torch.tensor([weight[t] for t in targets])
)
else:
# We only weight to sample equally from each dataset
samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
# Concatenate sample weights from all the datasets
samples_weights = torch.cat(samples_weights)
else:
# Weighting only for binary labels
if isinstance(dataset._samples[0]["label"], int):
# Groundtruth
targets = []
for s in dataset._samples:
targets.append(s["label"])
targets = torch.tensor(targets)
# Count number of samples per class
class_sample_count = torch.tensor(
[
(targets == t).sum()
for t in torch.unique(targets, sorted=True)
]
)
weight = 1.0 / class_sample_count.float()
samples_weights = torch.tensor([weight[t] for t in targets])
else:
# Equal weights for non-binary labels
samples_weights = torch.ones(len(dataset._samples))
return samples_weights
def get_positive_weights(dataset):
"""Compute the positive weights of each class of the dataset to balance the """Compute the positive weights of each class of the dataset to balance the
BCEWithLogitsLoss criterion. BCEWithLogitsLoss criterion.
This function takes as input a :py:class:`torch.utils.data.dataset.Dataset` This function takes as input a :py:class:`torch.utils.data.DataLoader`
and computes the positive weights of each class to use them to have and computes the positive weights of each class to use them to have
a balanced loss. a balanced loss.
...@@ -463,9 +22,8 @@ def get_positive_weights(dataset): ...@@ -463,9 +22,8 @@ def get_positive_weights(dataset):
Parameters Parameters
---------- ----------
dataset : torch.utils.data.dataset.Dataset dataloader : :py:class:`torch.utils.data.DataLoader`
An instance of torch.utils.data.dataset.Dataset A DataLoader from which to compute the positive weights. Must contain a 'label' key in the metadata returned by __getitem__().
ConcatDataset are supported
Returns Returns
...@@ -476,14 +34,8 @@ def get_positive_weights(dataset): ...@@ -476,14 +34,8 @@ def get_positive_weights(dataset):
""" """
targets = [] targets = []
if isinstance(dataset, torch.utils.data.ConcatDataset): for batch in dataloader:
for ds in dataset.datasets: targets.extend(batch[1]["label"])
for s in ds._samples:
targets.append(s["label"])
else:
for s in dataset._samples:
targets.append(s["label"])
targets = torch.tensor(targets) targets = torch.tensor(targets)
...@@ -512,75 +64,3 @@ def get_positive_weights(dataset): ...@@ -512,75 +64,3 @@ def get_positive_weights(dataset):
) )
return positive_weights return positive_weights
def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
from torch.nn import BCEWithLogitsLoss
datamodule.prepare_data()
datamodule.setup(stage="fit")
train_dataset = datamodule.train_dataset
validation_dataset = datamodule.validation_dataset
# Redefine a weighted criterion if possible
if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
positive_weights = get_positive_weights(train_dataset)
model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
else:
logger.warning("Weighted criterion not supported")
if validation_dataset is not None:
# Redefine a weighted valid criterion if possible
if (
isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
or criterion_valid is None
):
positive_weights = get_positive_weights(validation_dataset)
model.hparams.criterion_valid = BCEWithLogitsLoss(
pos_weight=positive_weights
)
else:
logger.warning("Weighted valid criterion not supported")
def normalize_data(normalization, model, datamodule):
from torch.utils.data import DataLoader
datamodule.prepare_data()
datamodule.setup(stage="fit")
train_dataset = datamodule.train_dataset
# Create z-normalization model layer if needed
if normalization == "imagenet":
model.normalizer.set_mean_std(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
)
logger.info("Z-normalization with ImageNet mean and std")
elif normalization == "current":
# Compute mean/std of current train subset
temp_dl = DataLoader(
dataset=train_dataset, batch_size=len(train_dataset)
)
data = next(iter(temp_dl))
mean = data[1].mean(dim=[0, 2, 3])
std = data[1].std(dim=[0, 2, 3])
model.normalizer.set_mean_std(mean, std)
# Format mean and std for logging
mean = str(
[
round(x, 3)
for x in ((mean * 10**3).round() / (10**3)).tolist()
]
)
std = str(
[
round(x, 3)
for x in ((std * 10**3).round() / (10**3)).tolist()
]
)
logger.info(f"Z-normalization with mean {mean} and std {std}")
...@@ -264,7 +264,7 @@ json_dataset = JSONDataset( ...@@ -264,7 +264,7 @@ json_dataset = JSONDataset(
def _maker(protocol, resize_size=512, cc_size=512, RGB=True): def _maker(protocol, resize_size=512, cc_size=512, RGB=True):
import torchvision.transforms as transforms import torchvision.transforms as transforms
from ..transforms import SingleAutoLevel16to8 from ..loader import SingleAutoLevel16to8
post_transforms = [] post_transforms = []
if not RGB: if not RGB:
......
...@@ -5,9 +5,42 @@ ...@@ -5,9 +5,42 @@
"""Data loading code.""" """Data loading code."""
import numpy
import PIL.Image import PIL.Image
class SingleAutoLevel16to8:
"""Converts a 16-bit image to 8-bit representation using "auto-level".
This transform assumes that the input image is gray-scaled.
To auto-level, we calculate the maximum and the minimum of the image, and
consider such a range should be mapped to the [0,255] range of the
destination image.
"""
def __call__(self, img):
imin, imax = img.getextrema()
irange = imax - imin
return PIL.Image.fromarray(
numpy.round(
255.0 * (numpy.array(img).astype(float) - imin) / irange
).astype("uint8"),
).convert("L")
class RemoveBlackBorders:
"""Remove black borders of CXR."""
def __init__(self, threshold=0):
self.threshold = threshold
def __call__(self, img):
img = numpy.asarray(img)
mask = numpy.asarray(img) > self.threshold
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
def load_pil(path): def load_pil(path):
"""Loads a sample data. """Loads a sample data.
......
...@@ -13,22 +13,11 @@ the daily routine using Philips DR Digital Diagnose systems. ...@@ -13,22 +13,11 @@ the daily routine using Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_ * Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less * Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none * Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels) * Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels) * Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels) * Test samples: 20% of TB and healthy CXR (including labels)
""" """
import importlib.resources import importlib.resources
import os
from clapper.logging import setup
from ...utils.rc import load_rc
from ..loader import load_pil_baw
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
_protocols = [ _protocols = [
importlib.resources.files(__name__).joinpath("default.json.bz2"), importlib.resources.files(__name__).joinpath("default.json.bz2"),
...@@ -43,11 +32,3 @@ _protocols = [ ...@@ -43,11 +32,3 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_8.json.bz2"), importlib.resources.files(__name__).joinpath("fold_8.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
] ]
_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
def _raw_data_loader(img_path):
raw_data = load_pil_baw(os.path.join(_datadir, img_path))
return raw_data
...@@ -2,27 +2,34 @@ ...@@ -2,27 +2,34 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for TB detection (default protocol) """Shenzhen datamodule for computer-aided diagnosis (default protocol)
* Split reference: first 64% of TB and healthy CXR for "train" 16% for See :py:mod:`ptbench.data.shenzhen` for dataset details.
* "validation", 20% for "test"
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.shenzhen` for dataset details
"""
from clapper.logging import setup
from ..transforms import ElasticDeformation
from .utils import ShenzhenDataModule
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") This configuration:
* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
protocol_name = "default" * augmentations: elastic deformation (probability = 80%)
* output image resolution: 512x512 pixels
augmentation_transforms = [ElasticDeformation(p=0.8)] """
datamodule = ShenzhenDataModule( import importlib.resources
protocol="default",
model_transforms=[], from ..datamodule import CachingDataModule
augmentation_transforms=augmentation_transforms, from ..split import JSONDatabaseSplit
from .raw_data_loader import raw_data_loader
datamodule = CachingDataModule(
database_split=JSONDatabaseSplit(
importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
"default.json.bz2"
)
),
raw_data_loader=raw_data_loader,
cache_samples=False,
# train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
# model_transforms = [],
# batch_size = 1,
# batch_chunk_count = 1,
# drop_incomplete_batch = False,
# parallel = -1,
) )
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for computer-aided diagnosis.
The standard digital image database for Tuberculosis is created by the National
Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s
Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from
out-patient clinics, and were captured as part of the daily routine using
Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
import os
import typing
import torch.nn
import torchvision.transforms
from ...utils.rc import load_rc
from ..raw_data_loader import RemoveBlackBorders, load_pil_baw
_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
"""This variable contains the base directory where the database raw data is
stored."""
_transform = torchvision.transforms.Compose(
[
RemoveBlackBorders(),
torchvision.transforms.Resize(512),
torchvision.transforms.CenterCrop(512),
torchvision.transforms.ToTensor(),
]
)
"""Transforms that are always applied to the loaded raw images."""
def raw_data_loader(
sample: tuple[str, int]
) -> tuple[torch.Tensor, typing.Mapping]:
"""Loads a single image sample from the disk.
Parameters
----------
img_path
The path suffix, within the dataset root folder, where to find the
image to be loaded.
Returns
-------
image
A PIL image in grayscale mode
"""
tensor = _transform(load_pil_baw(os.path.join(_datadir, sample[0])))
return tensor, dict(label=sample[1]) # type: ignore[arg-type]
...@@ -2,81 +2,40 @@ ...@@ -2,81 +2,40 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for TB detection (cross validation fold 0, RGB) """Shenzhen datamodule for computer-aided diagnosis (default protocol)
* Split reference: first 80% of TB and healthy CXR for "train", rest for "test" See :py:mod:`ptbench.data.shenzhen` for dataset details.
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.shenzhen` for dataset details
"""
from clapper.logging import setup
from ....data import return_subsets
from ....data.base_datamodule import BaseDataModule
from ....data.dataset import JSONProtocol
from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") This configuration:
* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
* augmentations: elastic deformation (probability = 80%)
class DefaultModule(BaseDataModule): * output image resolution: 512x512 pixels
def __init__( """
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
cache_samples=False,
multiproc_kwargs=None,
data_transforms=[],
model_transforms=[],
train_transforms=[],
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
self.cache_samples = cache_samples
self.has_setup_fit = False
self.data_transforms = data_transforms import importlib.resources
self.model_transforms = model_transforms
self.train_transforms = train_transforms
"""[ from torchvision import transforms
transforms.ToPILImage(),
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
]"""
def setup(self, stage: str): from ..datamodule import CachingDataModule
if self.cache_samples: from ..split import JSONDatabaseSplit
logger.info( from .raw_data_loader import raw_data_loader
"Argument cache_samples set to True. Samples will be loaded in memory."
)
samples_loader = _cached_loader
else:
logger.info(
"Argument cache_samples set to False. Samples will be loaded at runtime."
)
samples_loader = _delayed_loader
self.json_protocol = JSONProtocol( datamodule = CachingDataModule(
protocols=_protocols, database_split=JSONDatabaseSplit(
fieldnames=("data", "label"), importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
loader=samples_loader, "default.json.bz2"
post_transforms=self.post_transforms,
) )
),
if not self.has_setup_fit and stage == "fit": raw_data_loader=raw_data_loader,
( cache_samples=False,
self.train_dataset, # train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
self.validation_dataset, model_transforms=[
self.extra_validation_datasets, transforms.ToPILImage(),
) = return_subsets(self.json_protocol, "default", stage) transforms.Lambda(lambda x: x.convert("RGB")),
self.has_setup_fit = True transforms.ToTensor(),
],
# batch_size = 1,
datamodule = DefaultModule # batch_chunk_count = 1,
# drop_incomplete_batch = False,
# parallel = -1,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for computer-aided diagnosis.
The standard digital image database for Tuberculosis is created by the
National Library of Medicine, Maryland, USA in collaboration with Shenzhen
No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China.
The Chest X-rays are from out-patient clinics, and were captured as part of
the daily routine using Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
from clapper.logging import setup
from torchvision import transforms
from ..base_datamodule import BaseDataModule
from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
from ..shenzhen import _protocols, _raw_data_loader
from ..transforms import RemoveBlackBorders
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class ShenzhenDataModule(BaseDataModule):
def __init__(
self,
protocol="default",
model_transforms=[],
augmentation_transforms=[],
batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
cache_samples=False,
parallel=-1,
):
super().__init__(
batch_size=batch_size,
drop_incomplete_batch=drop_incomplete_batch,
batch_chunk_count=batch_chunk_count,
parallel=parallel,
)
self._cache_samples = cache_samples
self._has_setup_fit = False
self._has_setup_predict = False
self._protocol = protocol
self.raw_data_transforms = [
RemoveBlackBorders(),
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
]
self.model_transforms = model_transforms
self.augmentation_transforms = augmentation_transforms
def setup(self, stage: str):
json_protocol = JSONProtocol(
protocols=_protocols,
fieldnames=("data", "label"),
)
if self._cache_samples:
dataset = CachedDataset
else:
dataset = RuntimeDataset
if not self._has_setup_fit and stage == "fit":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=True),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_fit = True
if not self._has_setup_predict and stage == "predict":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_predict = True
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import collections.abc
import csv
import importlib.abc
import json
import logging
import pathlib
import typing
import torch
logger = logging.getLogger(__name__)
class JSONDatabaseSplit(
dict,
typing.Mapping[str, typing.Sequence[typing.Any]],
):
"""Defines a loader that understands a database split (train, test, etc) in
JSON format.
To create a new database split, you need to provide a JSON formatted
dictionary in a file, with contents similar to the following:
.. code-block:: json
{
"subset1": [
[
"sample1-data1",
"sample1-data2",
"sample1-data3",
],
[
"sample2-data1",
"sample2-data2",
"sample2-data3",
]
],
"subset2": [
[
"sample42-data1",
"sample42-data2",
"sample42-data3",
],
]
}
Your database split many contain any number of subsets (dictionary keys).
For simplicity, we recommend all sample entries are formatted similarly so
that raw-data-loading is simplified. Use the function
:py:func:`check_database_split_loading` to test raw data loading and fine
tune the dataset split, or its loading.
Objects of this class behave like a dictionary in which keys are subset
names in the split, and values represent samples data and meta-data.
Parameters
----------
path
Absolute path to a JSON formatted file containing the database split to be
recognized by this object.
"""
def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable):
if isinstance(path, str):
path = pathlib.Path(path)
self.path = path
self.subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> dict[str, typing.Sequence[typing.Any]]:
"""Loads all subsets in a split from its file system representation.
This method will load JSON information for the current split and return
all subsets of the given split after converting each entry through the
loader function.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of JSON objects
"""
if str(self.path).endswith(".bz2"):
logger.debug(f"Loading database split from {str(self.path)}...")
with __import__("bz2").open(self.path) as f:
return json.load(f)
else:
with self.path.open() as f:
return json.load(f)
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self.subsets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self.subsets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self.subsets)
class CSVDatabaseSplit(collections.abc.Mapping):
"""Defines a loader that understands a database split (train, test, etc) in
CSV format.
To create a new database split, you need to provide one or more CSV
formatted files, each representing a subset of this split, containing the
sample data (one per row). Example:
Inside the directory ``my-split/``, one can file files ``train.csv``,
``validation.csv``, and ``test.csv``. Each file has a structure similar to
the following:
.. code-block:: text
sample1-value1,sample1-value2,sample1-value3
sample2-value1,sample2-value2,sample2-value3
...
Each file in the provided directory defines the subset name on the split.
So, the file ``train.csv`` will contain the data from the ``train`` subset,
and so on.
Objects of this class behave like a dictionary in which keys are subset
names in the split, and values represent samples data and meta-data.
Parameters
----------
directory
Absolute path to a directory containing the database split layed down
as a set of CSV files, one per subset.
"""
def __init__(
self, directory: pathlib.Path | str | importlib.abc.Traversable
):
if isinstance(directory, str):
directory = pathlib.Path(directory)
assert (
directory.is_dir()
), f"`{str(directory)}` is not a valid directory"
self.directory = directory
self.subsets = self._load_split_from_disk()
def _load_split_from_disk(
self,
) -> dict[str, list[typing.Any]]:
"""Loads all subsets in a split from its file system representation.
This method will load CSV information for the current split and return all
subsets of the given split after converting each entry through the
loader function.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of JSON objects
"""
retval = {}
for subset in self.directory.iterdir():
if str(subset).endswith(".csv.bz2"):
logger.debug(f"Loading database split from {subset}...")
with __import__("bz2").open(subset) as f:
reader = csv.reader(f)
retval[subset.name[: -len(".csv.bz2")]] = [
k for k in reader
]
elif str(subset).endswith(".csv"):
with subset.open() as f:
reader = csv.reader(f)
retval[subset.name[: -len(".csv")]] = [k for k in reader]
else:
logger.debug(
f"Ignoring file {subset} in CSVDatabaseSplit readout"
)
return retval
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self.subsets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self.subsets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self.subsets)
def check_database_split_loading(
database_split: typing.Mapping[str, typing.Sequence[typing.Any]],
loader: typing.Callable[[typing.Any], torch.Tensor],
limit: int = 0,
) -> int:
"""For each subset in the split, check if all data can be correctly loaded
using the provided loader function.
This function will return the number of errors loading samples, and will
log more detailed information to the logging stream.
Parameters
----------
database_split
A mapping that, contains the database split. Each key represents the
name of a subset in the split. Each value is a (potentially complex)
object that represents a single sample.
loader
A callable that transforms sample entries in the database split
into :py:class:`torch.Tensor` objects that can be used for training
or inference.
limit
Maximum number of samples to check (in each split/subset
combination) in this dataset. If set to zero, then check
everything.
Returns
-------
errors
Number of errors found
"""
logger.info(
"Checking if can load all samples in all subsets of this split..."
)
errors = 0
for subset in database_split.keys():
samples = subset if not limit else subset[:limit]
for pos, sample in enumerate(samples):
try:
data = loader(sample)
assert isinstance(data, torch.Tensor)
except Exception as e:
logger.info(
f"Found error loading entry {pos} in subset `{subset}`: {e}"
)
errors += 1
return errors
...@@ -22,39 +22,6 @@ from scipy.ndimage import gaussian_filter, map_coordinates ...@@ -22,39 +22,6 @@ from scipy.ndimage import gaussian_filter, map_coordinates
from torchvision import transforms from torchvision import transforms
class SingleAutoLevel16to8:
"""Converts a 16-bit image to 8-bit representation using "auto-level".
This transform assumes that the input image is gray-scaled.
To auto-level, we calculate the maximum and the minimum of the
image, and
consider such a range should be mapped to the [0,255] range of the
destination image.
"""
def __call__(self, img):
imin, imax = img.getextrema()
irange = imax - imin
return PIL.Image.fromarray(
numpy.round(
255.0 * (numpy.array(img).astype(float) - imin) / irange
).astype("uint8"),
).convert("L")
class RemoveBlackBorders:
"""Remove black borders of CXR."""
def __init__(self, threshold=0):
self.threshold = threshold
def __call__(self, img):
img = numpy.asarray(img)
mask = numpy.asarray(img) > self.threshold
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
class ElasticDeformation: class ElasticDeformation:
"""Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_. """Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_.
...@@ -68,7 +35,7 @@ class ElasticDeformation: ...@@ -68,7 +35,7 @@ class ElasticDeformation:
spline_order=1, spline_order=1,
mode="nearest", mode="nearest",
random_state=numpy.random, random_state=numpy.random,
p=1, p=1.0,
): ):
self.alpha = alpha self.alpha = alpha
self.sigma = sigma self.sigma = sigma
...@@ -79,13 +46,15 @@ class ElasticDeformation: ...@@ -79,13 +46,15 @@ class ElasticDeformation:
def __call__(self, img): def __call__(self, img):
if random.random() < self.p: if random.random() < self.p:
img = transforms.ToPILImage()(img) assert img.ndim == 3
# Input tensor is of shape C x H x W
# If the tensor only contains one channel, this conversion results in H x W.
# With 3 channels, we get H x W x C
img = transforms.ToPILImage()(img)
img = numpy.asarray(img) img = numpy.asarray(img)
assert img.ndim == 2 shape = img.shape[:2]
shape = img.shape
dx = ( dx = (
gaussian_filter( gaussian_filter(
...@@ -114,9 +83,22 @@ class ElasticDeformation: ...@@ -114,9 +83,22 @@ class ElasticDeformation:
numpy.reshape(y + dy, (-1, 1)), numpy.reshape(y + dy, (-1, 1)),
] ]
result = numpy.empty_like(img) result = numpy.empty_like(img)
result[:, :] = map_coordinates(
img[:, :], indices, order=self.spline_order, mode=self.mode if img.ndim == 2:
).reshape(shape) result[:, :] = map_coordinates(
img[:, :], indices, order=self.spline_order, mode=self.mode
).reshape(shape)
else:
for i in range(img.shape[2]):
result[:, :, i] = map_coordinates(
img[:, :, i],
indices,
order=self.spline_order,
mode=self.mode,
).reshape(shape)
return transforms.ToTensor()(PIL.Image.fromarray(result)) return transforms.ToTensor()(PIL.Image.fromarray(result))
else: else:
return img return img
...@@ -94,10 +94,8 @@ class LoggingCallback(Callback): ...@@ -94,10 +94,8 @@ class LoggingCallback(Callback):
self.log("total_time", current_time) self.log("total_time", current_time)
self.log("eta", eta_seconds) self.log("eta", eta_seconds)
self.log("loss", numpy.average(self.training_loss)) self.log("loss", numpy.average(self.training_loss))
self.log( self.log("learning_rate", pl_module.optimizer_configs["lr"])
"learning_rate", pl_module.hparams["optimizer_configs"]["lr"] self.log("validation_loss", numpy.average(self.validation_loss))
)
self.log("validation_loss", numpy.sum(self.validation_loss))
if len(self.extra_validation_loss) > 0: if len(self.extra_validation_loss) > 0:
for ( for (
...@@ -105,7 +103,8 @@ class LoggingCallback(Callback): ...@@ -105,7 +103,8 @@ class LoggingCallback(Callback):
extra_valid_loss_values, extra_valid_loss_values,
) in self.extra_validation_loss.items: ) in self.extra_validation_loss.items:
self.log( self.log(
extra_valid_loss_key, numpy.sum(extra_valid_loss_values) extra_valid_loss_key,
numpy.average(extra_valid_loss_values),
) )
queue_retries = 0 queue_retries = 0
......
...@@ -7,19 +7,20 @@ import logging ...@@ -7,19 +7,20 @@ import logging
import os import os
import shutil import shutil
from lightning.pytorch import Trainer import lightning.pytorch
from lightning.pytorch.callbacks import ModelCheckpoint import lightning.pytorch.callbacks
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger import lightning.pytorch.loggers
from lightning.pytorch.utilities.model_summary import ModelSummary import torch.nn
from ..utils.accelerator import AcceleratorProcessor from ..utils.accelerator import AcceleratorProcessor
from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
from ..utils.save_sh_command import save_sh_command
from .callbacks import LoggingCallback from .callbacks import LoggingCallback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def check_gpu(device): def check_gpu(device: str) -> None:
"""Check the device type and the availability of GPU. """Check the device type and the availability of GPU.
Parameters Parameters
...@@ -35,32 +36,41 @@ def check_gpu(device): ...@@ -35,32 +36,41 @@ def check_gpu(device):
), f"Device set to '{device}', but nvidia-smi is not installed" ), f"Device set to '{device}', but nvidia-smi is not installed"
def save_model_summary(output_folder, model): def save_model_summary(
output_folder: str, model: torch.nn.Module
) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]:
"""Save a little summary of the model in a txt file. """Save a little summary of the model in a txt file.
Parameters Parameters
---------- ----------
output_folder : str output_folder
output path output path
model : :py:class:`torch.nn.Module` model
Network (e.g. driu, hed, unet) Network (e.g. driu, hed, unet)
Returns Returns
------- -------
r : str summary:
The model summary in a text format. The model summary in a text format.
n : int total_parameters:
The number of parameters of the model. The number of parameters of the model.
""" """
summary_path = os.path.join(output_folder, "model_summary.txt") summary_path = os.path.join(output_folder, "model_summary.txt")
logger.info(f"Saving model summary at {summary_path}...") logger.info(f"Saving model summary at {summary_path}...")
with open(summary_path, "w") as f: with open(summary_path, "w") as f:
summary = ModelSummary(model, max_depth=-1) summary = lightning.pytorch.utilities.model_summary.ModelSummary(
model, max_depth=-1
)
f.write(str(summary)) f.write(str(summary))
return summary, ModelSummary(model).total_parameters return (
summary,
lightning.pytorch.utilities.model_summary.ModelSummary(
model
).total_parameters,
)
def static_information_to_csv(static_logfile_name, device, n): def static_information_to_csv(static_logfile_name, device, n):
...@@ -70,7 +80,8 @@ def static_information_to_csv(static_logfile_name, device, n): ...@@ -70,7 +80,8 @@ def static_information_to_csv(static_logfile_name, device, n):
---------- ----------
static_logfile_name : str static_logfile_name : str
The static file name which is a join between the output folder and "constant.csv" The static file name which is a join between the output folder and
"constant.csv"
""" """
if os.path.exists(static_logfile_name): if os.path.exists(static_logfile_name):
backup = static_logfile_name + "~" backup = static_logfile_name + "~"
...@@ -188,7 +199,8 @@ def run( ...@@ -188,7 +199,8 @@ def run(
not save intermediary checkpoints. not save intermediary checkpoints.
accelerator : str accelerator : str
A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0) A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The
device can also be specified (gpu:0)
arguments : dict arguments : dict
Start and end epochs: Start and end epochs:
...@@ -215,10 +227,16 @@ def run( ...@@ -215,10 +227,16 @@ def run(
os.makedirs(output_folder, exist_ok=True) os.makedirs(output_folder, exist_ok=True)
# Save model summary # Save model summary
r, n = save_model_summary(output_folder, model) _, n = save_model_summary(output_folder, model)
csv_logger = CSVLogger(output_folder, "logs_csv") save_sh_command(output_folder)
tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard")
# save_sh_command(os.path.join(output_folder, "cmd_line_config.txt"))
csv_logger = lightning.pytorch.loggers.CSVLogger(output_folder, "logs_csv")
tensorboard_logger = lightning.pytorch.loggers.TensorBoardLogger(
output_folder, "logs_tensorboard"
)
resource_monitor = ResourceMonitor( resource_monitor = ResourceMonitor(
interval=monitoring_interval, interval=monitoring_interval,
...@@ -227,7 +245,7 @@ def run( ...@@ -227,7 +245,7 @@ def run(
logging_level=logging.ERROR, logging_level=logging.ERROR,
) )
checkpoint_callback = ModelCheckpoint( checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint(
output_folder, output_folder,
"model_lowest_valid_loss", "model_lowest_valid_loss",
save_last=True, save_last=True,
...@@ -251,7 +269,7 @@ def run( ...@@ -251,7 +269,7 @@ def run(
devices = accelerator_processor.device devices = accelerator_processor.device
with resource_monitor: with resource_monitor:
trainer = Trainer( trainer = lightning.pytorch.Trainer(
accelerator=accelerator_processor.accelerator, accelerator=accelerator_processor.accelerator,
devices=devices, devices=devices,
max_epochs=max_epoch, max_epochs=max_epoch,
......