Skip to content
Snippets Groups Projects
Commit ba0cbbf5 authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Added DataModule to predict

parent 239ca0ff
No related branches found
No related tags found
No related merge requests found
......@@ -67,17 +67,7 @@ class DataModule(pl.LightningDataModule):
self.extra_validation_datasets = None
if stage == "predict":
self.predict_dataset = []
for split_key in self.dataset.keys():
if split_key.startswith("_"):
logger.info(
f"Skipping dataset '{split_key}' (not to be evaluated)"
)
continue
else:
self.predict_dataset.append(self.dataset[split_key])
self.predict_dataset = self.dataset
def train_dataloader(self):
train_samples_weights = get_samples_weights(self.train_dataset)
......@@ -127,14 +117,9 @@ class DataModule(pl.LightningDataModule):
return loaders_dict
def predict_dataloader(self):
loaders_dict = {}
for set_idx, pred_set in enumerate(self.predict_dataset):
loaders_dict[set_idx] = DataLoader(
dataset=pred_set,
batch_size=self.predict_batch_size,
shuffle=False,
pin_memory=self.pin_memory,
)
return loaders_dict
return DataLoader(
dataset=self.predict_dataset,
batch_size=self.predict_batch_size,
shuffle=False,
pin_memory=self.pin_memory,
)
......@@ -117,6 +117,7 @@ def predict(
from sklearn import metrics
from torch.utils.data import ConcatDataset, DataLoader
from ..data.datamodule import DataModule
from ..engine.predictor import run
from ..utils.plot import relevance_analysis_plot
......@@ -147,14 +148,13 @@ def predict(
logger.info(f"Running inference on '{k}' set...")
data_loader = DataLoader(
dataset=v,
batch_size=batch_size,
shuffle=False,
pin_memory=torch.cuda.is_available(),
datamodule = DataModule(
v,
train_batch_size=batch_size,
)
predictions = run(
model, data_loader, k, accelerator, output_folder, grad_cams
model, datamodule, k, accelerator, output_folder, grad_cams
)
# Relevance analysis using permutation feature importance
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment