From c7be35285361bdcd9377c46ecd02efa7e42e2fb6 Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Fri, 26 Apr 2024 12:57:15 +0200
Subject: [PATCH] [train] move balance-classes option to classification script
 only

---
 src/mednet/libs/classification/scripts/train.py | 13 ++++++++++++-
 src/mednet/libs/common/scripts/train.py         | 15 ++-------------
 src/mednet/libs/segmentation/models/lwnet.py    |  4 ++--
 src/mednet/libs/segmentation/scripts/train.py   |  4 +---
 4 files changed, 17 insertions(+), 19 deletions(-)

diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py
index dbc14ebe..6e1e28e8 100644
--- a/src/mednet/libs/classification/scripts/train.py
+++ b/src/mednet/libs/classification/scripts/train.py
@@ -27,6 +27,17 @@ logger = setup("mednet", format="%(levelname)s: %(message)s")
 """,
 )
 @reusable_options
+@click.option(
+    "--balance-classes/--no-balance-classes",
+    "-B/-N",
+    help="""If set, balances weights of the random sampler during
+    training so that samples from all sample classes are picked
+    equitably.""",
+    required=True,
+    show_default=True,
+    default=True,
+    cls=ResourceOption,
+)
 @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
 def train(
     model,
@@ -104,7 +115,7 @@ def train(
         seed,
         parallel,
         monitoring_interval,
-        balance_classes,
+        balance_classes=balance_classes,
     )
 
     run(
diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py
index 32d7bd94..db48174d 100644
--- a/src/mednet/libs/common/scripts/train.py
+++ b/src/mednet/libs/common/scripts/train.py
@@ -197,17 +197,6 @@ def reusable_options(f):
         default=5.0,
         cls=ResourceOption,
     )
-    @click.option(
-        "--balance-classes/--no-balance-classes",
-        "-B/-N",
-        help="""If set, balances weights of the random sampler during
-        training so that samples from all sample classes are picked
-        equitably.""",
-        required=True,
-        show_default=True,
-        default=True,
-        cls=ResourceOption,
-    )
     @functools.wraps(f)
     def wrapper_reusable_options(*args, **kwargs):
         return f(*args, **kwargs)
@@ -317,7 +306,7 @@ def save_json_data(
     seed,
     parallel,
     monitoring_interval,
-    balance_classes,
+    **kwargs,
 ) -> None:  # numpydoc ignore=PR01
     """Save training hyperparameters into a .json file."""
     from .utils import (
@@ -342,10 +331,10 @@ def save_json_data(
             seed=seed,
             parallel=parallel,
             monitoring_interval=monitoring_interval,
-            balance_classes=balance_classes,
             model_name=model.name,
         ),
     )
+    json_data.update(kwargs)
     json_data.update(model_summary(model))
     json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
     save_json_with_backup(output_folder / "meta.json", json_data)
diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py
index ac23d617..c3190f0b 100644
--- a/src/mednet/libs/segmentation/models/lwnet.py
+++ b/src/mednet/libs/segmentation/models/lwnet.py
@@ -303,7 +303,7 @@ class LittleWNet(Model):
 
     def training_step(self, batch, batch_idx):
         images = batch[0]
-        ground_truths = batch[1]["label"]
+        ground_truths = batch[1]["target"]
         masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3]
 
         outputs = self(self._augmentation_transforms(images))
@@ -312,7 +312,7 @@ class LittleWNet(Model):
 
     def validation_step(self, batch, batch_idx):
         images = batch[0]
-        ground_truths = batch[1]["label"]
+        ground_truths = batch[1]["target"]
         masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3]
 
         outputs = self(images)
diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py
index 2632369c..67477db3 100644
--- a/src/mednet/libs/segmentation/scripts/train.py
+++ b/src/mednet/libs/segmentation/scripts/train.py
@@ -14,7 +14,7 @@ logger = setup("mednet", format="%(levelname)s: %(message)s")
 
 
 @click.command(
-    entry_point_group="mednet.libs.classification.config",
+    entry_point_group="mednet.libs.segmentation.config",
     cls=ConfigCommand,
     epilog="""Examples:
 
@@ -41,7 +41,6 @@ def train(
     seed,
     parallel,
     monitoring_interval,
-    balance_classes,
     **_,
 ) -> None:  # numpydoc ignore=PR01
     """Train an CNN to perform image classification.
@@ -89,7 +88,6 @@ def train(
         seed,
         parallel,
         monitoring_interval,
-        balance_classes,
     )
 
     run(
-- 
GitLab