From f739cde10ca7061696fce4757200983858e894ad Mon Sep 17 00:00:00 2001
From: dcarron <daniel.carron@idiap.ch>
Date: Tue, 4 Jul 2023 11:46:15 +0200
Subject: [PATCH] Added save_sh_command

---
 src/ptbench/engine/trainer.py        |  5 ++
 src/ptbench/utils/save_sh_command.py | 74 ++++++++++++++++++++++++++++
 2 files changed, 79 insertions(+)
 create mode 100644 src/ptbench/utils/save_sh_command.py

diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index ecf29153..7643d4ae 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -14,6 +14,7 @@ import torch.nn
 
 from ..utils.accelerator import AcceleratorProcessor
 from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
+from ..utils.save_sh_command import save_sh_command
 from .callbacks import LoggingCallback
 
 logger = logging.getLogger(__name__)
@@ -228,6 +229,10 @@ def run(
     # Save model summary
     _, n = save_model_summary(output_folder, model)
 
+    save_sh_command(output_folder)
+
+    # save_sh_command(os.path.join(output_folder, "cmd_line_config.txt"))
+
     csv_logger = lightning.pytorch.loggers.CSVLogger(output_folder, "logs_csv")
     tensorboard_logger = lightning.pytorch.loggers.TensorBoardLogger(
         output_folder, "logs_tensorboard"
diff --git a/src/ptbench/utils/save_sh_command.py b/src/ptbench/utils/save_sh_command.py
new file mode 100644
index 00000000..e0a7d379
--- /dev/null
+++ b/src/ptbench/utils/save_sh_command.py
@@ -0,0 +1,74 @@
+import glob
+import logging
+import os
+import sys
+import time
+
+import pkg_resources
+
+logger = logging.getLogger(__name__)
+
+
+def save_sh_command(output_dir):
+    """Records command-line to reproduce this experiment.
+
+    This function can record the current command-line used to call the script
+    being run.  It creates an executable ``bash`` script setting up the current
+    working directory and activating a conda environment, if needed.  It
+    records further information on the date and time the script was run and the
+    version of the package.
+
+
+    Parameters
+    ----------
+
+    output_folder : str
+        Path leading to the directory where the commands to reproduce the current
+        run will be recorded. A subdirectory will be created each time this function
+        is called to match lightning's versioning convention for loggers.
+    """
+
+    cmd_config_dir = os.path.join(output_dir, "cmd_line_configs")
+    cmd_config_versions = glob.glob(os.path.join(cmd_config_dir, "version_*"))
+    if len(cmd_config_versions) > 0:
+        latest_cmd_config_version = max(
+            [
+                int(config.split("version_")[-1])
+                for config in cmd_config_versions
+            ]
+        )
+        current_cmd_config_version = str(latest_cmd_config_version + 1)
+    else:
+        current_cmd_config_version = "0"
+
+    destfile = os.path.join(
+        cmd_config_dir,
+        f"version_{current_cmd_config_version}",
+        "cmd_line_config.txt",
+    )
+
+    if os.path.exists(destfile):
+        logger.info(f"Not overwriting existing file '{destfile}'")
+        return
+
+    logger.info(f"Writing command-line for reproduction at '{destfile}'...")
+    os.makedirs(os.path.dirname(destfile), exist_ok=True)
+
+    with open(destfile, "w") as f:
+        f.write("#!/usr/bin/env sh\n")
+        f.write(f"# date: {time.asctime()}\n")
+        version = pkg_resources.require("ptbench")[0].version
+        f.write(f"# version: {version} (deepdraw)\n")
+        f.write(f"# platform: {sys.platform}\n")
+        f.write("\n")
+        args = []
+        for k in sys.argv:
+            if " " in k:
+                args.append(f'"{k}"')
+            else:
+                args.append(k)
+        if os.environ.get("CONDA_DEFAULT_ENV") is not None:
+            f.write(f"#conda activate {os.environ['CONDA_DEFAULT_ENV']}\n")
+        f.write(f"#cd {os.path.realpath(os.curdir)}\n")
+        f.write(" ".join(args) + "\n")
+    os.chmod(destfile, 0o755)
-- 
GitLab