From 26c20533d96c5e19968c5c8aa311d6f930617f9a Mon Sep 17 00:00:00 2001
From: Tim Laibacher <tim.laibacher@idiap.ch>
Date: Wed, 29 May 2019 09:34:37 +0200
Subject: [PATCH] Fix optim checkpoint loading

---
 bob/ip/binseg/engine/ssltrainer.py | 6 +++++-
 bob/ip/binseg/engine/trainer.py    | 4 ++++
 2 files changed, 9 insertions(+), 1 deletion(-)

diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index b0489558..54f85194 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -136,7 +136,11 @@ def do_ssltrain(
 
     # Logg to file
     with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile:
-        
+        for state in optimizer.state.values():
+            for k, v in state.items():
+                if isinstance(v, torch.Tensor):
+                    state[k] = v.to(device)
+
         model.train().to(device)
         # Total training timer
         start_training_time = time.time()
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index a260ab0d..e083ec85 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -58,6 +58,10 @@ def do_train(
     with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile:
         
         model.train().to(device)
+        for state in optimizer.state.values():
+            for k, v in state.items():
+                if isinstance(v, torch.Tensor):
+                    state[k] = v.to(device)
         # Total training timer
         start_training_time = time.time()
 
-- 
GitLab