diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index 896ce9a3ed7440dba20708b5c63524e5788a3dbb..723c6d66cef78c84021589ee1f5b1718d5f73679 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -42,6 +42,7 @@ class AcceleratorProcessor:
         if len(split_accelerator) > 1:
             devices = split_accelerator[1:]
             devices = [int(i) for i in devices]
+            os.environ["CUDA_VISIBLE_DEVICES"] = devices
         else:
             devices = "auto"