diff --git a/src/mednet/libs/classification/engine/saliency/generator.py b/src/mednet/libs/classification/engine/saliency/generator.py index 28a11674d027f0a92ba0bb9c5a6096f475899214..2b9badf19390c4e564041dec0a28570ad853da8e 100644 --- a/src/mednet/libs/classification/engine/saliency/generator.py +++ b/src/mednet/libs/classification/engine/saliency/generator.py @@ -212,14 +212,14 @@ def run( for sample in tqdm.tqdm(v, desc="samples", leave=False, disable=None): name = sample[1]["name"][0] - label = sample[1]["label"].item() + target = sample[1]["target"].item() image = sample[0].to( device=device, non_blocking=torch.cuda.is_available(), ) - # in binary classification systems, negative labels may be skipped - if positive_only and (model.num_classes == 1) and (label == 0): + # in binary classification systems, negative targets may be skipped + if positive_only and (model.num_classes == 1) and (target == 0): continue # chooses target outputs to generate saliency maps for diff --git a/src/mednet/libs/classification/engine/saliency/viewer.py b/src/mednet/libs/classification/engine/saliency/viewer.py index 588ba70ccf4584db6026de5cb88423321c2184f8..063ce7379c7a777d2cd28d8f28f4d7b1f118ccb2 100644 --- a/src/mednet/libs/classification/engine/saliency/viewer.py +++ b/src/mednet/libs/classification/engine/saliency/viewer.py @@ -233,7 +233,7 @@ def run( # WARNING: following code assumes a batch size of 1. Will break if # not the case. name = str(sample[1]["name"][0]) - label = int(sample[1]["label"].item()) + label = int(sample[1]["target"].item()) data = sample[0][0] if label != target_label: diff --git a/src/mednet/libs/classification/scripts/experiment.py b/src/mednet/libs/classification/scripts/experiment.py index 3118b14b51332d79faafd31755856e50bac47ae3..814f3b76da7e1e4d2c9b95c51c17c97879930e6f 100644 --- a/src/mednet/libs/classification/scripts/experiment.py +++ b/src/mednet/libs/classification/scripts/experiment.py @@ -48,7 +48,6 @@ def experiment( seed, parallel, monitoring_interval, - balance_classes, augmentations, **_, ): # numpydoc ignore=PR01 @@ -89,7 +88,6 @@ def experiment( seed=seed, parallel=parallel, monitoring_interval=monitoring_interval, - balance_classes=balance_classes, augmentations=augmentations, ) train_stop_timestamp = datetime.now() diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py index 6e1e28e8d560e634d80c1aba18a79c0a137fbfa5..85a570d28d0619c2f170d640129434ae26a4d250 100644 --- a/src/mednet/libs/classification/scripts/train.py +++ b/src/mednet/libs/classification/scripts/train.py @@ -86,14 +86,8 @@ def train( # of class examples available in the training set. Also affects the # validation loss if a validation set is available on the DataModule. if balance_classes: - logger.info("Applying DataModule train sampler balancing...") - datamodule.balance_sampler_by_class = True - # logger.info("Applying train/valid loss balancing...") - # model.balance_losses_by_class(datamodule) - else: - logger.info( - "Skipping sample class/dataset ownership balancing on user request", - ) + logger.info("Applying train/valid loss balancing...") + model.balance_losses(datamodule) checkpoint_file = get_checkpoint_file(output_folder) load_checkpoint(checkpoint_file, datamodule, model) diff --git a/src/mednet/libs/common/models/loss_weights.py b/src/mednet/libs/common/models/loss_weights.py index b3c2edec5b982762521eedbc4778ea21abb4d5fc..596c9e7b75babb85e5d15a714a1e3d1de70aa3db 100644 --- a/src/mednet/libs/common/models/loss_weights.py +++ b/src/mednet/libs/common/models/loss_weights.py @@ -129,13 +129,13 @@ def get_positive_weights( number of negative and positive samples (scalar). The weight can be used to adjust minimisation criteria to in cases there is a huge data imbalance. - It returns a vector with weights (inverse counts) for each label. + It returns a vector with weights (inverse counts) for each target. Parameters ---------- dataloader A DataLoader from which to compute the positive weights. Entries must - be a dictionary which must contain a ``label`` key. + be a dictionary which must contain a ``target`` key. Returns ------- @@ -147,9 +147,9 @@ def get_positive_weights( targets = defaultdict(list) for batch in dataloader: - for class_idx, class_targets in enumerate(batch[1]["label"]): + for class_idx, class_targets in enumerate(batch[1]["target"]): # Targets are either a single tensor (binary case) or a list of tensors (multilabel) - if isinstance(batch[1]["label"], list): + if isinstance(batch[1]["target"], list): targets[class_idx].extend(tensor_to_list(class_targets)) else: targets[0].extend(tensor_to_list(class_targets)) @@ -161,17 +161,17 @@ def get_positive_weights( targets_tensor = torch.tensor(targets_list) if targets_tensor.shape[0] == 1: - logger.info("Computing positive weights assuming binary labels.") + logger.info("Computing positive weights assuming binary targets.") positive_weights = compute_binary_weights(targets_tensor) else: if is_multicalss_exclusive(targets_tensor): logger.info( - "Computing positive weights assuming multiclass, exclusive labels." + "Computing positive weights assuming multiclass, exclusive targets." ) positive_weights = compute_multiclass_weights(targets_tensor) else: logger.info( - "Computing positive weights assuming multiclass, non-exclusive labels." + "Computing positive weights assuming multiclass, non-exclusive targets." ) positive_weights = compute_non_exclusive_multiclass_weights(targets_tensor) diff --git a/src/mednet/libs/common/models/normalizer.py b/src/mednet/libs/common/models/normalizer.py index 63a5f10c67f369a34e1f7ad04ef92b3ee649b2fb..603de9c132d0a51a1c2cd56958574fdc874bb56f 100644 --- a/src/mednet/libs/common/models/normalizer.py +++ b/src/mednet/libs/common/models/normalizer.py @@ -38,7 +38,7 @@ def make_z_normalizer( # Ensure targets are ints try: - target = batch[1]["label"][0].item() + target = batch[1]["target"][0].item() if not isinstance(target, int): logger.info( "Targets are not Integer type, skipping z-normalization."