diff --git a/bob/ip/binseg/engine/ssltrainer.py b/bob/ip/binseg/engine/ssltrainer.py index b0489558e075b7fb873d21a4b3cf50db954eacbb..54f8519471a5b770656fec4ef0714394c94f863e 100644 --- a/bob/ip/binseg/engine/ssltrainer.py +++ b/bob/ip/binseg/engine/ssltrainer.py @@ -136,7 +136,11 @@ def do_ssltrain( # Logg to file with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile: - + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device) + model.train().to(device) # Total training timer start_training_time = time.time() diff --git a/bob/ip/binseg/engine/trainer.py b/bob/ip/binseg/engine/trainer.py index a260ab0de0a6444d8449c7338d776238ec4da8cd..e083ec852491e82d2abee84421e1f14d20966a2a 100644 --- a/bob/ip/binseg/engine/trainer.py +++ b/bob/ip/binseg/engine/trainer.py @@ -58,6 +58,10 @@ def do_train( with open (os.path.join(output_folder,"{}_trainlog.csv".format(model.name)), "a+") as outfile: model.train().to(device) + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(device) # Total training timer start_training_time = time.time()