diff --git a/src/mednet/libs/classification/scripts/train.py b/src/mednet/libs/classification/scripts/train.py index dbc14ebef400b23dbacd90c458d6108b4f76c69e..6e1e28e8d560e634d80c1aba18a79c0a137fbfa5 100644 --- a/src/mednet/libs/classification/scripts/train.py +++ b/src/mednet/libs/classification/scripts/train.py @@ -27,6 +27,17 @@ logger = setup("mednet", format="%(levelname)s: %(message)s") """, ) @reusable_options +@click.option( + "--balance-classes/--no-balance-classes", + "-B/-N", + help="""If set, balances weights of the random sampler during + training so that samples from all sample classes are picked + equitably.""", + required=True, + show_default=True, + default=True, + cls=ResourceOption, +) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) def train( model, @@ -104,7 +115,7 @@ def train( seed, parallel, monitoring_interval, - balance_classes, + balance_classes=balance_classes, ) run( diff --git a/src/mednet/libs/common/scripts/train.py b/src/mednet/libs/common/scripts/train.py index 32d7bd9481b0d973d0f0eb6279836d584578d84b..db48174d449fb42879c81ec086b6c0216f3fe1f4 100644 --- a/src/mednet/libs/common/scripts/train.py +++ b/src/mednet/libs/common/scripts/train.py @@ -197,17 +197,6 @@ def reusable_options(f): default=5.0, cls=ResourceOption, ) - @click.option( - "--balance-classes/--no-balance-classes", - "-B/-N", - help="""If set, balances weights of the random sampler during - training so that samples from all sample classes are picked - equitably.""", - required=True, - show_default=True, - default=True, - cls=ResourceOption, - ) @functools.wraps(f) def wrapper_reusable_options(*args, **kwargs): return f(*args, **kwargs) @@ -317,7 +306,7 @@ def save_json_data( seed, parallel, monitoring_interval, - balance_classes, + **kwargs, ) -> None: # numpydoc ignore=PR01 """Save training hyperparameters into a .json file.""" from .utils import ( @@ -342,10 +331,10 @@ def save_json_data( seed=seed, parallel=parallel, monitoring_interval=monitoring_interval, - balance_classes=balance_classes, model_name=model.name, ), ) + json_data.update(kwargs) json_data.update(model_summary(model)) json_data = {k.replace("_", "-"): v for k, v in json_data.items()} save_json_with_backup(output_folder / "meta.json", json_data) diff --git a/src/mednet/libs/segmentation/models/lwnet.py b/src/mednet/libs/segmentation/models/lwnet.py index ac23d617138e7f8f9d04849691b82832930ffd7b..c3190f0bdf6e6ced5f41be64c1856a6d68fa1cfc 100644 --- a/src/mednet/libs/segmentation/models/lwnet.py +++ b/src/mednet/libs/segmentation/models/lwnet.py @@ -303,7 +303,7 @@ class LittleWNet(Model): def training_step(self, batch, batch_idx): images = batch[0] - ground_truths = batch[1]["label"] + ground_truths = batch[1]["target"] masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3] outputs = self(self._augmentation_transforms(images)) @@ -312,7 +312,7 @@ class LittleWNet(Model): def validation_step(self, batch, batch_idx): images = batch[0] - ground_truths = batch[1]["label"] + ground_truths = batch[1]["target"] masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3] outputs = self(images) diff --git a/src/mednet/libs/segmentation/scripts/train.py b/src/mednet/libs/segmentation/scripts/train.py index 2632369c874fda1aad7b792c6f6a0a655b0cfe3b..67477db3c3b047fcfcfe9df0c9996193cab5a748 100644 --- a/src/mednet/libs/segmentation/scripts/train.py +++ b/src/mednet/libs/segmentation/scripts/train.py @@ -14,7 +14,7 @@ logger = setup("mednet", format="%(levelname)s: %(message)s") @click.command( - entry_point_group="mednet.libs.classification.config", + entry_point_group="mednet.libs.segmentation.config", cls=ConfigCommand, epilog="""Examples: @@ -41,7 +41,6 @@ def train( seed, parallel, monitoring_interval, - balance_classes, **_, ) -> None: # numpydoc ignore=PR01 """Train an CNN to perform image classification. @@ -89,7 +88,6 @@ def train( seed, parallel, monitoring_interval, - balance_classes, ) run(