From 26c20533d96c5e19968c5c8aa311d6f930617f9a Mon Sep 17 00:00:00 2001 From: Tim Laibacher <tim.laibacher@idiap.ch> Date: Wed, 29 May 2019 09:34:37 +0200 Subject: [PATCH] Fix optim checkpoint loading --- bob/ip/binseg/engine/ssltrainer.py | 6 +++++- bob/ip/binseg/engine/trainer.py | 4 ++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index b0489558..54f85194 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -136,7 +136,11 @@ def do_ssltrain( # Logg to file with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile: - + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device) + model.train().to(device) # Total training timer start_training_time = time.time() diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index a260ab0d..e083ec85 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -58,6 +58,10 @@ def do_train( with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile: model.train().to(device) + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device) # Total training timer start_training_time = time.time() -- GitLab