diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py index 896ce9a3ed7440dba20708b5c63524e5788a3dbb..723c6d66cef78c84021589ee1f5b1718d5f73679 100644 --- a/src/ptbench/engine/trainer.py +++ b/src/ptbench/engine/trainer.py @@ -42,6 +42,7 @@ class AcceleratorProcessor: if len(split_accelerator) > 1: devices = split_accelerator[1:] devices = [int(i) for i in devices] + os.environ["CUDA_VISIBLE_DEVICES"] = devices else: devices = "auto"