diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index ecf29153bc56c84185a905edc97867b0730a0a57..7643d4ae79adc646e4c5cef083810371101269de 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -14,6 +14,7 @@ import torch.nn from ..utils.accelerator import AcceleratorProcessor from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants +from ..utils.save_sh_command import save_sh_command from .callbacks import LoggingCallback logger = logging.getLogger(__name__) @@ -228,6 +229,10 @@ def run( # Save model summary _, n = save_model_summary(output_folder, model) + save_sh_command(output_folder) + + # save_sh_command(os.path.join(output_folder, "cmd_line_config.txt")) + csv_logger = lightning.pytorch.loggers.CSVLogger(output_folder, "logs_csv") tensorboard_logger = lightning.pytorch.loggers.TensorBoardLogger( output_folder, "logs_tensorboard" diff --git a/src/ptbench/utils/save_sh_command.py b/src/ptbench/utils/save_sh_command.py new file mode 100644 index 0000000000000000000000000000000000000000..e0a7d379c00caddb7ade2513668a1392e98b21f2 --- /dev/null +++ b/src/ptbench/utils/save_sh_command.py @@ -0,0 +1,74 @@ +import glob +import logging +import os +import sys +import time + +import pkg_resources + +logger = logging.getLogger(__name__) + + +def save_sh_command(output_dir): + """Records command-line to reproduce this experiment. + + This function can record the current command-line used to call the script + being run. It creates an executable ``bash`` script setting up the current + working directory and activating a conda environment, if needed. It + records further information on the date and time the script was run and the + version of the package. + + + Parameters + ---------- + + output_folder : str + Path leading to the directory where the commands to reproduce the current + run will be recorded. A subdirectory will be created each time this function + is called to match lightning's versioning convention for loggers. + """ + + cmd_config_dir = os.path.join(output_dir, "cmd_line_configs") + cmd_config_versions = glob.glob(os.path.join(cmd_config_dir, "version_*")) + if len(cmd_config_versions) > 0: + latest_cmd_config_version = max( + [ + int(config.split("version_")[-1]) + for config in cmd_config_versions + ] + ) + current_cmd_config_version = str(latest_cmd_config_version + 1) + else: + current_cmd_config_version = "0" + + destfile = os.path.join( + cmd_config_dir, + f"version_{current_cmd_config_version}", + "cmd_line_config.txt", + ) + + if os.path.exists(destfile): + logger.info(f"Not overwriting existing file '{destfile}'") + return + + logger.info(f"Writing command-line for reproduction at '{destfile}'...") + os.makedirs(os.path.dirname(destfile), exist_ok=True) + + with open(destfile, "w") as f: + f.write("#!/usr/bin/env sh\n") + f.write(f"# date: {time.asctime()}\n") + version = pkg_resources.require("ptbench")[0].version + f.write(f"# version: {version} (deepdraw)\n") + f.write(f"# platform: {sys.platform}\n") + f.write("\n") + args = [] + for k in sys.argv: + if " " in k: + args.append(f'"{k}"') + else: + args.append(k) + if os.environ.get("CONDA_DEFAULT_ENV") is not None: + f.write(f"#conda activate {os.environ['CONDA_DEFAULT_ENV']}\n") + f.write(f"#cd {os.path.realpath(os.curdir)}\n") + f.write(" ".join(args) + "\n") + os.chmod(destfile, 0o755)