Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • medai/software/mednet
1 result
Show changes
Commits on Source (20)
Showing
with 1192 additions and 1067 deletions
......@@ -15,33 +15,33 @@ dynamic = ["readme"]
license = { text = "GNU General Public License v3 (GPLv3)" }
authors = [{ name = "Geoffrey Raposo", email = "geoffrey@raposo.ch" }]
maintainers = [
{ name = "Andre Anjos", email = "andre.anjos@idiap.ch" },
{ name = "Daniel Carron", email = "daniel.carron@idiap.ch" },
{ name = "Andre Anjos", email = "andre.anjos@idiap.ch" },
{ name = "Daniel Carron", email = "daniel.carron@idiap.ch" },
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Topic :: Software Development :: Libraries :: Python Modules",
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Topic :: Software Development :: Libraries :: Python Modules",
]
dependencies = [
"clapper",
"click",
"numpy",
"pandas",
"scipy",
"scikit-learn",
"tqdm",
"psutil",
"tabulate",
"matplotlib",
"pillow",
"torch>=1.8",
"torchvision>=0.10",
"lightning>=2.0.3",
"tensorboard",
"clapper",
"click",
"numpy",
"pandas",
"scipy",
"scikit-learn",
"tqdm",
"psutil",
"tabulate",
"matplotlib",
"pillow",
"torch>=1.8",
"torchvision>=0.10",
"lightning>=2.0.3",
"tensorboard",
]
[project.urls]
......@@ -53,13 +53,13 @@ changelog = "https://gitlab.idiap.ch/biosignal/software/ptbench/-/releases"
[project.optional-dependencies]
qa = ["pre-commit"]
doc = [
"sphinx",
"furo",
"sphinx-autodoc-typehints",
"auto-intersphinx",
"sphinx-copybutton",
"sphinx-inline-tabs",
"sphinx-click",
"sphinx",
"furo",
"sphinx-autodoc-typehints",
"auto-intersphinx",
"sphinx-copybutton",
"sphinx-inline-tabs",
"sphinx-click",
]
test = ["pytest", "pytest-cov", "coverage"]
......
......@@ -6,19 +6,30 @@
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
from ...models.alexnet import Alexnet
# config
# optimizer
optimizer = SGD
optimizer_configs = {"lr": 0.01, "momentum": 0.1}
# optimizer
optimizer = "SGD"
# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model
model = Alexnet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False
criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=False,
augmentation_transforms=augmentation_transforms,
)
......@@ -6,19 +6,30 @@
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
from ...models.alexnet import Alexnet
# config
optimizer_configs = {"lr": 0.001, "momentum": 0.1}
# optimizer
optimizer = "SGD"
optimizer = SGD
optimizer_configs = {"lr": 0.01, "momentum": 0.1}
# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model
model = Alexnet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True
criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=True,
augmentation_transforms=augmentation_transforms,
)
......@@ -6,20 +6,30 @@
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.densenet import Densenet
# config
optimizer_configs = {"lr": 0.0001}
# optimizer
optimizer = "Adam"
optimizer = Adam
optimizer_configs = {"lr": 0.0001}
# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model
model = Densenet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False
criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=False,
augmentation_transforms=augmentation_transforms,
)
......@@ -6,20 +6,30 @@
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.densenet import Densenet
# config
optimizer_configs = {"lr": 0.01}
# optimizer
optimizer = "Adam"
optimizer = Adam
optimizer_configs = {"lr": 0.0001}
# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model
model = Densenet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True
criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=True,
augmentation_transforms=augmentation_transforms,
)
......@@ -13,18 +13,30 @@ Reference: [PASA-2019]_
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.pasa import PASA
# config
optimizer_configs = {"lr": 8e-5}
# optimizer
optimizer = "Adam"
optimizer = Adam
optimizer_configs = {"lr": 8e-5}
# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [ElasticDeformation(p=0.8)]
# from torchvision.transforms.v2 import ElasticTransform, InterpolationMode
# augmentation_transforms = [ElasticTransform(alpha=1000.0, sigma=30.0, interpolation=InterpolationMode.NEAREST)]
# model
model = PASA(criterion, criterion_valid, optimizer, optimizer_configs)
model = PASA(
criterion,
criterion_valid,
optimizer,
optimizer_configs,
augmentation_transforms=augmentation_transforms,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import multiprocessing
import sys
import lightning as pl
import torch
from clapper.logging import setup
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from .dataset import get_samples_weights
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class BaseDataModule(pl.LightningDataModule):
def __init__(
self,
batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
parallel=-1,
):
super().__init__()
self.batch_size = batch_size
self.batch_chunk_count = batch_chunk_count
self.train_dataset = None
self.validation_dataset = None
self.extra_validation_datasets = None
self._raw_data_transforms = []
self._model_transforms = []
self._augmentation_transforms = []
self.drop_incomplete_batch = drop_incomplete_batch
self.pin_memory = (
torch.cuda.is_available()
) # should only be true if GPU available and using it
self.parallel = parallel
def setup(self, stage: str):
# Implemented by user
# Must define self.train_dataset, self.validation_dataset, self.extra_validation_datasets and self.predict_dataset
raise NotImplementedError
def train_dataloader(self):
train_samples_weights = get_samples_weights(self.train_dataset)
multiproc_kwargs = self._setup_multiproc(self.parallel)
train_sampler = WeightedRandomSampler(
train_samples_weights, len(train_samples_weights), replacement=True
)
return DataLoader(
self.train_dataset,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
sampler=train_sampler,
**multiproc_kwargs,
)
def val_dataloader(self):
loaders_dict = {}
multiproc_kwargs = self._setup_multiproc(self.parallel)
val_loader = DataLoader(
dataset=self.validation_dataset,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
**multiproc_kwargs,
)
loaders_dict["validation_loader"] = val_loader
if self.extra_validation_datasets is not None:
for set_idx, extra_set in enumerate(self.extra_validation_datasets):
extra_val_loader = DataLoader(
dataset=extra_set,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
**multiproc_kwargs,
)
loaders_dict[
f"extra_validation_loader{set_idx}"
] = extra_val_loader
return loaders_dict
def predict_dataloader(self):
loaders_dict = {}
loaders_dict["train_loader"] = self.train_dataloader()
for k, v in self.val_dataloader().items():
loaders_dict[k] = v
return loaders_dict
def update_module_properties(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def _compute_chunk_size(self, batch_size, chunk_count):
batch_chunk_size = batch_size
if batch_size % chunk_count != 0:
# batch_size must be divisible by batch_chunk_count.
raise RuntimeError(
f"--batch-size ({batch_size}) must be divisible by "
f"--batch-chunk-size ({chunk_count})."
)
else:
batch_chunk_size = batch_size // chunk_count
return batch_chunk_size
def _setup_multiproc(self, parallel):
multiproc_kwargs = dict()
if parallel < 0:
multiproc_kwargs["num_workers"] = 0
else:
multiproc_kwargs["num_workers"] = (
parallel or multiprocessing.cpu_count()
)
if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
multiproc_kwargs[
"multiprocessing_context"
] = multiprocessing.get_context("spawn")
return multiproc_kwargs
def _build_transforms(self, is_train):
all_transforms = self.raw_data_transforms + self.model_transforms
if is_train:
all_transforms = all_transforms + self.augmentation_transforms
# all_transforms.append(transforms.ToTensor())
return transforms.Compose(all_transforms)
@property
def raw_data_transforms(self):
return self._raw_data_transforms
@raw_data_transforms.setter
def raw_data_transforms(self, transforms):
self._raw_data_transforms = transforms
@property
def model_transforms(self):
return self._model_transforms
@model_transforms.setter
def model_transforms(self, transforms):
self._model_transforms = transforms
@property
def augmentation_transforms(self):
return self._augmentation_transforms
@augmentation_transforms.setter
def augmentation_transforms(self, transforms):
self._augmentation_transforms = transforms
def get_dataset_from_module(module, stage, **module_args):
"""Instantiates a DataModule and retrieves the corresponding dataset.
Useful when combining multiple datasets.
"""
module_instance = module(**module_args)
module_instance.prepare_data()
module_instance.setup(stage=stage)
dataset = module_instance.dataset
return dataset
This diff is collapsed.
......@@ -2,460 +2,19 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import csv
import json
import logging
import os
import pathlib
import torch
from tqdm import tqdm
import torch.utils.data
logger = logging.getLogger(__name__)
class JSONProtocol:
"""Generic multi-protocol/subset filelist dataset that yields samples.
To create a new dataset, you need to provide one or more JSON formatted
filelists (one per protocol) with the following contents:
.. code-block:: json
{
"subset1": [
[
"value1",
"value2",
"value3"
],
[
"value4",
"value5",
"value6"
]
],
"subset2": [
]
}
Your dataset many contain any number of subsets, but all sample entries
must contain the same number of fields.
Parameters
----------
protocols : list, dict
Paths to one or more JSON formatted files containing the various
protocols to be recognized by this dataset, or a dictionary, mapping
protocol names to paths (or opened file objects) of CSV files.
Internally, we save a dictionary where keys default to the basename of
paths (list input).
fieldnames : list, tuple
An iterable over the field names (strings) to assign to each entry in
the JSON file. It should have as many items as fields in each entry of
the JSON file.
loader : object
A function that receives as input, a context dictionary (with at least
a "protocol" and "subset" keys indicating which protocol and subset are
being served), and a dictionary with ``{fieldname: value}`` entries,
and returns an object with at least 2 attributes:
* ``key``: which must be a unique string for every sample across
subsets in a protocol, and
* ``data``: which contains the data associated witht this sample
"""
def __init__(self, protocols, fieldnames):
if isinstance(protocols, dict):
self._protocols = protocols
else:
self._protocols = {
os.path.basename(
str(k).replace("".join(pathlib.Path(k).suffixes), "")
): k
for k in protocols
}
self.fieldnames = fieldnames
def check(self, limit=0):
"""For each protocol, check if all data can be correctly accessed.
This function assumes each sample has a ``data`` and a ``key``
attribute. The ``key`` attribute should be a string, or representable
as such.
Parameters
----------
limit : int
Maximum number of samples to check (in each protocol/subset
combination) in this dataset. If set to zero, then check
everything.
Returns
-------
errors : int
Number of errors found
"""
logger.info("Checking dataset...")
errors = 0
for proto in self._protocols:
logger.info(f"Checking protocol '{proto}'...")
for name, samples in self.subsets(proto).items():
logger.info(f"Checking subset '{name}'...")
if limit:
logger.info(f"Checking at most first '{limit}' samples...")
samples = samples[:limit]
for pos, sample in enumerate(samples):
try:
sample.data # may trigger data loading
logger.info(f"{sample.key}: OK")
except Exception as e:
logger.error(
f"Found error loading entry {pos} in subset {name} "
f"of protocol {proto} from file "
f"'{self._protocols[proto]}': {e}"
)
errors += 1
return errors
def subsets(self, protocol):
"""Returns all subsets in a protocol.
This method will load JSON information for a given protocol and return
all subsets of the given protocol after converting each entry through
the loader function.
Parameters
----------
protocol : str
Name of the protocol data to load
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of objects (respecting
the ``key``, ``data`` interface).
"""
fileobj = self._protocols[protocol]
if isinstance(fileobj, (str, bytes, pathlib.Path)):
if str(fileobj).endswith(".bz2"):
import bz2
with bz2.open(self._protocols[protocol]) as f:
data = json.load(f)
else:
with open(self._protocols[protocol]) as f:
data = json.load(f)
else:
data = json.load(fileobj)
fileobj.seek(0)
retval = {}
for subset, samples in data.items():
logger.info(f"Loading subset {subset} samples.")
retval[subset] = [
dict(zip(self.fieldnames, k))
for n, k in enumerate(tqdm(samples))
]
return retval
class CSVDataset:
"""Generic multi-subset filelist dataset that yields samples.
To create a new dataset, you only need to provide a CSV formatted filelist
using any separator (e.g. comma, space, semi-colon) with the following
information:
.. code-block:: text
value1,value2,value3
value4,value5,value6
...
Notice that all rows must have the same number of entries.
Parameters
----------
subsets : list, dict
Paths to one or more CSV formatted files containing the various subsets
to be recognized by this dataset, or a dictionary, mapping subset names
to paths (or opened file objects) of CSV files. Internally, we save a
dictionary where keys default to the basename of paths (list input).
fieldnames : list, tuple
An iterable over the field names (strings) to assign to each column in
the CSV file. It should have as many items as fields in each row of
the CSV file(s).
loader : object
A function that receives as input, a context dictionary (with, at
least, a "subset" key indicating which subset is being served), and a
dictionary with ``{key: path}`` entries, and returns a dictionary with
the loaded data.
"""
def __init__(self, subsets, fieldnames, loader):
if isinstance(subsets, dict):
self._subsets = subsets
else:
self._subsets = {
os.path.basename(
str(k).replace("".join(pathlib.Path(k).suffixes), "")
): k
for k in subsets
}
self.fieldnames = fieldnames
self._loader = loader
def check(self, limit=0):
"""For each subset, check if all data can be correctly accessed.
This function assumes each sample has a ``data`` and a ``key``
attribute. The ``key`` attribute should be a string, or representable
as such.
Parameters
----------
limit : int
Maximum number of samples to check (in each protocol/subset
combination) in this dataset. If set to zero, then check
everything.
Returns
-------
errors : int
Number of errors found
"""
logger.info("Checking dataset...")
errors = 0
for name in self._subsets.keys():
logger.info(f"Checking subset '{name}'...")
samples = self.samples(name)
if limit:
logger.info(f"Checking at most first '{limit}' samples...")
samples = samples[:limit]
for pos, sample in enumerate(samples):
try:
sample.data # may trigger data loading
logger.info(f"{sample.key}: OK")
except Exception as e:
logger.error(
f"Found error loading entry {pos} in subset {name} "
f"from file '{self._subsets[name]}': {e}"
)
errors += 1
return errors
def subsets(self):
"""Returns all available subsets at once.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of objects (respecting
the ``key``, ``data`` interface).
"""
return {k: self.samples(k) for k in self._subsets.keys()}
def samples(self, subset):
"""Returns all samples in a subset.
This method will load CSV information for a given subset and return
all samples of the given subset after passing each entry through the
loading function.
Parameters
----------
subset : str
Name of the subset data to load
Returns
-------
subset : list
A lists of objects (respecting the ``key``, ``data`` interface).
"""
fileobj = self._subsets[subset]
if isinstance(fileobj, (str, bytes, pathlib.Path)):
with open(self._subsets[subset], newline="") as f:
cf = csv.reader(f)
samples = [k for k in cf]
else:
cf = csv.reader(fileobj)
samples = [k for k in cf]
fileobj.seek(0)
return [
self._loader(
dict(subset=subset, order=n), dict(zip(self.fieldnames, k))
)
for n, k in enumerate(samples)
]
class CachedDataset(torch.utils.data.Dataset):
def __init__(
self, json_protocol, protocol, subset, raw_data_loader, transforms
):
self.json_protocol = json_protocol
self.subset = subset
self.raw_data_loader = raw_data_loader
self.transforms = transforms
self._samples = json_protocol.subsets(protocol)[self.subset]
# Dict entry with relative path to files, used during prediction
for s in self._samples:
s["name"] = s["data"]
logger.info(f"Caching {self.subset} samples")
for sample in tqdm(self._samples):
sample["data"] = self.raw_data_loader(sample["data"])
def __getitem__(self, idx):
sample = self._samples[idx].copy()
sample["data"] = self.transforms(sample["data"])
return sample
def __len__(self):
return len(self._samples)
class RuntimeDataset(torch.utils.data.Dataset):
def __init__(
self, json_protocol, protocol, subset, raw_data_loader, transforms
):
self.json_protocol = json_protocol
self.subset = subset
self.raw_data_loader = raw_data_loader
self.transforms = transforms
self._samples = json_protocol.subsets(protocol)[self.subset]
# Dict entry with relative path to files
for s in self._samples:
s["name"] = s["data"]
def __getitem__(self, idx):
sample = self._samples[idx].copy()
sample["data"] = self.transforms(self.raw_data_loader(sample["data"]))
return sample
def __len__(self):
return len(self._samples)
def get_samples_weights(dataset):
"""Compute the weights of all the samples of the dataset to balance it
using the sampler of the dataloader.
This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
and computes the weights to balance each class in the dataset and the
datasets themselves if we have a ConcatDataset.
Parameters
----------
dataset : torch.utils.data.dataset.Dataset
An instance of torch.utils.data.dataset.Dataset
ConcatDataset are supported
Returns
-------
samples_weights : :py:class:`torch.Tensor`
the weights for all the samples in the dataset given as input
"""
samples_weights = []
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
# Weighting only for binary labels
if isinstance(ds._samples[0]["label"], int):
# Groundtruth
targets = []
for s in ds._samples:
targets.append(s["label"])
targets = torch.tensor(targets)
# Count number of samples per class
class_sample_count = torch.tensor(
[
(targets == t).sum()
for t in torch.unique(targets, sorted=True)
]
)
weight = 1.0 / class_sample_count.float()
samples_weights.append(
torch.tensor([weight[t] for t in targets])
)
else:
# We only weight to sample equally from each dataset
samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
# Concatenate sample weights from all the datasets
samples_weights = torch.cat(samples_weights)
else:
# Weighting only for binary labels
if isinstance(dataset._samples[0]["label"], int):
# Groundtruth
targets = []
for s in dataset._samples:
targets.append(s["label"])
targets = torch.tensor(targets)
# Count number of samples per class
class_sample_count = torch.tensor(
[
(targets == t).sum()
for t in torch.unique(targets, sorted=True)
]
)
weight = 1.0 / class_sample_count.float()
samples_weights = torch.tensor([weight[t] for t in targets])
else:
# Equal weights for non-binary labels
samples_weights = torch.ones(len(dataset._samples))
return samples_weights
def get_positive_weights(dataset):
def _get_positive_weights(dataloader):
"""Compute the positive weights of each class of the dataset to balance the
BCEWithLogitsLoss criterion.
This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
This function takes as input a :py:class:`torch.utils.data.DataLoader`
and computes the positive weights of each class to use them to have
a balanced loss.
......@@ -463,9 +22,8 @@ def get_positive_weights(dataset):
Parameters
----------
dataset : torch.utils.data.dataset.Dataset
An instance of torch.utils.data.dataset.Dataset
ConcatDataset are supported
dataloader : :py:class:`torch.utils.data.DataLoader`
A DataLoader from which to compute the positive weights. Must contain a 'label' key in the metadata returned by __getitem__().
Returns
......@@ -476,14 +34,8 @@ def get_positive_weights(dataset):
"""
targets = []
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
for s in ds._samples:
targets.append(s["label"])
else:
for s in dataset._samples:
targets.append(s["label"])
for batch in dataloader:
targets.extend(batch[1]["label"])
targets = torch.tensor(targets)
......@@ -512,75 +64,3 @@ def get_positive_weights(dataset):
)
return positive_weights
def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
from torch.nn import BCEWithLogitsLoss
datamodule.prepare_data()
datamodule.setup(stage="fit")
train_dataset = datamodule.train_dataset
validation_dataset = datamodule.validation_dataset
# Redefine a weighted criterion if possible
if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
positive_weights = get_positive_weights(train_dataset)
model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
else:
logger.warning("Weighted criterion not supported")
if validation_dataset is not None:
# Redefine a weighted valid criterion if possible
if (
isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
or criterion_valid is None
):
positive_weights = get_positive_weights(validation_dataset)
model.hparams.criterion_valid = BCEWithLogitsLoss(
pos_weight=positive_weights
)
else:
logger.warning("Weighted valid criterion not supported")
def normalize_data(normalization, model, datamodule):
from torch.utils.data import DataLoader
datamodule.prepare_data()
datamodule.setup(stage="fit")
train_dataset = datamodule.train_dataset
# Create z-normalization model layer if needed
if normalization == "imagenet":
model.normalizer.set_mean_std(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
)
logger.info("Z-normalization with ImageNet mean and std")
elif normalization == "current":
# Compute mean/std of current train subset
temp_dl = DataLoader(
dataset=train_dataset, batch_size=len(train_dataset)
)
data = next(iter(temp_dl))
mean = data[1].mean(dim=[0, 2, 3])
std = data[1].std(dim=[0, 2, 3])
model.normalizer.set_mean_std(mean, std)
# Format mean and std for logging
mean = str(
[
round(x, 3)
for x in ((mean * 10**3).round() / (10**3)).tolist()
]
)
std = str(
[
round(x, 3)
for x in ((std * 10**3).round() / (10**3)).tolist()
]
)
logger.info(f"Z-normalization with mean {mean} and std {std}")
......@@ -264,7 +264,7 @@ json_dataset = JSONDataset(
def _maker(protocol, resize_size=512, cc_size=512, RGB=True):
import torchvision.transforms as transforms
from ..transforms import SingleAutoLevel16to8
from ..loader import SingleAutoLevel16to8
post_transforms = []
if not RGB:
......
......@@ -5,9 +5,42 @@
"""Data loading code."""
import numpy
import PIL.Image
class SingleAutoLevel16to8:
"""Converts a 16-bit image to 8-bit representation using "auto-level".
This transform assumes that the input image is gray-scaled.
To auto-level, we calculate the maximum and the minimum of the image, and
consider such a range should be mapped to the [0,255] range of the
destination image.
"""
def __call__(self, img):
imin, imax = img.getextrema()
irange = imax - imin
return PIL.Image.fromarray(
numpy.round(
255.0 * (numpy.array(img).astype(float) - imin) / irange
).astype("uint8"),
).convert("L")
class RemoveBlackBorders:
"""Remove black borders of CXR."""
def __init__(self, threshold=0):
self.threshold = threshold
def __call__(self, img):
img = numpy.asarray(img)
mask = numpy.asarray(img) > self.threshold
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
def load_pil(path):
"""Loads a sample data.
......
......@@ -13,22 +13,11 @@ the daily routine using Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
import importlib.resources
import os
from clapper.logging import setup
from ...utils.rc import load_rc
from ..loader import load_pil_baw
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
_protocols = [
importlib.resources.files(__name__).joinpath("default.json.bz2"),
......@@ -43,11 +32,3 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_8.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
]
_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
def _raw_data_loader(img_path):
raw_data = load_pil_baw(os.path.join(_datadir, img_path))
return raw_data
......@@ -2,27 +2,34 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for TB detection (default protocol)
"""Shenzhen datamodule for computer-aided diagnosis (default protocol)
* Split reference: first 64% of TB and healthy CXR for "train" 16% for
* "validation", 20% for "test"
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.shenzhen` for dataset details
"""
from clapper.logging import setup
from ..transforms import ElasticDeformation
from .utils import ShenzhenDataModule
See :py:mod:`ptbench.data.shenzhen` for dataset details.
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
protocol_name = "default"
augmentation_transforms = [ElasticDeformation(p=0.8)]
This configuration:
* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
* augmentations: elastic deformation (probability = 80%)
* output image resolution: 512x512 pixels
"""
datamodule = ShenzhenDataModule(
protocol="default",
model_transforms=[],
augmentation_transforms=augmentation_transforms,
import importlib.resources
from ..datamodule import CachingDataModule
from ..split import JSONDatabaseSplit
from .raw_data_loader import raw_data_loader
datamodule = CachingDataModule(
database_split=JSONDatabaseSplit(
importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
"default.json.bz2"
)
),
raw_data_loader=raw_data_loader,
cache_samples=False,
# train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
# model_transforms = [],
# batch_size = 1,
# batch_chunk_count = 1,
# drop_incomplete_batch = False,
# parallel = -1,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for computer-aided diagnosis.
The standard digital image database for Tuberculosis is created by the National
Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s
Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from
out-patient clinics, and were captured as part of the daily routine using
Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
import os
import typing
import torch.nn
import torchvision.transforms
from ...utils.rc import load_rc
from ..raw_data_loader import RemoveBlackBorders, load_pil_baw
_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
"""This variable contains the base directory where the database raw data is
stored."""
_transform = torchvision.transforms.Compose(
[
RemoveBlackBorders(),
torchvision.transforms.Resize(512),
torchvision.transforms.CenterCrop(512),
torchvision.transforms.ToTensor(),
]
)
"""Transforms that are always applied to the loaded raw images."""
def raw_data_loader(
sample: tuple[str, int]
) -> tuple[torch.Tensor, typing.Mapping]:
"""Loads a single image sample from the disk.
Parameters
----------
img_path
The path suffix, within the dataset root folder, where to find the
image to be loaded.
Returns
-------
image
A PIL image in grayscale mode
"""
tensor = _transform(load_pil_baw(os.path.join(_datadir, sample[0])))
return tensor, dict(label=sample[1]) # type: ignore[arg-type]
......@@ -2,81 +2,40 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for TB detection (cross validation fold 0, RGB)
"""Shenzhen datamodule for computer-aided diagnosis (default protocol)
* Split reference: first 80% of TB and healthy CXR for "train", rest for "test"
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.shenzhen` for dataset details
"""
from clapper.logging import setup
from ....data import return_subsets
from ....data.base_datamodule import BaseDataModule
from ....data.dataset import JSONProtocol
from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols
See :py:mod:`ptbench.data.shenzhen` for dataset details.
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
cache_samples=False,
multiproc_kwargs=None,
data_transforms=[],
model_transforms=[],
train_transforms=[],
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
self.cache_samples = cache_samples
self.has_setup_fit = False
This configuration:
* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
* augmentations: elastic deformation (probability = 80%)
* output image resolution: 512x512 pixels
"""
self.data_transforms = data_transforms
self.model_transforms = model_transforms
self.train_transforms = train_transforms
import importlib.resources
"""[
transforms.ToPILImage(),
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
]"""
from torchvision import transforms
def setup(self, stage: str):
if self.cache_samples:
logger.info(
"Argument cache_samples set to True. Samples will be loaded in memory."
)
samples_loader = _cached_loader
else:
logger.info(
"Argument cache_samples set to False. Samples will be loaded at runtime."
)
samples_loader = _delayed_loader
from ..datamodule import CachingDataModule
from ..split import JSONDatabaseSplit
from .raw_data_loader import raw_data_loader
self.json_protocol = JSONProtocol(
protocols=_protocols,
fieldnames=("data", "label"),
loader=samples_loader,
post_transforms=self.post_transforms,
datamodule = CachingDataModule(
database_split=JSONDatabaseSplit(
importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
"default.json.bz2"
)
if not self.has_setup_fit and stage == "fit":
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
) = return_subsets(self.json_protocol, "default", stage)
self.has_setup_fit = True
datamodule = DefaultModule
),
raw_data_loader=raw_data_loader,
cache_samples=False,
# train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
model_transforms=[
transforms.ToPILImage(),
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
],
# batch_size = 1,
# batch_chunk_count = 1,
# drop_incomplete_batch = False,
# parallel = -1,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for computer-aided diagnosis.
The standard digital image database for Tuberculosis is created by the
National Library of Medicine, Maryland, USA in collaboration with Shenzhen
No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China.
The Chest X-rays are from out-patient clinics, and were captured as part of
the daily routine using Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
from clapper.logging import setup
from torchvision import transforms
from ..base_datamodule import BaseDataModule
from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
from ..shenzhen import _protocols, _raw_data_loader
from ..transforms import RemoveBlackBorders
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class ShenzhenDataModule(BaseDataModule):
def __init__(
self,
protocol="default",
model_transforms=[],
augmentation_transforms=[],
batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
cache_samples=False,
parallel=-1,
):
super().__init__(
batch_size=batch_size,
drop_incomplete_batch=drop_incomplete_batch,
batch_chunk_count=batch_chunk_count,
parallel=parallel,
)
self._cache_samples = cache_samples
self._has_setup_fit = False
self._has_setup_predict = False
self._protocol = protocol
self.raw_data_transforms = [
RemoveBlackBorders(),
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
]
self.model_transforms = model_transforms
self.augmentation_transforms = augmentation_transforms
def setup(self, stage: str):
json_protocol = JSONProtocol(
protocols=_protocols,
fieldnames=("data", "label"),
)
if self._cache_samples:
dataset = CachedDataset
else:
dataset = RuntimeDataset
if not self._has_setup_fit and stage == "fit":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=True),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_fit = True
if not self._has_setup_predict and stage == "predict":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_predict = True
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import collections.abc
import csv
import importlib.abc
import json
import logging
import pathlib
import typing
import torch
logger = logging.getLogger(__name__)
class JSONDatabaseSplit(
dict,
typing.Mapping[str, typing.Sequence[typing.Any]],
):
"""Defines a loader that understands a database split (train, test, etc) in
JSON format.
To create a new database split, you need to provide a JSON formatted
dictionary in a file, with contents similar to the following:
.. code-block:: json
{
"subset1": [
[
"sample1-data1",
"sample1-data2",
"sample1-data3",
],
[
"sample2-data1",
"sample2-data2",
"sample2-data3",
]
],
"subset2": [
[
"sample42-data1",
"sample42-data2",
"sample42-data3",
],
]
}
Your database split many contain any number of subsets (dictionary keys).
For simplicity, we recommend all sample entries are formatted similarly so
that raw-data-loading is simplified. Use the function
:py:func:`check_database_split_loading` to test raw data loading and fine
tune the dataset split, or its loading.
Objects of this class behave like a dictionary in which keys are subset
names in the split, and values represent samples data and meta-data.
Parameters
----------
path
Absolute path to a JSON formatted file containing the database split to be
recognized by this object.
"""
def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable):
if isinstance(path, str):
path = pathlib.Path(path)
self.path = path
self.subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> dict[str, typing.Sequence[typing.Any]]:
"""Loads all subsets in a split from its file system representation.
This method will load JSON information for the current split and return
all subsets of the given split after converting each entry through the
loader function.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of JSON objects
"""
if str(self.path).endswith(".bz2"):
logger.debug(f"Loading database split from {str(self.path)}...")
with __import__("bz2").open(self.path) as f:
return json.load(f)
else:
with self.path.open() as f:
return json.load(f)
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self.subsets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self.subsets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self.subsets)
class CSVDatabaseSplit(collections.abc.Mapping):
"""Defines a loader that understands a database split (train, test, etc) in
CSV format.
To create a new database split, you need to provide one or more CSV
formatted files, each representing a subset of this split, containing the
sample data (one per row). Example:
Inside the directory ``my-split/``, one can file files ``train.csv``,
``validation.csv``, and ``test.csv``. Each file has a structure similar to
the following:
.. code-block:: text
sample1-value1,sample1-value2,sample1-value3
sample2-value1,sample2-value2,sample2-value3
...
Each file in the provided directory defines the subset name on the split.
So, the file ``train.csv`` will contain the data from the ``train`` subset,
and so on.
Objects of this class behave like a dictionary in which keys are subset
names in the split, and values represent samples data and meta-data.
Parameters
----------
directory
Absolute path to a directory containing the database split layed down
as a set of CSV files, one per subset.
"""
def __init__(
self, directory: pathlib.Path | str | importlib.abc.Traversable
):
if isinstance(directory, str):
directory = pathlib.Path(directory)
assert (
directory.is_dir()
), f"`{str(directory)}` is not a valid directory"
self.directory = directory
self.subsets = self._load_split_from_disk()
def _load_split_from_disk(
self,
) -> dict[str, list[typing.Any]]:
"""Loads all subsets in a split from its file system representation.
This method will load CSV information for the current split and return all
subsets of the given split after converting each entry through the
loader function.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of JSON objects
"""
retval = {}
for subset in self.directory.iterdir():
if str(subset).endswith(".csv.bz2"):
logger.debug(f"Loading database split from {subset}...")
with __import__("bz2").open(subset) as f:
reader = csv.reader(f)
retval[subset.name[: -len(".csv.bz2")]] = [
k for k in reader
]
elif str(subset).endswith(".csv"):
with subset.open() as f:
reader = csv.reader(f)
retval[subset.name[: -len(".csv")]] = [k for k in reader]
else:
logger.debug(
f"Ignoring file {subset} in CSVDatabaseSplit readout"
)
return retval
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self.subsets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self.subsets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self.subsets)
def check_database_split_loading(
database_split: typing.Mapping[str, typing.Sequence[typing.Any]],
loader: typing.Callable[[typing.Any], torch.Tensor],
limit: int = 0,
) -> int:
"""For each subset in the split, check if all data can be correctly loaded
using the provided loader function.
This function will return the number of errors loading samples, and will
log more detailed information to the logging stream.
Parameters
----------
database_split
A mapping that, contains the database split. Each key represents the
name of a subset in the split. Each value is a (potentially complex)
object that represents a single sample.
loader
A callable that transforms sample entries in the database split
into :py:class:`torch.Tensor` objects that can be used for training
or inference.
limit
Maximum number of samples to check (in each split/subset
combination) in this dataset. If set to zero, then check
everything.
Returns
-------
errors
Number of errors found
"""
logger.info(
"Checking if can load all samples in all subsets of this split..."
)
errors = 0
for subset in database_split.keys():
samples = subset if not limit else subset[:limit]
for pos, sample in enumerate(samples):
try:
data = loader(sample)
assert isinstance(data, torch.Tensor)
except Exception as e:
logger.info(
f"Found error loading entry {pos} in subset `{subset}`: {e}"
)
errors += 1
return errors
......@@ -22,39 +22,6 @@ from scipy.ndimage import gaussian_filter, map_coordinates
from torchvision import transforms
class SingleAutoLevel16to8:
"""Converts a 16-bit image to 8-bit representation using "auto-level".
This transform assumes that the input image is gray-scaled.
To auto-level, we calculate the maximum and the minimum of the
image, and
consider such a range should be mapped to the [0,255] range of the
destination image.
"""
def __call__(self, img):
imin, imax = img.getextrema()
irange = imax - imin
return PIL.Image.fromarray(
numpy.round(
255.0 * (numpy.array(img).astype(float) - imin) / irange
).astype("uint8"),
).convert("L")
class RemoveBlackBorders:
"""Remove black borders of CXR."""
def __init__(self, threshold=0):
self.threshold = threshold
def __call__(self, img):
img = numpy.asarray(img)
mask = numpy.asarray(img) > self.threshold
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
class ElasticDeformation:
"""Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_.
......@@ -68,7 +35,7 @@ class ElasticDeformation:
spline_order=1,
mode="nearest",
random_state=numpy.random,
p=1,
p=1.0,
):
self.alpha = alpha
self.sigma = sigma
......@@ -79,13 +46,15 @@ class ElasticDeformation:
def __call__(self, img):
if random.random() < self.p:
img = transforms.ToPILImage()(img)
assert img.ndim == 3
# Input tensor is of shape C x H x W
# If the tensor only contains one channel, this conversion results in H x W.
# With 3 channels, we get H x W x C
img = transforms.ToPILImage()(img)
img = numpy.asarray(img)
assert img.ndim == 2
shape = img.shape
shape = img.shape[:2]
dx = (
gaussian_filter(
......@@ -114,9 +83,22 @@ class ElasticDeformation:
numpy.reshape(y + dy, (-1, 1)),
]
result = numpy.empty_like(img)
result[:, :] = map_coordinates(
img[:, :], indices, order=self.spline_order, mode=self.mode
).reshape(shape)
if img.ndim == 2:
result[:, :] = map_coordinates(
img[:, :], indices, order=self.spline_order, mode=self.mode
).reshape(shape)
else:
for i in range(img.shape[2]):
result[:, :, i] = map_coordinates(
img[:, :, i],
indices,
order=self.spline_order,
mode=self.mode,
).reshape(shape)
return transforms.ToTensor()(PIL.Image.fromarray(result))
else:
return img
......@@ -94,10 +94,8 @@ class LoggingCallback(Callback):
self.log("total_time", current_time)
self.log("eta", eta_seconds)
self.log("loss", numpy.average(self.training_loss))
self.log(
"learning_rate", pl_module.hparams["optimizer_configs"]["lr"]
)
self.log("validation_loss", numpy.sum(self.validation_loss))
self.log("learning_rate", pl_module.optimizer_configs["lr"])
self.log("validation_loss", numpy.average(self.validation_loss))
if len(self.extra_validation_loss) > 0:
for (
......@@ -105,7 +103,8 @@ class LoggingCallback(Callback):
extra_valid_loss_values,
) in self.extra_validation_loss.items:
self.log(
extra_valid_loss_key, numpy.sum(extra_valid_loss_values)
extra_valid_loss_key,
numpy.average(extra_valid_loss_values),
)
queue_retries = 0
......
......@@ -7,19 +7,20 @@ import logging
import os
import shutil
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.utilities.model_summary import ModelSummary
import lightning.pytorch
import lightning.pytorch.callbacks
import lightning.pytorch.loggers
import torch.nn
from ..utils.accelerator import AcceleratorProcessor
from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
from ..utils.save_sh_command import save_sh_command
from .callbacks import LoggingCallback
logger = logging.getLogger(__name__)
def check_gpu(device):
def check_gpu(device: str) -> None:
"""Check the device type and the availability of GPU.
Parameters
......@@ -35,32 +36,41 @@ def check_gpu(device):
), f"Device set to '{device}', but nvidia-smi is not installed"
def save_model_summary(output_folder, model):
def save_model_summary(
output_folder: str, model: torch.nn.Module
) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]:
"""Save a little summary of the model in a txt file.
Parameters
----------
output_folder : str
output_folder
output path
model : :py:class:`torch.nn.Module`
model
Network (e.g. driu, hed, unet)
Returns
-------
r : str
summary:
The model summary in a text format.
n : int
total_parameters:
The number of parameters of the model.
"""
summary_path = os.path.join(output_folder, "model_summary.txt")
logger.info(f"Saving model summary at {summary_path}...")
with open(summary_path, "w") as f:
summary = ModelSummary(model, max_depth=-1)
summary = lightning.pytorch.utilities.model_summary.ModelSummary(
model, max_depth=-1
)
f.write(str(summary))
return summary, ModelSummary(model).total_parameters
return (
summary,
lightning.pytorch.utilities.model_summary.ModelSummary(
model
).total_parameters,
)
def static_information_to_csv(static_logfile_name, device, n):
......@@ -70,7 +80,8 @@ def static_information_to_csv(static_logfile_name, device, n):
----------
static_logfile_name : str
The static file name which is a join between the output folder and "constant.csv"
The static file name which is a join between the output folder and
"constant.csv"
"""
if os.path.exists(static_logfile_name):
backup = static_logfile_name + "~"
......@@ -188,7 +199,8 @@ def run(
not save intermediary checkpoints.
accelerator : str
A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The
device can also be specified (gpu:0)
arguments : dict
Start and end epochs:
......@@ -215,10 +227,16 @@ def run(
os.makedirs(output_folder, exist_ok=True)
# Save model summary
r, n = save_model_summary(output_folder, model)
_, n = save_model_summary(output_folder, model)
csv_logger = CSVLogger(output_folder, "logs_csv")
tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard")
save_sh_command(output_folder)
# save_sh_command(os.path.join(output_folder, "cmd_line_config.txt"))
csv_logger = lightning.pytorch.loggers.CSVLogger(output_folder, "logs_csv")
tensorboard_logger = lightning.pytorch.loggers.TensorBoardLogger(
output_folder, "logs_tensorboard"
)
resource_monitor = ResourceMonitor(
interval=monitoring_interval,
......@@ -227,7 +245,7 @@ def run(
logging_level=logging.ERROR,
)
checkpoint_callback = ModelCheckpoint(
checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint(
output_folder,
"model_lowest_valid_loss",
save_last=True,
......@@ -251,7 +269,7 @@ def run(
devices = accelerator_processor.device
with resource_monitor:
trainer = Trainer(
trainer = lightning.pytorch.Trainer(
accelerator=accelerator_processor.accelerator,
devices=devices,
max_epochs=max_epoch,
......