From ba0cbbf55776e67f4dceaded1cbc154e4bc2786d Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Wed, 24 May 2023 16:59:17 +0200 Subject: [PATCH] Added DataModule to predict --- src/ptbench/data/datamodule.py | 29 +++++++---------------------- src/ptbench/scripts/predict.py | 12 ++++++------ 2 files changed, 13 insertions(+), 28 deletions(-) diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index efbcfaf9..d3c03d76 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -67,17 +67,7 @@ class DataModule(pl.LightningDataModule): self.extra_validation_datasets = None if stage == "predict": - self.predict_dataset = [] - - for split_key in self.dataset.keys(): - if split_key.startswith("_"): - logger.info( - f"Skipping dataset '{split_key}' (not to be evaluated)" - ) - continue - - else: - self.predict_dataset.append(self.dataset[split_key]) + self.predict_dataset = self.dataset def train_dataloader(self): train_samples_weights = get_samples_weights(self.train_dataset) @@ -127,14 +117,9 @@ class DataModule(pl.LightningDataModule): return loaders_dict def predict_dataloader(self): - loaders_dict = {} - - for set_idx, pred_set in enumerate(self.predict_dataset): - loaders_dict[set_idx] = DataLoader( - dataset=pred_set, - batch_size=self.predict_batch_size, - shuffle=False, - pin_memory=self.pin_memory, - ) - - return loaders_dict + return DataLoader( + dataset=self.predict_dataset, + batch_size=self.predict_batch_size, + shuffle=False, + pin_memory=self.pin_memory, + ) diff --git a/src/ptbench/scripts/predict.py b/src/ptbench/scripts/predict.py index 3613e75a..52dc98f5 100644 --- a/src/ptbench/scripts/predict.py +++ b/src/ptbench/scripts/predict.py @@ -117,6 +117,7 @@ def predict( from sklearn import metrics from torch.utils.data import ConcatDataset, DataLoader + from ..data.datamodule import DataModule from ..engine.predictor import run from ..utils.plot import relevance_analysis_plot @@ -147,14 +148,13 @@ def predict( logger.info(f"Running inference on '{k}' set...") - data_loader = DataLoader( - dataset=v, - batch_size=batch_size, - shuffle=False, - pin_memory=torch.cuda.is_available(), + datamodule = DataModule( + v, + train_batch_size=batch_size, ) + predictions = run( - model, data_loader, k, accelerator, output_folder, grad_cams + model, datamodule, k, accelerator, output_folder, grad_cams ) # Relevance analysis using permutation feature importance -- GitLab