Cross validation
This MR adds the following
- Adds cross-validation using a validation set; the best model based on the lowest validation loss is saved automatically now
- Adds trainer, extractor, training script and unit tests for FASNet architecture
- Small change in MCCNN architecture; added a flag for selecting whether to use sigmoid in eval phase.
Merge request reports
Activity
@heusch I have added cross validation (as an option) to my trainer scripts. Let me know your feedback.
like in tensorflow we have a config like this: https://gitlab.idiap.ch/bob/bob.bio.htface/blob/277781d9c99738ff141218e1ce04103f9a427b0c/bob/bio/htface/config/tensorflow/MSCELEBA_inception_resnet_v2_center_loss.py and a generic train script: https://gitlab.idiap.ch/bob/bob.learn.tensorflow/blob/master/bob/learn/tensorflow/script/train.py#L60
This script so far has managed all of our trainings.
Not every architecture needs a trainer here, training normal CNNs should be straight forward with existing trainers. Here its needed since it adapts only few layers in training.
I think it will be cleaner if we can move this stuff to config files and make a generic trainer script, @amohammadi was your framework able to deal with the use cases like binary, multiclass, regression, multi-task and custom loss functions?.
Separate trainers and training script can be added in the cases where the training/architecture is too complex to be handled by the generic one.
Edited by Anjith GEORGEI think this package can have one train script that calls
trainer.train(dataloader)
and everything else is setup in your config files.That was the idea at first, but things got a little bit confused lately, mainly due to time pressure.
As @ageorge mentioned, there are some models that need specific training schemes (GANs for instance). Anyway, the ultimate plan is to refactor and simplify everything, and as you said, moving specifics into config file.
was your framework able to deal with the use cases like binary, multiclass, regression, multi-task and custom loss functions?.
As long as your training process is offline you can have one script that runs all. By offline I mean there is no feedback from the model to dataset. For example a triplet loss with online selection of triplets would need its own script. Anything else, like binary, multiclass, regression, multi-task, and custom loss, all fits in one script. You might need different trainers but only one script that would call something like
trainer.train(dataloader, n_epochs=epochs)
.This is already the case if you look at the scripts here. They are all very similar and all of them end with:
trainer.train(dataloader, n_epochs=epochs, learning_rate=learning_rate, output_dir=output_dir, model=model)
247 def __getitem__(self, idx): 248 data = numpy.random.rand(3, 224,224).astype("float32") 249 label = numpy.random.randint(2) 250 sample = data, label 251 return sample 252 253 def test_FASNettrainer(): 254 255 from ..architectures import FASNet 256 net = FASNet() 257 258 dataloader={} 259 dataloader['train'] = torch.utils.data.DataLoader(DummyDataSetFASNet(), batch_size=32, shuffle=True) 260 261 from ..trainers import FASNetTrainer 262 trainer = FASNetTrainer(net, verbosity_level=3,do_crossvalidation=False) @heusch Added tests for both cases.
Not considering the right comments from @amohammadi on simplifying things a little and having less script for everything, this MR is fine with me.
Note that at the moment, the only script causing problem is
train_network.py
, which is very different from all the others. It's my plan to take care of this, see #15 (closed)added 1 commit
- c8943f5d - Adds more unit tests for trainers with CV and fuixed a typo
@heusch Can you merge it if it's ok with you. Actually I need one change in this MR for BATL tomorrow, I can make another MR with just that change if you prefer.
mentioned in commit 064ad286
mentioned in issue #11 (closed)