Skip to content
Snippets Groups Projects
Commit c7be3528 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

[train] move balance-classes option to classification script only

parent 8366f375
No related branches found
No related tags found
1 merge request!46Create common library
...@@ -27,6 +27,17 @@ logger = setup("mednet", format="%(levelname)s: %(message)s") ...@@ -27,6 +27,17 @@ logger = setup("mednet", format="%(levelname)s: %(message)s")
""", """,
) )
@reusable_options @reusable_options
@click.option(
"--balance-classes/--no-balance-classes",
"-B/-N",
help="""If set, balances weights of the random sampler during
training so that samples from all sample classes are picked
equitably.""",
required=True,
show_default=True,
default=True,
cls=ResourceOption,
)
@verbosity_option(logger=logger, cls=ResourceOption, expose_value=False) @verbosity_option(logger=logger, cls=ResourceOption, expose_value=False)
def train( def train(
model, model,
...@@ -104,7 +115,7 @@ def train( ...@@ -104,7 +115,7 @@ def train(
seed, seed,
parallel, parallel,
monitoring_interval, monitoring_interval,
balance_classes, balance_classes=balance_classes,
) )
run( run(
......
...@@ -197,17 +197,6 @@ def reusable_options(f): ...@@ -197,17 +197,6 @@ def reusable_options(f):
default=5.0, default=5.0,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option(
"--balance-classes/--no-balance-classes",
"-B/-N",
help="""If set, balances weights of the random sampler during
training so that samples from all sample classes are picked
equitably.""",
required=True,
show_default=True,
default=True,
cls=ResourceOption,
)
@functools.wraps(f) @functools.wraps(f)
def wrapper_reusable_options(*args, **kwargs): def wrapper_reusable_options(*args, **kwargs):
return f(*args, **kwargs) return f(*args, **kwargs)
...@@ -317,7 +306,7 @@ def save_json_data( ...@@ -317,7 +306,7 @@ def save_json_data(
seed, seed,
parallel, parallel,
monitoring_interval, monitoring_interval,
balance_classes, **kwargs,
) -> None: # numpydoc ignore=PR01 ) -> None: # numpydoc ignore=PR01
"""Save training hyperparameters into a .json file.""" """Save training hyperparameters into a .json file."""
from .utils import ( from .utils import (
...@@ -342,10 +331,10 @@ def save_json_data( ...@@ -342,10 +331,10 @@ def save_json_data(
seed=seed, seed=seed,
parallel=parallel, parallel=parallel,
monitoring_interval=monitoring_interval, monitoring_interval=monitoring_interval,
balance_classes=balance_classes,
model_name=model.name, model_name=model.name,
), ),
) )
json_data.update(kwargs)
json_data.update(model_summary(model)) json_data.update(model_summary(model))
json_data = {k.replace("_", "-"): v for k, v in json_data.items()} json_data = {k.replace("_", "-"): v for k, v in json_data.items()}
save_json_with_backup(output_folder / "meta.json", json_data) save_json_with_backup(output_folder / "meta.json", json_data)
...@@ -303,7 +303,7 @@ class LittleWNet(Model): ...@@ -303,7 +303,7 @@ class LittleWNet(Model):
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
images = batch[0] images = batch[0]
ground_truths = batch[1]["label"] ground_truths = batch[1]["target"]
masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3] masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3]
outputs = self(self._augmentation_transforms(images)) outputs = self(self._augmentation_transforms(images))
...@@ -312,7 +312,7 @@ class LittleWNet(Model): ...@@ -312,7 +312,7 @@ class LittleWNet(Model):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
images = batch[0] images = batch[0]
ground_truths = batch[1]["label"] ground_truths = batch[1]["target"]
masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3] masks = torch.ones_like(ground_truths) if len(batch) < 4 else batch[3]
outputs = self(images) outputs = self(images)
......
...@@ -14,7 +14,7 @@ logger = setup("mednet", format="%(levelname)s: %(message)s") ...@@ -14,7 +14,7 @@ logger = setup("mednet", format="%(levelname)s: %(message)s")
@click.command( @click.command(
entry_point_group="mednet.libs.classification.config", entry_point_group="mednet.libs.segmentation.config",
cls=ConfigCommand, cls=ConfigCommand,
epilog="""Examples: epilog="""Examples:
...@@ -41,7 +41,6 @@ def train( ...@@ -41,7 +41,6 @@ def train(
seed, seed,
parallel, parallel,
monitoring_interval, monitoring_interval,
balance_classes,
**_, **_,
) -> None: # numpydoc ignore=PR01 ) -> None: # numpydoc ignore=PR01
"""Train an CNN to perform image classification. """Train an CNN to perform image classification.
...@@ -89,7 +88,6 @@ def train( ...@@ -89,7 +88,6 @@ def train(
seed, seed,
parallel, parallel,
monitoring_interval, monitoring_interval,
balance_classes,
) )
run( run(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment