Skip to content
Snippets Groups Projects
Commit 759c7e04 authored by Daniel CARRON's avatar Daniel CARRON :b: Committed by André Anjos
Browse files

Added save_sh_command

parent 11b65a40
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -14,6 +14,7 @@ import torch.nn ...@@ -14,6 +14,7 @@ import torch.nn
from ..utils.accelerator import AcceleratorProcessor from ..utils.accelerator import AcceleratorProcessor
from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
from ..utils.save_sh_command import save_sh_command
from .callbacks import LoggingCallback from .callbacks import LoggingCallback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -228,6 +229,10 @@ def run( ...@@ -228,6 +229,10 @@ def run(
# Save model summary # Save model summary
_, n = save_model_summary(output_folder, model) _, n = save_model_summary(output_folder, model)
save_sh_command(output_folder)
# save_sh_command(os.path.join(output_folder, "cmd_line_config.txt"))
csv_logger = lightning.pytorch.loggers.CSVLogger(output_folder, "logs_csv") csv_logger = lightning.pytorch.loggers.CSVLogger(output_folder, "logs_csv")
tensorboard_logger = lightning.pytorch.loggers.TensorBoardLogger( tensorboard_logger = lightning.pytorch.loggers.TensorBoardLogger(
output_folder, "logs_tensorboard" output_folder, "logs_tensorboard"
......
import glob
import logging
import os
import sys
import time
import pkg_resources
logger = logging.getLogger(__name__)
def save_sh_command(output_dir):
"""Records command-line to reproduce this experiment.
This function can record the current command-line used to call the script
being run. It creates an executable ``bash`` script setting up the current
working directory and activating a conda environment, if needed. It
records further information on the date and time the script was run and the
version of the package.
Parameters
----------
output_folder : str
Path leading to the directory where the commands to reproduce the current
run will be recorded. A subdirectory will be created each time this function
is called to match lightning's versioning convention for loggers.
"""
cmd_config_dir = os.path.join(output_dir, "cmd_line_configs")
cmd_config_versions = glob.glob(os.path.join(cmd_config_dir, "version_*"))
if len(cmd_config_versions) > 0:
latest_cmd_config_version = max(
[
int(config.split("version_")[-1])
for config in cmd_config_versions
]
)
current_cmd_config_version = str(latest_cmd_config_version + 1)
else:
current_cmd_config_version = "0"
destfile = os.path.join(
cmd_config_dir,
f"version_{current_cmd_config_version}",
"cmd_line_config.txt",
)
if os.path.exists(destfile):
logger.info(f"Not overwriting existing file '{destfile}'")
return
logger.info(f"Writing command-line for reproduction at '{destfile}'...")
os.makedirs(os.path.dirname(destfile), exist_ok=True)
with open(destfile, "w") as f:
f.write("#!/usr/bin/env sh\n")
f.write(f"# date: {time.asctime()}\n")
version = pkg_resources.require("ptbench")[0].version
f.write(f"# version: {version} (deepdraw)\n")
f.write(f"# platform: {sys.platform}\n")
f.write("\n")
args = []
for k in sys.argv:
if " " in k:
args.append(f'"{k}"')
else:
args.append(k)
if os.environ.get("CONDA_DEFAULT_ENV") is not None:
f.write(f"#conda activate {os.environ['CONDA_DEFAULT_ENV']}\n")
f.write(f"#cd {os.path.realpath(os.curdir)}\n")
f.write(" ".join(args) + "\n")
os.chmod(destfile, 0o755)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment