Skip to content
Snippets Groups Projects
Commit 8d11b566 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Pair programming with @dcarron

parent a67626d8
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
Pipeline #75378 failed
......@@ -10,11 +10,7 @@ import typing
import lightning
import torch
import torch.utils.data
from clapper.logging import setup
# TODO: No logging on this module...
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
import torchvision.transforms
def _setup_dataloader_multiproc_parameters(
......@@ -93,11 +89,13 @@ class _DelayedLoadingDataset(torch.utils.data.Dataset):
raw_data_loader: typing.Callable[
[typing.Any], tuple[torch.Tensor, typing.Mapping]
],
transforms: typing.Optional[typing.Sequence[torch.nn.Module]] = None,
transforms: typing.Sequence[
typing.Callable[[torch.Tensor], torch.Tensor]
] = [],
):
self.split = split
self.raw_data_loader = raw_data_loader
self.transform = torch.nn.Sequential(*transforms)
self.transform = torchvision.transforms.Compose(*transforms)
def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
tensor, metadata = self.raw_data_loader(self.split[key])
......@@ -137,10 +135,12 @@ class _CachedDataset(torch.utils.data.Dataset):
raw_data_loader: typing.Callable[
[typing.Any], tuple[torch.Tensor, typing.Mapping]
],
transforms: typing.Optional[typing.Sequence[torch.nn.Module]] = None,
transforms: typing.Sequence[
typing.Callable[[torch.Tensor], torch.Tensor]
] = [],
):
self.data = [raw_data_loader(k) for k in split]
self.transform = torch.nn.Sequential(*transforms)
self.transform = torchvision.transforms.Compose(*transforms)
def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
tensor, metadata = self.data[key]
......@@ -344,22 +344,21 @@ class CachingDataModule(lightning.LightningDataModule):
],
cache_samples: bool = False,
train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
data_augmentations: list[torch.nn.Module] = [],
model_transforms: list[torch.nn.Module] = [],
data_augmentations: list[
typing.Callable[[torch.Tensor], torch.Tensor]
] = [],
model_transforms: list[
typing.Callable[[torch.Tensor], torch.Tensor]
] = [],
batch_size: int = 1,
batch_chunk_count: int = 1,
drop_incomplete_batch: bool = False,
parallel: int = -1,
):
# validation
if batch_size % batch_chunk_count != 0:
raise RuntimeError(
f"batch_size ({batch_size}) must be divisible by "
f"batch_chunk_size ({batch_chunk_count})."
)
super().__init__()
self.set_chunk_size(batch_size, batch_chunk_count)
self.database_split = database_split
self.raw_data_loader = raw_data_loader
self.cache_samples = cache_samples
......@@ -367,16 +366,8 @@ class CachingDataModule(lightning.LightningDataModule):
self.data_augmentations = data_augmentations
self.model_transforms = model_transforms
self._batch_size = batch_size
self._batch_chunk_count = batch_chunk_count
self._chunk_size = self._batch_size // self._batch_chunk_count
self.drop_incomplete_batch = drop_incomplete_batch
self._parallel = parallel # immutable, otherwise would need to call
# the next function again
self._dataloader_multiproc = _setup_dataloader_multiproc_parameters(
parallel
)
self.parallel = parallel # immutable, otherwise would need to call
self.pin_memory = (
torch.cuda.is_available()
......@@ -385,6 +376,63 @@ class CachingDataModule(lightning.LightningDataModule):
# datasets that have been setup() for the current stage
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
@property
def parallel(self) -> int:
"""The parallel property."""
return self._parallel
@parallel.setter
def parallel(self, value: int) -> None:
self._parallel = value
self._dataloader_multiproc = _setup_dataloader_multiproc_parameters(
value
)
# datasets that have been setup() for the current stage
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
"""Coherently sets the batch-chunk-size after validation.
Parameters
----------
batch_size
Number of samples in every **training** batch (this parameter affects
memory requirements for the network). If the number of samples in the
batch is larger than the total number of samples available for
training, this value is truncated. If this number is smaller, then
batches of the specified size are created and fed to the network until
there are no more new samples to feed (epoch is finished). If the
total number of training samples is not a multiple of the batch-size,
the last batch will be smaller than the first, unless
``drop_incomplete_batch`` is set to ``true``, in which case this batch
is not used.
batch_chunk_count
Number of chunks in every batch (this parameter affects memory
requirements for the network). The number of samples loaded for every
iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
needs to be divisible by ``batch_chunk_count``, otherwise an error will
be raised. This parameter is used to reduce number of samples loaded in
each iteration, in order to reduce the memory usage in exchange for
processing time (more iterations). This is specially interesting whe
one is running with GPUs with limited RAM. The default of 1 forces the
whole batch to be processed at once. Otherwise the batch is broken into
batch-chunk-count pieces, and gradients are accumulated to complete
each batch.
"""
# validation
if batch_size % batch_chunk_count != 0:
raise RuntimeError(
f"batch_size ({batch_size}) must be divisible by "
f"batch_chunk_size ({batch_chunk_count})."
)
self._batch_size = batch_size
self._batch_chunk_count = batch_chunk_count
self._chunk_size = self._batch_size // self._batch_chunk_count
def setup(self, stage: str) -> None:
"""Sets up datasets for different tasks on the pipeline.
......@@ -440,11 +488,40 @@ class CachingDataModule(lightning.LightningDataModule):
elif stage == "predict":
_setup("test", self.model_transforms)
def train_dataloader(self):
def unaugmented_train_dataloader(self) -> torch.utils.data.DataLoader:
"""Returns a version of the train dataloader without augmentations.
Use this method to obtain a version of the train dataloader without
augmentations, to compute input normalisation factors (e.g. mean and
standard deviation or min-max parameterisations).
Returns
-------
dataloader
The unaugmented train dataloader
"""
dataset = _DelayedLoadingDataset(
self.database_split["train"],
self.raw_data_loader,
self.model_transforms,
)
return torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=self._chunk_size,
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
**self._dataloader_multiproc,
)
def train_dataloader(self) -> torch.utils.data.DataLoader:
"""Returns the train data loader."""
return torch.utils.data.DataLoader(
self._datasets["train"],
shuffle=True,
batch_size=self._chunk_size,
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
......@@ -452,14 +529,19 @@ class CachingDataModule(lightning.LightningDataModule):
**self._dataloader_multiproc,
)
def val_dataloader(self):
"""Returns the validation data loader(s)"""
def available_dataset_keys(self) -> typing.KeysView[str]:
"""Returns all names for datasets that are setup."""
return self._datasets.keys()
extra_valid = [
def val_database_split_keys(self) -> list[str]:
"""Returns list of validation dataset names."""
return ["validation"] + [
k for k in self.database_split.keys() if k.startswith("monitor-")
]
# TODO: do we really need the train sampler here?
def val_dataloader(self) -> dict[str, torch.utils.data.DataLoader]:
"""Returns the validation data loader(s)"""
validation_loader_opts = {
"batch_size": self._chunk_size,
"shuffle": False,
......@@ -468,22 +550,13 @@ class CachingDataModule(lightning.LightningDataModule):
}
validation_loader_opts.update(self._dataloader_multiproc)
# TODO: not sure this is the right way to handle multiple validation
# loaders, please check and fix
if not extra_valid:
return torch.utils.data.DataLoader(
self._datasets["validation"],
**validation_loader_opts,
# select all keys of interest
return {
k: torch.utils.data.DataLoader(
self._datasets[k], **validation_loader_opts
)
else:
return [
torch.utils.data.DataLoader(
self._datasets[k],
**validation_loader_opts,
)
for k in ["validation"] + extra_valid
]
for k in self.val_database_split_keys()
}
def test_dataloader(self):
"""Returns the test data loader(s)"""
......
......@@ -2,13 +2,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import collections.abc
import csv
import importlib.abc
import json
import logging
import pathlib
import typing
import torch
import torch.utils.data
......@@ -16,249 +10,6 @@ import torch.utils.data
logger = logging.getLogger(__name__)
class JSONDatabaseSplit(
dict,
typing.Mapping[str, typing.Sequence[typing.Any]],
):
"""Defines a loader that understands a database split (train, test, etc) in
JSON format.
To create a new database split, you need to provide a JSON formatted
dictionary in a file, with contents similar to the following:
.. code-block:: json
{
"subset1": [
[
"sample1-data1",
"sample1-data2",
"sample1-data3",
],
[
"sample2-data1",
"sample2-data2",
"sample2-data3",
]
],
"subset2": [
[
"sample42-data1",
"sample42-data2",
"sample42-data3",
],
]
}
Your database split many contain any number of subsets (dictionary keys).
For simplicity, we recommend all sample entries are formatted similarly so
that raw-data-loading is simplified. Use the function
:py:func:`check_database_split_loading` to test raw data loading and fine
tune the dataset split, or its loading.
Objects of this class behave like a dictionary in which keys are subset
names in the split, and values represent samples data and meta-data.
Parameters
----------
path
Absolute path to a JSON formatted file containing the database split to be
recognized by this object.
"""
def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable):
if isinstance(path, str):
path = pathlib.Path(path)
self.path = path
self.subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> dict[str, typing.Sequence[typing.Any]]:
"""Loads all subsets in a split from its file system representation.
This method will load JSON information for the current split and return
all subsets of the given split after converting each entry through the
loader function.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of JSON objects
"""
if str(self.path).endswith(".bz2"):
logger.debug(f"Loading database split from {str(self.path)}...")
with __import__("bz2").open(self.path) as f:
return json.load(f)
else:
with self.path.open() as f:
return json.load(f)
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self.subsets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self.subsets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self.subsets)
class CSVDatabaseSplit(collections.abc.Mapping):
"""Defines a loader that understands a database split (train, test, etc) in
CSV format.
To create a new database split, you need to provide one or more CSV
formatted files, each representing a subset of this split, containing the
sample data (one per row). Example:
Inside the directory ``my-split/``, one can file files ``train.csv``,
``validation.csv``, and ``test.csv``. Each file has a structure similar to
the following:
.. code-block:: text
sample1-value1,sample1-value2,sample1-value3
sample2-value1,sample2-value2,sample2-value3
...
Each file in the provided directory defines the subset name on the split.
So, the file ``train.csv`` will contain the data from the ``train`` subset,
and so on.
Objects of this class behave like a dictionary in which keys are subset
names in the split, and values represent samples data and meta-data.
Parameters
----------
directory
Absolute path to a directory containing the database split layed down
as a set of CSV files, one per subset.
"""
def __init__(
self, directory: pathlib.Path | str | importlib.abc.Traversable
):
if isinstance(directory, str):
directory = pathlib.Path(directory)
assert (
directory.is_dir()
), f"`{str(directory)}` is not a valid directory"
self.directory = directory
self.subsets = self._load_split_from_disk()
def _load_split_from_disk(
self,
) -> dict[str, list[typing.Any]]:
"""Loads all subsets in a split from its file system representation.
This method will load CSV information for the current split and return all
subsets of the given split after converting each entry through the
loader function.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of JSON objects
"""
retval = {}
for subset in self.directory.iterdir():
if str(subset).endswith(".csv.bz2"):
logger.debug(f"Loading database split from {subset}...")
with __import__("bz2").open(subset) as f:
reader = csv.reader(f)
retval[subset.name[: -len(".csv.bz2")]] = [
k for k in reader
]
elif str(subset).endswith(".csv"):
with subset.open() as f:
reader = csv.reader(f)
retval[subset.name[: -len(".csv")]] = [k for k in reader]
else:
logger.debug(
f"Ignoring file {subset} in CSVDatabaseSplit readout"
)
return retval
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self.subsets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self.subsets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self.subsets)
def check_database_split_loading(
database_split: typing.Mapping[str, typing.Sequence[typing.Any]],
loader: typing.Callable[[typing.Any], torch.Tensor],
limit: int = 0,
) -> int:
"""For each subset in the split, check if all data can be correctly loaded
using the provided loader function.
This function will return the number of errors loading samples, and will
log more detailed information to the logging stream.
Parameters
----------
database_split
A mapping that, contains the database split. Each key represents the
name of a subset in the split. Each value is a (potentially complex)
object that represents a single sample.
loader
A callable that transforms sample entries in the database split
into :py:class:`torch.Tensor` objects that can be used for training
or inference.
limit
Maximum number of samples to check (in each split/subset
combination) in this dataset. If set to zero, then check
everything.
Returns
-------
errors
Number of errors found
"""
logger.info(
"Checking if can load all samples in all subsets of this split..."
)
errors = 0
for subset in database_split.keys():
samples = subset if not limit else subset[:limit]
for pos, sample in enumerate(samples):
try:
data = loader(sample)
assert isinstance(data, torch.Tensor)
except Exception as e:
logger.info(
f"Found error loading entry {pos} in subset `{subset}`: {e}"
)
errors += 1
return errors
def get_positive_weights(dataset):
"""Compute the positive weights of each class of the dataset to balance the
BCEWithLogitsLoss criterion.
......@@ -350,45 +101,3 @@ def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
)
else:
logger.warning("Weighted valid criterion not supported")
def normalize_data(normalization, model, datamodule):
from torch.utils.data import DataLoader
datamodule.prepare_data()
datamodule.setup(stage="fit")
train_dataset = datamodule.train_dataset
# Create z-normalization model layer if needed
if normalization == "imagenet":
model.normalizer.set_mean_std(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
)
logger.info("Z-normalization with ImageNet mean and std")
elif normalization == "current":
# Compute mean/std of current train subset
temp_dl = DataLoader(
dataset=train_dataset, batch_size=len(train_dataset)
)
data = next(iter(temp_dl))
mean = data[1].mean(dim=[0, 2, 3])
std = data[1].std(dim=[0, 2, 3])
model.normalizer.set_mean_std(mean, std)
# Format mean and std for logging
mean = str(
[
round(x, 3)
for x in ((mean * 10**3).round() / (10**3)).tolist()
]
)
std = str(
[
round(x, 3)
for x in ((std * 10**3).round() / (10**3)).tolist()
]
)
logger.info(f"Z-normalization with mean {mean} and std {std}")
......@@ -264,7 +264,7 @@ json_dataset = JSONDataset(
def _maker(protocol, resize_size=512, cc_size=512, RGB=True):
import torchvision.transforms as transforms
from ..transforms import SingleAutoLevel16to8
from ..loader import SingleAutoLevel16to8
post_transforms = []
if not RGB:
......
......@@ -5,9 +5,42 @@
"""Data loading code."""
import numpy
import PIL.Image
class SingleAutoLevel16to8:
"""Converts a 16-bit image to 8-bit representation using "auto-level".
This transform assumes that the input image is gray-scaled.
To auto-level, we calculate the maximum and the minimum of the image, and
consider such a range should be mapped to the [0,255] range of the
destination image.
"""
def __call__(self, img):
imin, imax = img.getextrema()
irange = imax - imin
return PIL.Image.fromarray(
numpy.round(
255.0 * (numpy.array(img).astype(float) - imin) / irange
).astype("uint8"),
).convert("L")
class RemoveBlackBorders:
"""Remove black borders of CXR."""
def __init__(self, threshold=0):
self.threshold = threshold
def __call__(self, img):
img = numpy.asarray(img)
mask = numpy.asarray(img) > self.threshold
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
def load_pil(path):
"""Loads a sample data.
......
......@@ -15,13 +15,15 @@ This configuration:
import importlib.resources
from ..datamodule import CachingDataModule
from ..dataset import JSONDatabaseSplit
from ..split import JSONDatabaseSplit
from ..transforms import ElasticDeformation
from .loader import raw_data_loader
from .raw_data_loader import raw_data_loader
datamodule = CachingDataModule(
database_split=JSONDatabaseSplit(
importlib.resources.files(__name__).joinpath("default.json.bz2")
importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
"default.json.bz2"
)
),
raw_data_loader=raw_data_loader,
cache_samples=False,
......
......@@ -27,8 +27,7 @@ import torch.nn
import torchvision.transforms
from ...utils.rc import load_rc
from ..loader import load_pil_baw
from ..transforms import RemoveBlackBorders
from ..raw_data_loader import RemoveBlackBorders, load_pil_baw
_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
"""This variable contains the base directory where the database raw data is
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import collections.abc
import csv
import importlib.abc
import json
import logging
import pathlib
import typing
import torch
logger = logging.getLogger(__name__)
class JSONDatabaseSplit(
dict,
typing.Mapping[str, typing.Sequence[typing.Any]],
):
"""Defines a loader that understands a database split (train, test, etc) in
JSON format.
To create a new database split, you need to provide a JSON formatted
dictionary in a file, with contents similar to the following:
.. code-block:: json
{
"subset1": [
[
"sample1-data1",
"sample1-data2",
"sample1-data3",
],
[
"sample2-data1",
"sample2-data2",
"sample2-data3",
]
],
"subset2": [
[
"sample42-data1",
"sample42-data2",
"sample42-data3",
],
]
}
Your database split many contain any number of subsets (dictionary keys).
For simplicity, we recommend all sample entries are formatted similarly so
that raw-data-loading is simplified. Use the function
:py:func:`check_database_split_loading` to test raw data loading and fine
tune the dataset split, or its loading.
Objects of this class behave like a dictionary in which keys are subset
names in the split, and values represent samples data and meta-data.
Parameters
----------
path
Absolute path to a JSON formatted file containing the database split to be
recognized by this object.
"""
def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable):
if isinstance(path, str):
path = pathlib.Path(path)
self.path = path
self.subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> dict[str, typing.Sequence[typing.Any]]:
"""Loads all subsets in a split from its file system representation.
This method will load JSON information for the current split and return
all subsets of the given split after converting each entry through the
loader function.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of JSON objects
"""
if str(self.path).endswith(".bz2"):
logger.debug(f"Loading database split from {str(self.path)}...")
with __import__("bz2").open(self.path) as f:
return json.load(f)
else:
with self.path.open() as f:
return json.load(f)
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self.subsets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self.subsets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self.subsets)
class CSVDatabaseSplit(collections.abc.Mapping):
"""Defines a loader that understands a database split (train, test, etc) in
CSV format.
To create a new database split, you need to provide one or more CSV
formatted files, each representing a subset of this split, containing the
sample data (one per row). Example:
Inside the directory ``my-split/``, one can file files ``train.csv``,
``validation.csv``, and ``test.csv``. Each file has a structure similar to
the following:
.. code-block:: text
sample1-value1,sample1-value2,sample1-value3
sample2-value1,sample2-value2,sample2-value3
...
Each file in the provided directory defines the subset name on the split.
So, the file ``train.csv`` will contain the data from the ``train`` subset,
and so on.
Objects of this class behave like a dictionary in which keys are subset
names in the split, and values represent samples data and meta-data.
Parameters
----------
directory
Absolute path to a directory containing the database split layed down
as a set of CSV files, one per subset.
"""
def __init__(
self, directory: pathlib.Path | str | importlib.abc.Traversable
):
if isinstance(directory, str):
directory = pathlib.Path(directory)
assert (
directory.is_dir()
), f"`{str(directory)}` is not a valid directory"
self.directory = directory
self.subsets = self._load_split_from_disk()
def _load_split_from_disk(
self,
) -> dict[str, list[typing.Any]]:
"""Loads all subsets in a split from its file system representation.
This method will load CSV information for the current split and return all
subsets of the given split after converting each entry through the
loader function.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of JSON objects
"""
retval = {}
for subset in self.directory.iterdir():
if str(subset).endswith(".csv.bz2"):
logger.debug(f"Loading database split from {subset}...")
with __import__("bz2").open(subset) as f:
reader = csv.reader(f)
retval[subset.name[: -len(".csv.bz2")]] = [
k for k in reader
]
elif str(subset).endswith(".csv"):
with subset.open() as f:
reader = csv.reader(f)
retval[subset.name[: -len(".csv")]] = [k for k in reader]
else:
logger.debug(
f"Ignoring file {subset} in CSVDatabaseSplit readout"
)
return retval
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self.subsets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self.subsets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self.subsets)
def check_database_split_loading(
database_split: typing.Mapping[str, typing.Sequence[typing.Any]],
loader: typing.Callable[[typing.Any], torch.Tensor],
limit: int = 0,
) -> int:
"""For each subset in the split, check if all data can be correctly loaded
using the provided loader function.
This function will return the number of errors loading samples, and will
log more detailed information to the logging stream.
Parameters
----------
database_split
A mapping that, contains the database split. Each key represents the
name of a subset in the split. Each value is a (potentially complex)
object that represents a single sample.
loader
A callable that transforms sample entries in the database split
into :py:class:`torch.Tensor` objects that can be used for training
or inference.
limit
Maximum number of samples to check (in each split/subset
combination) in this dataset. If set to zero, then check
everything.
Returns
-------
errors
Number of errors found
"""
logger.info(
"Checking if can load all samples in all subsets of this split..."
)
errors = 0
for subset in database_split.keys():
samples = subset if not limit else subset[:limit]
for pos, sample in enumerate(samples):
try:
data = loader(sample)
assert isinstance(data, torch.Tensor)
except Exception as e:
logger.info(
f"Found error loading entry {pos} in subset `{subset}`: {e}"
)
errors += 1
return errors
......@@ -22,44 +22,9 @@ from scipy.ndimage import gaussian_filter, map_coordinates
from torchvision import transforms
class SingleAutoLevel16to8:
"""Converts a 16-bit image to 8-bit representation using "auto-level".
This transform assumes that the input image is gray-scaled.
To auto-level, we calculate the maximum and the minimum of the
image, and
consider such a range should be mapped to the [0,255] range of the
destination image.
"""
def __call__(self, img):
imin, imax = img.getextrema()
irange = imax - imin
return PIL.Image.fromarray(
numpy.round(
255.0 * (numpy.array(img).astype(float) - imin) / irange
).astype("uint8"),
).convert("L")
class RemoveBlackBorders:
"""Remove black borders of CXR."""
def __init__(self, threshold=0):
self.threshold = threshold
def __call__(self, img):
img = numpy.asarray(img)
mask = numpy.asarray(img) > self.threshold
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
class ElasticDeformation:
"""Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_.
TODO: needs to be converted into a torch.nn.Module to become scriptable!
Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0
"""
......@@ -70,7 +35,7 @@ class ElasticDeformation:
spline_order=1,
mode="nearest",
random_state=numpy.random,
p=1,
p=1.0,
):
self.alpha = alpha
self.sigma = sigma
......
......@@ -48,6 +48,21 @@ class Densenet(pl.LightningModule):
return x
def set_normalizer(self, dataloader):
"""TODO: Write this function to set the Normalizer
This function is NOOP if ``pretrained = True`` (normalizer set to
imagenet weights, during contruction).
"""
if self.pretrained:
from .normalizer import TorchVisionNormalizer
self.normalizer = TorchVisionNormalizer(..., ...)
else:
from .normalizer import get_znorm_normalizer
self.normalizer = get_znorm_normalizer(dataloader)
def training_step(self, batch, batch_idx):
images = batch[1]
labels = batch[2]
......
......@@ -6,6 +6,7 @@
import torch
import torch.nn
import torch.utils.data
class TorchVisionNormalizer(torch.nn.Module):
......@@ -20,19 +21,33 @@ class TorchVisionNormalizer(torch.nn.Module):
Number of images channels fed to the model
"""
def __init__(self, nb_channels=3):
def __init__(self, subtract: torch.Tensor, divide: torch.Tensor):
super().__init__()
mean = torch.zeros(nb_channels)[None, :, None, None]
std = torch.ones(nb_channels)[None, :, None, None]
self.register_buffer("mean", mean)
self.register_buffer("std", std)
assert len(subtract) == len(divide), "TODO"
assert len(subtract) in (1, 3), "TODO"
self.subtract = subtract
self.divided = divide
subtract = torch.zeros(len(subtract.shape))[None, :, None, None]
divide = torch.ones(len(divide.shape))[None, :, None, None]
self.register_buffer("subtract", subtract)
self.register_buffer("divide", divide)
self.name = "torchvision-normalizer"
def set_mean_std(self, mean, std):
mean = torch.as_tensor(mean)[None, :, None, None]
std = torch.as_tensor(std)[None, :, None, None]
self.register_buffer("mean", mean)
self.register_buffer("std", std)
def forward(self, inputs: torch.Tensor):
"""inputs shape [batches, planes, height, width]"""
return inputs.sub(self.subtract).div(self.divide)
def forward(self, inputs):
return inputs.sub(self.mean).div(self.std)
def get_znorm_normalizer(
dataloader: torch.utils.data.DataLoader,
) -> TorchVisionNormalizer:
# TODO: Fix this function to use unaugmented training set
# TODO: This function is only applicable IFF we are not fine-tuning (ie.
# model does not re-use weights from imagenet training!)
# TODO: Add type hints
# TODO: Add documentation
# 1 extract mean/std from dataloader
# 2 return TorchVisionNormalizer(mean, std)
pass
......@@ -6,6 +6,7 @@ import lightning.pytorch as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from .normalizer import TorchVisionNormalizer
......@@ -127,6 +128,12 @@ class PASA(pl.LightningModule):
return x
def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None:
"""TODO: Write this function documentation"""
from .normalizer import get_znorm_normalizer
self.normalizer = get_znorm_normalizer(dataloader)
def training_step(self, batch, batch_idx):
images = batch["data"]
labels = batch["label"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment