From 67ca29f48ff5bae70b4d306b4d38ded85bfdec33 Mon Sep 17 00:00:00 2001
From: Andre Anjos <andre.dos.anjos@gmail.com>
Date: Fri, 18 Aug 2023 00:04:26 +0200
Subject: [PATCH] [scripts.experiment] Reflect changes from evaluation; closes
 #44 after noticing train-analysis is performed

---
 src/ptbench/scripts/experiment.py | 53 ++++++++-----------------------
 src/ptbench/scripts/train.py      |  3 +-
 2 files changed, 14 insertions(+), 42 deletions(-)

diff --git a/src/ptbench/scripts/experiment.py b/src/ptbench/scripts/experiment.py
index bbfe5b86..3bf6a50c 100644
--- a/src/ptbench/scripts/experiment.py
+++ b/src/ptbench/scripts/experiment.py
@@ -12,8 +12,6 @@ from clapper.logging import setup
 
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
-from .utils import save_sh_command
-
 
 @click.command(
     entry_point_group="ptbench.config",
@@ -21,12 +19,13 @@ from .utils import save_sh_command
     epilog="""Examples:
 
 \b
-  1. Trains a pasa model with shenzhen dataset, on the CPU, for only two epochs, then runs inference and
-     evaluation on stock datasets, report performance as a table and a figure:
+  1. Trains a pasa model with montgomery dataset, on the CPU, for only two
+     epochs, then runs inference and evaluation on stock datasets, report
+     performance as a table and a figure:
 
      .. code:: sh
 
-        $ ptbench experiment -vv pasa shenzhen --epochs=2
+        $ ptbench experiment -vv pasa montgomery --epochs=2
 """,
 )
 @click.option(
@@ -199,36 +198,12 @@ from .utils import save_sh_command
     "-B/-N",
     help="""If set, then balances weights of the random sampler during
     training, so that samples from all sample classes are picked picked
-    equitably.  It also sets the training (and validation) losses to account
-    for the populations of each class.""",
+    equitably.""",
     required=True,
     show_default=True,
     default=True,
     cls=ResourceOption,
 )
-@click.option(
-    "--steps",
-    "-S",
-    help="This number is used to define the number of threshold steps to "
-    "consider when evaluating the highest possible F1-score on test data.",
-    default=1000,
-    show_default=True,
-    required=True,
-    cls=ResourceOption,
-)
-@click.option(
-    "--plot-limits",
-    "-L",
-    help="""If set, this option affects the performance comparison plots.  It
-    must be a 4-tuple containing the bounds of the plot for the x and y axis
-    respectively (format: x_low, x_high, y_low, y_high]).  If not set, use
-    normal bounds ([0, 1, 0, 1]) for the performance curve.""",
-    default=[0.0, 1.0, 0.0, 1.0],
-    show_default=True,
-    nargs=4,
-    type=float,
-    cls=ResourceOption,
-)
 @verbosity_option(logger=logger, cls=ResourceOption)
 @click.pass_context
 def experiment(
@@ -248,8 +223,7 @@ def experiment(
     monitoring_interval,
     resume_from,
     balance_classes,
-    steps,
-    **kwargs,
+    **_,
 ):
     """Runs a complete experiment, from training, to prediction and evaluation.
 
@@ -260,10 +234,11 @@ def experiment(
         \b
        └─ <output-folder>/
           ├── command
-          ├── model/  #the generated model will be here
-          ├── predictions/  #the prediction outputs for the sets
-          └── evaluations/  #the outputs of the evaluations for the sets
+          ├── model/  # the generated model will be here
+          ├── predictions/  # the prediction outputs for the sets
+          └── evaluation/  # the outputs of the evaluations for the sets
     """
+    from .utils import save_sh_command
 
     command_sh = os.path.join(output_folder, "command.sh")
     if os.path.exists(command_sh):
@@ -342,15 +317,13 @@ def experiment(
 
     from .evaluate import evaluate
 
-    evaluations_folder = os.path.join(output_folder, "evaluations")
+    evaluations_folder = os.path.join(output_folder, "evaluation")
 
     ctx.invoke(
         evaluate,
         output_folder=evaluations_folder,
-        predictions_folder=predictions_folder,
-        datamodule=datamodule,
-        threshold="train",
-        steps=steps,
+        predictions=os.path.join(predictions_folder, "predictions.json"),
+        threshold="validation",
     )
 
     logger.info("Ended evaluating")
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 11ac8e07..24b77c7a 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -194,8 +194,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     "-B/-N",
     help="""If set, then balances weights of the random sampler during
     training, so that samples from all sample classes are picked picked
-    equitably.  It also sets the training (and validation) losses to account
-    for the populations of each class.""",
+    equitably.""",
     required=True,
     show_default=True,
     default=True,
-- 
GitLab