Skip to content
Snippets Groups Projects
Commit 5b5d601b authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Added DataModule to predict

parent 5f0c48a1
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -67,17 +67,7 @@ class DataModule(pl.LightningDataModule): ...@@ -67,17 +67,7 @@ class DataModule(pl.LightningDataModule):
self.extra_validation_datasets = None self.extra_validation_datasets = None
if stage == "predict": if stage == "predict":
self.predict_dataset = [] self.predict_dataset = self.dataset
for split_key in self.dataset.keys():
if split_key.startswith("_"):
logger.info(
f"Skipping dataset '{split_key}' (not to be evaluated)"
)
continue
else:
self.predict_dataset.append(self.dataset[split_key])
def train_dataloader(self): def train_dataloader(self):
train_samples_weights = get_samples_weights(self.train_dataset) train_samples_weights = get_samples_weights(self.train_dataset)
...@@ -127,14 +117,9 @@ class DataModule(pl.LightningDataModule): ...@@ -127,14 +117,9 @@ class DataModule(pl.LightningDataModule):
return loaders_dict return loaders_dict
def predict_dataloader(self): def predict_dataloader(self):
loaders_dict = {} return DataLoader(
dataset=self.predict_dataset,
for set_idx, pred_set in enumerate(self.predict_dataset): batch_size=self.predict_batch_size,
loaders_dict[set_idx] = DataLoader( shuffle=False,
dataset=pred_set, pin_memory=self.pin_memory,
batch_size=self.predict_batch_size, )
shuffle=False,
pin_memory=self.pin_memory,
)
return loaders_dict
...@@ -117,6 +117,7 @@ def predict( ...@@ -117,6 +117,7 @@ def predict(
from sklearn import metrics from sklearn import metrics
from torch.utils.data import ConcatDataset, DataLoader from torch.utils.data import ConcatDataset, DataLoader
from ..data.datamodule import DataModule
from ..engine.predictor import run from ..engine.predictor import run
from ..utils.plot import relevance_analysis_plot from ..utils.plot import relevance_analysis_plot
...@@ -147,14 +148,13 @@ def predict( ...@@ -147,14 +148,13 @@ def predict(
logger.info(f"Running inference on '{k}' set...") logger.info(f"Running inference on '{k}' set...")
data_loader = DataLoader( datamodule = DataModule(
dataset=v, v,
batch_size=batch_size, train_batch_size=batch_size,
shuffle=False,
pin_memory=torch.cuda.is_available(),
) )
predictions = run( predictions = run(
model, data_loader, k, accelerator, output_folder, grad_cams model, datamodule, k, accelerator, output_folder, grad_cams
) )
# Relevance analysis using permutation feature importance # Relevance analysis using permutation feature importance
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment