Skip to content
Snippets Groups Projects
Commit 4d303d1e authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[mednet] Fixes after rebase

parent c7be3528
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -212,14 +212,14 @@ def run( ...@@ -212,14 +212,14 @@ def run(
for sample in tqdm.tqdm(v, desc="samples", leave=False, disable=None): for sample in tqdm.tqdm(v, desc="samples", leave=False, disable=None):
name = sample[1]["name"][0] name = sample[1]["name"][0]
label = sample[1]["label"].item() target = sample[1]["target"].item()
image = sample[0].to( image = sample[0].to(
device=device, device=device,
non_blocking=torch.cuda.is_available(), non_blocking=torch.cuda.is_available(),
) )
# in binary classification systems, negative labels may be skipped # in binary classification systems, negative targets may be skipped
if positive_only and (model.num_classes == 1) and (label == 0): if positive_only and (model.num_classes == 1) and (target == 0):
continue continue
# chooses target outputs to generate saliency maps for # chooses target outputs to generate saliency maps for
......
...@@ -233,7 +233,7 @@ def run( ...@@ -233,7 +233,7 @@ def run(
# WARNING: following code assumes a batch size of 1. Will break if # WARNING: following code assumes a batch size of 1. Will break if
# not the case. # not the case.
name = str(sample[1]["name"][0]) name = str(sample[1]["name"][0])
label = int(sample[1]["label"].item()) label = int(sample[1]["target"].item())
data = sample[0][0] data = sample[0][0]
if label != target_label: if label != target_label:
......
...@@ -48,7 +48,6 @@ def experiment( ...@@ -48,7 +48,6 @@ def experiment(
seed, seed,
parallel, parallel,
monitoring_interval, monitoring_interval,
balance_classes,
augmentations, augmentations,
**_, **_,
): # numpydoc ignore=PR01 ): # numpydoc ignore=PR01
...@@ -89,7 +88,6 @@ def experiment( ...@@ -89,7 +88,6 @@ def experiment(
seed=seed, seed=seed,
parallel=parallel, parallel=parallel,
monitoring_interval=monitoring_interval, monitoring_interval=monitoring_interval,
balance_classes=balance_classes,
augmentations=augmentations, augmentations=augmentations,
) )
train_stop_timestamp = datetime.now() train_stop_timestamp = datetime.now()
......
...@@ -86,14 +86,8 @@ def train( ...@@ -86,14 +86,8 @@ def train(
# of class examples available in the training set. Also affects the # of class examples available in the training set. Also affects the
# validation loss if a validation set is available on the DataModule. # validation loss if a validation set is available on the DataModule.
if balance_classes: if balance_classes:
logger.info("Applying DataModule train sampler balancing...") logger.info("Applying train/valid loss balancing...")
datamodule.balance_sampler_by_class = True model.balance_losses(datamodule)
# logger.info("Applying train/valid loss balancing...")
# model.balance_losses_by_class(datamodule)
else:
logger.info(
"Skipping sample class/dataset ownership balancing on user request",
)
checkpoint_file = get_checkpoint_file(output_folder) checkpoint_file = get_checkpoint_file(output_folder)
load_checkpoint(checkpoint_file, datamodule, model) load_checkpoint(checkpoint_file, datamodule, model)
......
...@@ -129,13 +129,13 @@ def get_positive_weights( ...@@ -129,13 +129,13 @@ def get_positive_weights(
number of negative and positive samples (scalar). The weight can be used number of negative and positive samples (scalar). The weight can be used
to adjust minimisation criteria to in cases there is a huge data imbalance. to adjust minimisation criteria to in cases there is a huge data imbalance.
It returns a vector with weights (inverse counts) for each label. It returns a vector with weights (inverse counts) for each target.
Parameters Parameters
---------- ----------
dataloader dataloader
A DataLoader from which to compute the positive weights. Entries must A DataLoader from which to compute the positive weights. Entries must
be a dictionary which must contain a ``label`` key. be a dictionary which must contain a ``target`` key.
Returns Returns
------- -------
...@@ -147,9 +147,9 @@ def get_positive_weights( ...@@ -147,9 +147,9 @@ def get_positive_weights(
targets = defaultdict(list) targets = defaultdict(list)
for batch in dataloader: for batch in dataloader:
for class_idx, class_targets in enumerate(batch[1]["label"]): for class_idx, class_targets in enumerate(batch[1]["target"]):
# Targets are either a single tensor (binary case) or a list of tensors (multilabel) # Targets are either a single tensor (binary case) or a list of tensors (multilabel)
if isinstance(batch[1]["label"], list): if isinstance(batch[1]["target"], list):
targets[class_idx].extend(tensor_to_list(class_targets)) targets[class_idx].extend(tensor_to_list(class_targets))
else: else:
targets[0].extend(tensor_to_list(class_targets)) targets[0].extend(tensor_to_list(class_targets))
...@@ -161,17 +161,17 @@ def get_positive_weights( ...@@ -161,17 +161,17 @@ def get_positive_weights(
targets_tensor = torch.tensor(targets_list) targets_tensor = torch.tensor(targets_list)
if targets_tensor.shape[0] == 1: if targets_tensor.shape[0] == 1:
logger.info("Computing positive weights assuming binary labels.") logger.info("Computing positive weights assuming binary targets.")
positive_weights = compute_binary_weights(targets_tensor) positive_weights = compute_binary_weights(targets_tensor)
else: else:
if is_multicalss_exclusive(targets_tensor): if is_multicalss_exclusive(targets_tensor):
logger.info( logger.info(
"Computing positive weights assuming multiclass, exclusive labels." "Computing positive weights assuming multiclass, exclusive targets."
) )
positive_weights = compute_multiclass_weights(targets_tensor) positive_weights = compute_multiclass_weights(targets_tensor)
else: else:
logger.info( logger.info(
"Computing positive weights assuming multiclass, non-exclusive labels." "Computing positive weights assuming multiclass, non-exclusive targets."
) )
positive_weights = compute_non_exclusive_multiclass_weights(targets_tensor) positive_weights = compute_non_exclusive_multiclass_weights(targets_tensor)
......
...@@ -38,7 +38,7 @@ def make_z_normalizer( ...@@ -38,7 +38,7 @@ def make_z_normalizer(
# Ensure targets are ints # Ensure targets are ints
try: try:
target = batch[1]["label"][0].item() target = batch[1]["target"][0].item()
if not isinstance(target, int): if not isinstance(target, int):
logger.info( logger.info(
"Targets are not Integer type, skipping z-normalization." "Targets are not Integer type, skipping z-normalization."
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment