Newer
Older
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
from typing import Any, Optional, Union
from lightning.fabric.utilities.types import _PATH
from lightning.pytorch.loggers import TensorBoardLogger
class CustomTensorboardLogger(TensorBoardLogger):
r"""Custom implementation implementation of
lightning.pytorch.loggers.TensorBoardLogger.
Allows us to put all logs inside the same directory, instead of a separate "version_n" directory which is the default behaviour.
Parameters
----------
save_dir:
Save directory
name:
Experiment name. Defaults to ``'default'``. If it is the empty string then no per-experiment
subdirectory is used.
version:
Experiment version. If version is not specified the logger inspects the save
directory for existing versions, then automatically assigns the next available version.
If it is a string then it is used as the run-specific subdirectory name,
otherwise ``'version_${version}'`` is used.
log_graph:
Adds the computational graph to tensorboard. This requires that
the user has defined the `self.example_input_array` attribute in their
model.
default_hp_metric:
Enables a placeholder metric with key `hp_metric` when `log_hyperparams` is
called without a metric (otherwise calls to log_hyperparams without a metric are ignored).
prefix:
A string to put at the beginning of metric keys.
sub_dir:
Sub-directory to group TensorBoard logs. If a sub_dir argument is passed
then logs are saved in ``/save_dir/name/version/sub_dir/``. Defaults to ``None`` in which
logs are saved in ``/save_dir/name/version/``.
\**kwargs:
Additional arguments used by :class:`tensorboardX.SummaryWriter` can be passed as keyword
arguments in this logger. To automatically flush to disk, `max_queue` sets the size
of the queue for pending logs before flushing. `flush_secs` determines how many seconds
elapses before flushing.
"""
def __init__(
self,
save_dir: _PATH,
name: Optional[str] = "lightning_logs",
version: Optional[Union[int, str]] = None,
log_graph: bool = False,
default_hp_metric: bool = True,
prefix: str = "",
sub_dir: Optional[_PATH] = None,
**kwargs: Any,
):
super().__init__(
save_dir,
name,
version,
log_graph,
default_hp_metric,
prefix,
sub_dir,
)
@property
def log_dir(self) -> str:
return os.path.join(self.save_dir, self.name)