diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py
index b0489558e075b7fb873d21a4b3cf50db954eacbb..54f8519471a5b770656fec4ef0714394c94f863e 100644
--- a/bob/ip/binseg/engine/ssltrainer.py
+++ b/bob/ip/binseg/engine/ssltrainer.py
@@ -136,7 +136,11 @@ def do_ssltrain(
 
     # Logg to file
     with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile:
-        
+        for state in optimizer.state.values():
+            for k, v in state.items():
+                if isinstance(v, torch.Tensor):
+                    state[k] = v.to(device)
+
         model.train().to(device)
         # Total training timer
         start_training_time = time.time()
diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py
index a260ab0de0a6444d8449c7338d776238ec4da8cd..e083ec852491e82d2abee84421e1f14d20966a2a 100644
--- a/bob/ip/binseg/engine/trainer.py
+++ b/bob/ip/binseg/engine/trainer.py
@@ -58,6 +58,10 @@ def do_train(
     with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile:
         
         model.train().to(device)
+        for state in optimizer.state.values():
+            for k, v in state.items():
+                if isinstance(v, torch.Tensor):
+                    state[k] = v.to(device)
         # Total training timer
         start_training_time = time.time()