diff --git a/src/ptbench/data/shenzhen/loader.py b/src/ptbench/data/shenzhen/loader.py index e983fe7028d858118efb9ba41e560aee8e95845a..49ccf8bfb217e411004228b4acf7c924e3ffec66 100644 --- a/src/ptbench/data/shenzhen/loader.py +++ b/src/ptbench/data/shenzhen/loader.py @@ -82,7 +82,8 @@ class RawDataLoader(_BaseRawDataLoader): tensor = self.transform( load_pil_baw(os.path.join(self.datadir, sample[0])) ) - return tensor, dict(label=sample[1]) # type: ignore[arg-type] + + return tensor, dict(label=sample[1], name=sample[0]) # type: ignore[arg-type] def label(self, sample: tuple[str, int]) -> int: """Loads a single image sample label from the disk. diff --git a/src/ptbench/data/shenzhen/rgb.py b/src/ptbench/data/shenzhen/rgb.py index f45f601d6e84a04e4610319d1f27631ff32ee69c..211b49236cae22af68ad61d38849429dedb606d7 100644 --- a/src/ptbench/data/shenzhen/rgb.py +++ b/src/ptbench/data/shenzhen/rgb.py @@ -18,7 +18,7 @@ from torchvision import transforms from ..datamodule import CachingDataModule from ..split import JSONDatabaseSplit -from .raw_data_loader import raw_data_loader +from .loader import RawDataLoader datamodule = CachingDataModule( database_split=JSONDatabaseSplit( @@ -26,16 +26,10 @@ datamodule = CachingDataModule( "default.json.bz2" ) ), - raw_data_loader=raw_data_loader, - cache_samples=False, - # train_sampler: typing.Optional[torch.utils.data.Sampler] = None, + raw_data_loader=RawDataLoader(), model_transforms=[ transforms.ToPILImage(), transforms.Lambda(lambda x: x.convert("RGB")), transforms.ToTensor(), ], - # batch_size = 1, - # batch_chunk_count = 1, - # drop_incomplete_batch = False, - # parallel = -1, ) diff --git a/src/ptbench/data/typing.py b/src/ptbench/data/typing.py index 344c1294df6777f88acdde23dec51d40fc51e31e..f0e54f1afa701f0b7f9c691bbc7dfd51aad86e8f 100644 --- a/src/ptbench/data/typing.py +++ b/src/ptbench/data/typing.py @@ -73,3 +73,6 @@ DataLoader = torch.utils.data.DataLoader[Sample] We iterate over Sample objects in this case. """ + +Checkpoint = dict[str, typing.Any] +"""Definition of a lightning checkpoint.""" diff --git a/src/ptbench/engine/callbacks.py b/src/ptbench/engine/callbacks.py index 350140a8516ddef43e89323e65b746ca7c479182..8774f9c45c248414618a24592cab7ee687e74dbf 100644 --- a/src/ptbench/engine/callbacks.py +++ b/src/ptbench/engine/callbacks.py @@ -398,26 +398,27 @@ class PredictionsWriter(lightning.pytorch.callbacks.BasePredictionWriter): predictions: typing.Sequence[typing.Any], batch_indices: typing.Sequence[typing.Any] | None, ) -> None: - for dataloader_idx, dataloader_results in enumerate(predictions): - dataloader_name = list( - trainer.datamodule.predict_dataloader().keys() - )[dataloader_idx].replace("_loader", "") + dataloader_name = list(trainer.datamodule.predict_dataloader().keys())[ + 0 + ] - logfile = os.path.join( - self.output_dir, dataloader_name, "predictions.csv" - ) - os.makedirs(os.path.dirname(logfile), exist_ok=True) - - with open(logfile, "w") as l_f: - logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields) - logwriter.writeheader() - - for prediction in dataloader_results: - logwriter.writerow( - { - "filename": prediction[0], - "likelihood": prediction[1].numpy(), - "ground_truth": prediction[2].numpy(), - } - ) - l_f.flush() + logfile = os.path.join( + self.output_dir, f"predictions_{dataloader_name}_set.csv" + ) + os.makedirs(os.path.dirname(logfile), exist_ok=True) + + logger.info(f"Saving predictions in {logfile}.") + + with open(logfile, "w") as l_f: + logwriter = csv.DictWriter(l_f, fieldnames=self.logfile_fields) + logwriter.writeheader() + + for prediction in predictions: + logwriter.writerow( + { + "filename": prediction[0], + "likelihood": prediction[1].numpy(), + "ground_truth": prediction[2].numpy(), + } + ) + l_f.flush() diff --git a/src/ptbench/engine/predictor.py b/src/ptbench/engine/predictor.py index 5dcbb79c9fd0a8c32f9d269f8302b888da56be84..6bb6e275d304406182449c77a68a8a8e62406719 100644 --- a/src/ptbench/engine/predictor.py +++ b/src/ptbench/engine/predictor.py @@ -5,15 +5,22 @@ import logging import os +import lightning.pytorch + from lightning.pytorch import Trainer -from ..utils.accelerator import AcceleratorProcessor from .callbacks import PredictionsWriter +from .device import DeviceManager logger = logging.getLogger(__name__) -def run(model, datamodule, accelerator, output_folder, grad_cams=False): +def run( + model: lightning.pytorch.LightningModule, + datamodule: lightning.pytorch.LightningDataModule, + device_manager: DeviceManager, + output_folder: str, +): """Runs inference on input data, outputs csv files with predictions. Parameters @@ -21,11 +28,13 @@ def run(model, datamodule, accelerator, output_folder, grad_cams=False): model : :py:class:`torch.nn.Module` Neural network model (e.g. pasa). - data_loader : py:class:`torch.torch.utils.data.DataLoader` - The pytorch Dataloader used to iterate over batches. + datamodule + The lightning datamodule to use for training **and** validation - accelerator : str - A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0) + device_manager + An internal device representation, to be used for training and + validation. This representation can be converted into a pytorch device + or a torch lightning accelerator setup. output_folder : str Directory in which the results will be saved. @@ -44,19 +53,11 @@ def run(model, datamodule, accelerator, output_folder, grad_cams=False): logger.info(f"Output folder: {output_folder}") os.makedirs(output_folder, exist_ok=True) - accelerator_processor = AcceleratorProcessor(accelerator) - - if accelerator_processor.device is None: - devices = "auto" - else: - devices = accelerator_processor.device - - logger.info(f"Device: {devices}") - logfile_fields = ("filename", "likelihood", "ground_truth") + accelerator, devices = device_manager.lightning_accelerator() trainer = Trainer( - accelerator=accelerator_processor.accelerator, + accelerator=accelerator, devices=devices, callbacks=[ PredictionsWriter( diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index a878a076037925879c072bffb87d23a3e1ce7b0d..c3aadc6d0466777fb0adcb62eb7d20887cf36b20 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -8,13 +8,12 @@ import typing import lightning.pytorch as pl import torch import torch.nn -import torch.nn.functional as F import torch.optim.optimizer import torch.utils.data import torchvision.models as models import torchvision.transforms -from ..data.typing import DataLoader, TransformSequence +from ..data.typing import Checkpoint, DataLoader, TransformSequence logger = logging.getLogger(__name__) @@ -61,10 +60,10 @@ class Alexnet(pl.LightningModule): def __init__( self, - train_loss: torch.nn.Module, - validation_loss: torch.nn.Module | None, - optimizer_type: type[torch.optim.Optimizer], - optimizer_arguments: dict[str, typing.Any], + train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), + validation_loss: torch.nn.Module | None = None, + optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], pretrained: bool = False, ): @@ -105,6 +104,32 @@ class Alexnet(pl.LightningModule): return x + def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: + """Called by Lightning to restore your model. + + If you saved something with on_save_checkpoint() this is your chance to restore this. + + Parameters + ---------- + + checkpoint: + Loaded checkpoint + """ + checkpoint["normalizer"] = self.normalizer + + def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: + """Called by Lightning when saving a checkpoint to give you a chance to + store anything else you might want to save. + + Parameters + ---------- + + checkpoint: + Loaded checkpoint + """ + logger.info("Restoring normalizer from checkpoint.") + self.normalizer = checkpoint["normalizer"] + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: """Initializes the normalizer for the current model. @@ -214,7 +239,7 @@ class Alexnet(pl.LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): images = batch[0] labels = batch[1]["label"] - names = batch[1]["names"] + names = batch[1]["name"] outputs = self(images) probabilities = torch.sigmoid(outputs) diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index 8eba3b53410da4874d8b59c48c73e13bec1cb703..021f6ce2c6f5cfb3ad3819144a442744577d5eaa 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -8,13 +8,12 @@ import typing import lightning.pytorch as pl import torch import torch.nn -import torch.nn.functional as F import torch.optim.optimizer import torch.utils.data import torchvision.models as models import torchvision.transforms -from ..data.typing import DataLoader, TransformSequence +from ..data.typing import Checkpoint, DataLoader, TransformSequence logger = logging.getLogger(__name__) @@ -59,12 +58,12 @@ class Densenet(pl.LightningModule): def __init__( self, - train_loss: torch.nn.Module, - validation_loss: torch.nn.Module | None, - optimizer_type: type[torch.optim.Optimizer], - optimizer_arguments: dict[str, typing.Any], + train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), + validation_loss: torch.nn.Module | None = None, + optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], - pretrained: bool= False, + pretrained: bool = False, ): super().__init__() @@ -98,13 +97,38 @@ class Densenet(pl.LightningModule): ) def forward(self, x): - x = self.normalizer(x) # type: ignore x = self.model_ft(x) return x + def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: + """Called by Lightning to restore your model. + + If you saved something with on_save_checkpoint() this is your chance to restore this. + + Parameters + ---------- + + checkpoint: + Loaded checkpoint + """ + checkpoint["normalizer"] = self.normalizer + + def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: + """Called by Lightning when saving a checkpoint to give you a chance to + store anything else you might want to save. + + Parameters + ---------- + + checkpoint: + Loaded checkpoint + """ + logger.info("Restoring normalizer from checkpoint.") + self.normalizer = checkpoint["normalizer"] + def set_normalizer(self, dataloader: torch.utils.data.DataLoader) -> None: """Initializes the normalizer for the current model. @@ -216,7 +240,7 @@ class Densenet(pl.LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): images = batch[0] labels = batch[1]["label"] - names = batch[1]["names"] + names = batch[1]["name"] outputs = self(images) probabilities = torch.sigmoid(outputs) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index d6dd23ee02923490247d03b75fbf2c167aef57dd..5dd1c33c19c7a5b22e2b37bbfeb9943322f7f16f 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -13,7 +13,7 @@ import torch.optim.optimizer import torch.utils.data import torchvision.transforms -from ..data.typing import DataLoader, TransformSequence +from ..data.typing import Checkpoint, DataLoader, TransformSequence logger = logging.getLogger(__name__) @@ -58,10 +58,10 @@ class Pasa(pl.LightningModule): def __init__( self, - train_loss: torch.nn.Module, - validation_loss: torch.nn.Module | None, - optimizer_type: type[torch.optim.Optimizer], - optimizer_arguments: dict[str, typing.Any], + train_loss: torch.nn.Module = torch.nn.BCEWithLogitsLoss(), + validation_loss: torch.nn.Module | None = None, + optimizer_type: type[torch.optim.Optimizer] = torch.optim.Adam, + optimizer_arguments: dict[str, typing.Any] = {}, augmentation_transforms: TransformSequence = [], ): super().__init__() @@ -185,10 +185,29 @@ class Pasa(pl.LightningModule): return x - def on_save_checkpoint(self, checkpoint): + def on_save_checkpoint(self, checkpoint: Checkpoint) -> None: + """Called by Lightning to restore your model. + + If you saved something with on_save_checkpoint() this is your chance to restore this. + + Parameters + ---------- + + checkpoint: + Loaded checkpoint + """ checkpoint["normalizer"] = self.normalizer - def on_load_checkpoint(self, checkpoint): + def on_load_checkpoint(self, checkpoint: Checkpoint) -> None: + """Called by Lightning when saving a checkpoint to give you a chance to + store anything else you might want to save. + + Parameters + ---------- + + checkpoint: + Loaded checkpoint + """ logger.info("Restoring normalizer from checkpoint.") self.normalizer = checkpoint["normalizer"] @@ -289,7 +308,7 @@ class Pasa(pl.LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0, grad_cams=False): images = batch[0] labels = batch[1]["label"] - names = batch[1]["names"] + names = batch[1]["name"] outputs = self(images) probabilities = torch.sigmoid(outputs) diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index a78d74b41d8f75b3f4466a890f206e4b2503a84c..f73baabab5d2f09f5d2fd6d802c429ec63a2cc6f 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -62,9 +62,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ResourceOption, ) @click.option( - "--accelerator", - "-a", - help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)', + "--device", + "-d", + help='A string indicating the device to use (e.g. "cpu" or "cuda:0")', show_default=True, required=True, default="cpu", @@ -77,22 +77,14 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") required=True, cls=ResourceOption, ) -@click.option( - "--grad-cams", - "-g", - help="If set, generate grad cams for each prediction (must use batch of 1)", - is_flag=True, - cls=ResourceOption, -) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def predict( output_folder, model, datamodule, batch_size, - accelerator, + device, weight, - grad_cams, **_, ) -> None: """Predicts Tuberculosis presence (probabilities) on input images.""" @@ -103,12 +95,11 @@ def predict( from matplotlib.backends.backend_pdf import PdfPages + from ..engine.device import DeviceManager from ..engine.predictor import run from ..utils.plot import relevance_analysis_plot - datamodule = datamodule( - batch_size=batch_size, - ) + datamodule.set_chunk_size(batch_size, 1) logger.info(f"Loading checkpoint from {weight}") model = model.load_from_checkpoint(weight, strict=False) @@ -128,4 +119,4 @@ def predict( ) pdf.close() - _ = run(model, datamodule, accelerator, output_folder, grad_cams) + _ = run(model, datamodule, DeviceManager(device), output_folder)