From 4d303d1e5f9d6c26ad406f13783f2733a1c23465 Mon Sep 17 00:00:00 2001 From: dcarron <daniel.carron@idiap.ch> Date: Tue, 14 May 2024 11:06:07 +0200 Subject: [PATCH] [mednet] Fixes after rebase --- .../classification/engine/saliency/generator.py | 6 +++--- .../libs/classification/engine/saliency/viewer.py | 2 +- .../libs/classification/scripts/experiment.py | 2 -- src/mednet/libs/classification/scripts/train.py | 10 ++-------- src/mednet/libs/common/models/loss_weights.py | 14 +++++++------- src/mednet/libs/common/models/normalizer.py | 2 +- 6 files changed, 14 insertions(+), 22 deletions(-) diff --git a/src/mednet/libs/classification/engine/saliency/generator.py b/src/mednet/libs/classification/engine/saliency/generator.py index 28a11674..2b9badf1 100644 --- a/src/mednet/libs/classification/engine/saliency/generator.py +++ b/src/mednet/libs/classification/engine/saliency/generator.py @@ -212,14 +212,14 @@ def run( for sample in tqdm.tqdm(v, desc="samples", leave=False, disable=None): name = sample[1]["name"][0] - label = sample[1]["label"].item() + target = sample[1]["target"].item() image = sample[0].to( device=device, non_blocking=torch.cuda.is_available(), ) - # in binary classification systems, negative labels may be skipped - if positive_only and (model.num_classes == 1) and (label == 0): + # in binary classification systems, negative targets may be skipped + if positive_only and (model.num_classes == 1) and (target == 0): continue # chooses target outputs to generate saliency maps for diff --git a/src/mednet/libs/classification/engine/saliency/viewer.py b/src/mednet/libs/classification/engine/saliency/viewer.py index 588ba70c..063ce737 100644 --- a/src/mednet/libs/classification/engine/saliency/viewer.py +++ b/src/mednet/libs/classification/engine/saliency/viewer.py @@ -233,7 +233,7 @@ def run( # WARNING: following code assumes a batch size of 1. Will break if # not the case. name = str(sample[1]["name"][0]) - label = int(sample[1]["label"].item()) + label = int(sample[1]["target"].item()) data = sample[0][0] if label != target_label: diff --git a/src/mednet/libs/classification/scripts/experiment.py b/src/mednet/libs/classification/scripts/experiment.py index 3118b14b..814f3b76 100644 --- a/src/mednet/libs/classification/scripts/experiment.py +++ b/src/mednet/libs/classification/scripts/experiment.py @@ -48,7 +48,6 @@ def experiment( seed, parallel, monitoring_interval, - balance_classes, augmentations, **_, ): # numpydoc ignore=PR01 @@ -89,7 +88,6 @@ def experiment( seed=seed, parallel=parallel, monitoring_interval=monitoring_interval, - balance_classes=balance_classes, augmentations=augmentations, ) train_stop_timestamp = datetime.now() diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py index 6e1e28e8..85a570d2 100644 --- a/src/mednet/libs/classification/scripts/train.py +++ b/src/mednet/libs/classification/scripts/train.py @@ -86,14 +86,8 @@ def train( # of class examples available in the training set. Also affects the # validation loss if a validation set is available on the DataModule. if balance_classes: - logger.info("Applying DataModule train sampler balancing...") - datamodule.balance_sampler_by_class = True - # logger.info("Applying train/valid loss balancing...") - # model.balance_losses_by_class(datamodule) - else: - logger.info( - "Skipping sample class/dataset ownership balancing on user request", - ) + logger.info("Applying train/valid loss balancing...") + model.balance_losses(datamodule) checkpoint_file = get_checkpoint_file(output_folder) load_checkpoint(checkpoint_file, datamodule, model) diff --git a/src/mednet/libs/common/models/loss_weights.py b/src/mednet/libs/common/models/loss_weights.py index b3c2edec..596c9e7b 100644 --- a/src/mednet/libs/common/models/loss_weights.py +++ b/src/mednet/libs/common/models/loss_weights.py @@ -129,13 +129,13 @@ def get_positive_weights( number of negative and positive samples (scalar). The weight can be used to adjust minimisation criteria to in cases there is a huge data imbalance. - It returns a vector with weights (inverse counts) for each label. + It returns a vector with weights (inverse counts) for each target. Parameters ---------- dataloader A DataLoader from which to compute the positive weights. Entries must - be a dictionary which must contain a ``label`` key. + be a dictionary which must contain a ``target`` key. Returns ------- @@ -147,9 +147,9 @@ def get_positive_weights( targets = defaultdict(list) for batch in dataloader: - for class_idx, class_targets in enumerate(batch[1]["label"]): + for class_idx, class_targets in enumerate(batch[1]["target"]): # Targets are either a single tensor (binary case) or a list of tensors (multilabel) - if isinstance(batch[1]["label"], list): + if isinstance(batch[1]["target"], list): targets[class_idx].extend(tensor_to_list(class_targets)) else: targets[0].extend(tensor_to_list(class_targets)) @@ -161,17 +161,17 @@ def get_positive_weights( targets_tensor = torch.tensor(targets_list) if targets_tensor.shape[0] == 1: - logger.info("Computing positive weights assuming binary labels.") + logger.info("Computing positive weights assuming binary targets.") positive_weights = compute_binary_weights(targets_tensor) else: if is_multicalss_exclusive(targets_tensor): logger.info( - "Computing positive weights assuming multiclass, exclusive labels." + "Computing positive weights assuming multiclass, exclusive targets." ) positive_weights = compute_multiclass_weights(targets_tensor) else: logger.info( - "Computing positive weights assuming multiclass, non-exclusive labels." + "Computing positive weights assuming multiclass, non-exclusive targets." ) positive_weights = compute_non_exclusive_multiclass_weights(targets_tensor) diff --git a/src/mednet/libs/common/models/normalizer.py b/src/mednet/libs/common/models/normalizer.py index 63a5f10c..603de9c1 100644 --- a/src/mednet/libs/common/models/normalizer.py +++ b/src/mednet/libs/common/models/normalizer.py @@ -38,7 +38,7 @@ def make_z_normalizer( # Ensure targets are ints try: - target = batch[1]["label"][0].item() + target = batch[1]["target"][0].item() if not isinstance(target, int): logger.info( "Targets are not Integer type, skipping z-normalization." -- GitLab