Reviewed DataModule design+docs+types
In this MR, I try to review the use of the data module and go further in simplifications for the whole code:
- Slightly streamlined the datamodule approach
- Added documentation
- Added type annotations
- Added some TODOs for further discussion
Closes #7 (closed).
TODO:
-
(@dcarron) Implement/document/type hint normaliser (ensure it is used correctly through all the models; integrate on the train script) and that train CLI works! -
(@andre.anjos) Deal with import reweight_BCEWithLogitsLoss
Merge request reports
Activity
assigned to @andre.anjos
requested review from @dcarron
@dcarron: let's use this as basis for our work tomorrow.
added 3 commits
-
423d2682...6b6196a0 - 2 commits from branch
add-datamodule
- a67626d8 - [datamodule] Slightly streamlines the datamodule approach; adds documentation;...
-
423d2682...6b6196a0 - 2 commits from branch
@dcarron: I just merged stuff I had on disk. pre-commit is not 100% happy (99% only), but that gives you a solid base to do your part.
- Resolved by André Anjos
@dcarron, I started doing my bit, but I got confused about the current state of the code. I hope you can clarify my doubts.
If I look at the densenet model, I see both
criterion
andcriterion_valid
are not being saved on the lightning-module's hyper-parameters (they are explicitly excluded in the constructor). Then, if I looked at the Pasa model, this isn't the case - these parameters are saved. The config files for both these models pass criteria parameters via their constructor. How are these passed, finally, to the place they are needed in the densenet case? Is that just a bug? I clearly don't understand enough about lightning to make sense of all this.Optimiser: I see we pass
optimiser
as a string to the method - it statesAdam
in the case of the config file for the PASA model, alongside its parameters in a dictionary. Inside the model code there is something like:optimizer = getattr(torch.optim, self.hparams.optimizer)( self.parameters(), **self.hparams.optimizer_configs )
I'm wondering why we need to mangle this in this way instead of passing the instantiated optimiser during the construction of the model. The lightning module does not like saving torch optimisers? However, criteria (losses) are OK?
After some local tests, I realised
lightning
actually saves the hyper-parameters in a YAML file, so it is hard to save complex objects. It has to be limited to what an YAML file can stand.Our configuration system (using software/clapper>) has some advantages w.r.t. this, like being able to snapshot the input configuration for the training session as a whole, and supporting complex objects.
In other frameworks, we are able to register the command-line used to run a training session (reference implementation here: https://gitlab.idiap.ch/biosignal/software/deepdraw/-/blob/master/src/deepdraw/script/common.py#L140). We could copy that here, and slightly extend its functionality to also copy any modules passed as input. We then "capture" the whole parameter set. In this case, I'd propose we simply not bother about this particular feature of lightning (i.e. we never use
self.save_hyperparameters()
).Edited by André AnjosThe code to use this on the
train
script looks like this:# called on the CLICK main script # output_folder is the root folder where we save all info command_sh = os.path.join(output_folder, "train.sh") if os.path.exists(command_sh): backup = command_sh + "~" if os.path.exists(backup): os.unlink(backup) shutil.move(command_sh, backup) save_sh_command(command_sh)
We can extend
save_sh_command()
to also save any eventual clapper config files used by the user. Let me check how to do this.Meeting decisions:
- We go the
clapper
way regarding configuration - do not useself.save_hyperparameters()
from lightning any more! If need be, keep instance variables for loss and optimiser. We pass loss and optimiser already pre-built during model construction - Use
save_sh_command()
to safe keep the command-line used on the application. We'd need to access the place where the logs are saved to be able to also store the command-line call. - Change
train.py
so thatset_input_normalization()
is only called iff the module implements it. Otherwise, a warning is issued. - We changed the data module to not apply augmentations - according to the pytorch/torchvision design, these should come with the model and be applied at that level. We need to apply augmentations with the models and within the training step. This makes data-augmentations as part of the model instead of the data module, which works better with the way pytorch/torchvision is implemented.
- We go the
added 2 commits
Since lightning saves logs in subfolders named "version_n" each time training is run, it would make sense to do the same when using
save_sh_command()
. That way we havelogs_csv/version_n
,logs_tensorboard/version_n
and the correspondingcmd_line_configs/version_n
that generated them.I modified
save_sh_command()
in f739cde1 to do that. Version numbers will however start to diverge if training is cancelled after callingsave_sh_command()
but before lightning loggers are called at the end of the first epoch.It would be nice to group all files for a particular version together instead of having different log directories but that does not seem possible due to how lightning logging works.
I took care of the points mentioned above, as well as implemented
set_bce_loss_weights
and fixed minor issues to make the model training possible. There are definitely improvements to be made and bugs to fix.Training loss is converging a bit more slowly and validation loss is much higher than before the changes.
Just noticed that no normalization was applied by default before the changes. The results are comparable in the old and updated codebase when normalization is applied.
Normalization mean and std differ slightly after the update (0.0006 diff in mean, 0.0074 diff in std for shenzhen default). This could be explained by a difference in how image transforms are applied. I also checked the computed criterion weights and they are identical.
added 7 commits
- 87285a8e - Average validation loss instead of adding it
- d9105467 - Make ElastiCDeformation work with both greayscale and rgb images
- 2baa8e0b - Change 'no set_normalizer' logger info to warning
- 4ccba39c - Functional densenet model
- ef7494b7 - Added util to save images
- dd3c5fba - Update pretrained densenet configs
- 79187009 - Update alexnet model
Toggle commit listadded 6 commits
- a69868bf - [doc/references] Add header
- 0827ccca - [ptbench.data.datamodule] Implemented typing, added more logging, implemented...
- dccc1da3 - [ptbench.engine] Simplified, documented and created type hints for the...
- 99f52320 - [ptbench.scripts] Improved docs, adapt changes from weight-balancing strategy
- a0f264f0 - [ptbench.utils.accelerator] Add support for mps backend
- c25e5008 - [ptbench.models.pasa] Define new API for modules
Toggle commit listThanks for all the work so far. I complemented like so:
- I created a
ptbench.data.typing
module with all type definitions we use through out the code. Then, I implemented the use of those types where I could find. I also refined the type definitions throughout, with some insights from pytorch. - I added a config option to
train.py
so that the user can trigger sample/class balancing from the CLI. If the user triggers it, then both datamodule and model->losses (train and validation) are set - I fixed some of the documentation of the
train.py
script, that was outdated. - I improved the logging in several parts of the code, so it is more explicit
- I improved type hinting on
ptbench.data.raw_data_loader
module. I also moved that toptbench.data.image_utils
as that is what it is for now, at least. - I implemented a method on data-modules to trigger the creation of the random-weighted-sampler on-the-fly, depending on a user accessible instance flag (
balance_sampler_by_class
). - I considerably improved the documentation to understand and validate the function to calculate sample weights for the training sampler. For one moment, I thought it was wrong (but thankfully, I was wrong!)
- Regarding the
save_sh_command()
function: the objective of this function is to save the command-line so you can repeat it later. Why would anyone run the same command multiple times, with slightly different options, and with output on the same directory? Could you please exemplify a use-case where such a flexibility would be needed? Meanwhile, I simplified that back. I also moved that into the CLI script, where it makes sense instead of inside the trainer engine (??). Also, take note that if you usepkg_resources
, you need to add a dependence tosetuptools
, which is outdated. The modern way to lookup package metadata information is with Python's nativeimportlib.metadata
. - I parallelized the data loading when caching, using the same
parallel
argument from the datamodule, and some multiprocessing-magic. It is much faster now. - I created type hints for the whole
resources.py
module - I simplified, documented and created type hints for the
callbacks
module. It now better estimates the total epoch loss. - Training with the pasa model and Shenzhen default will work. The following seems to work for me:
ptbench train -vv pasa shenzhen --batch-size=4 --cache-samples --epochs=3 --parallel=6
- The code now works with the "mps" backend - This closes #7 (closed). It is about 10% faster (only!?) for the Pasa model, by avoiding numpy constructions in the callbacks
- I renamed the PASA model to Pasa.
Issues that still remain before a merge:
- We should be very careful about the loss balancing - I opened issue #6 (closed) to handle this, with some rationale.
- Please double-check that the
callbacks
module still works as it was originally intended, after my changes. - It is not clear how to navigate the versioning system Lightning has for logs.
- I'm unhappy with the
AcceleratorProcessor
design: using another accelerator requires one to re-implement mappings. Furthermore, the proposed "processor" is not really processing anything - it is more like a "translator" between what torch considers a device, and lightning, an accelerator. It should be named so, somehow so that its function is clear. I tried to use an accelerator you hadn't foreseen yesterday (mps
) and got confused on why it wasn't supported...
- I created a
marked the checklist item (@dcarron) Implement/document/type hint normaliser (ensure it is used correctly through all the models; integrate on the train script) and that train CLI works! as completed
After quickly looking at the documentation for lightning's accelerators (https://lightning.ai/docs/pytorch/stable/extensions/accelerator.html#accelerator), there seems to be a distinction between the accelerator platform and the device it runs on. You need to define both. There is an automatic setting called
devices=auto
, which tries to allocate the code into a "free" GPU. This is not very good, as in constrained environments, we can't really execute on a "free" gpu, but rather on the one selected with$CUDA_VISIBLE_DEVICES
.There are instructions that also say that
accelerator=gpu
triggers the use of themps
on M1 chips (here: https://lightning.ai/docs/pytorch/stable/accelerators/gpu_basic.html).The "automatism" of lightning makes it a bit harder to control.
I propose, instead, we focus on re-using the pytorch
device
system, as it was before:- If the user says nothing, runs on the "cpu" -> set lightning accelerator to
cpu
, devices toauto
- If the user says "cuda", runs on Nvidia GPU:
- If
$CUDA_VISIBLE_DEVICES
is set and non-empty, run on those devices -> set lightning accelerator tocuda
, devices to[$CUDA_VISIBLE_DEVICES]
- If the user does not specify a device, runs on whatever is available -> set lightning accelerator to
cuda
, devices toauto
- If the user specifies
cuda:2
, run on the third GPU device -> set lightning accelerator tocuda
, devices to[2]
- If
- If the user specifies
mps
-> set lightning accelerator tomps
, devices to1
. - If the user specifies anything else, pass-through to pytorch and lightning with devices set to
auto
.
- If the user says nothing, runs on the "cpu" -> set lightning accelerator to
marked the checklist item (@andre.anjos) Deal with import reweight_BCEWithLogitsLoss as completed
mentioned in commit 278a6198