diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index bb7fedab436e4ec68db632c05194db2bea61ae16..0db5386cba01a34724b08c714c7e999983f90a3d 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -7,6 +7,8 @@ import logging
 import os
 import shutil
 
+import torch
+
 from pytorch_lightning import Trainer
 from pytorch_lightning.callbacks import ModelCheckpoint
 from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
@@ -18,6 +20,61 @@ from .callbacks import LoggingCallback
 logger = logging.getLogger(__name__)
 
 
+class AcceleratorProcessor:
+    """This class is used to convert torch devices into lightning accelerators
+    and vice versa, as they do not use the same conventions."""
+
+    def __init__(self):
+        # Note: "auto" is a valid accelerator in lightning, but there doesn't seem to be a way to check which accelerator it will actually use so we don't take it into account for now.
+        self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu"}
+        self.lightning_to_torch = {
+            v: k for k, v in self.torch_to_lightning.items()
+        }
+        self.valid_accelerators = set(
+            list(self.torch_to_lightning.keys())
+            + list(self.lightning_to_torch.keys())
+        )
+
+    def _split_accelerator_name(self, accelerator_name):
+        split_accelerator = accelerator_name.split(":")
+        accelerator = split_accelerator[0]
+
+        if len(split_accelerator) > 1:
+            devices = split_accelerator[1:]
+        else:
+            devices = "auto"
+
+        return accelerator, devices
+
+    def to_torch(self, accelerator_name):
+        accelerator_name, devices = self._split_accelerator_name(
+            accelerator_name
+        )
+
+        assert accelerator_name in self.valid_accelerators
+
+        if accelerator_name in self.lightning_to_torch:
+            return self.lightning_to_torch[accelerator_name], devices
+        elif accelerator_name in self.torch_to_lightning:
+            return accelerator_name, devices
+        else:
+            raise ValueError("Unknown accelerator.")
+
+    def to_lightning(self, accelerator_name):
+        accelerator_name, devices = self._split_accelerator_name(
+            accelerator_name
+        )
+
+        assert accelerator_name in self.valid_accelerators
+
+        if accelerator_name in self.torch_to_lightning:
+            return self.lightning_to_torch[accelerator_name], devices
+        elif accelerator_name in self.lightning_to_torch:
+            return accelerator_name, devices
+        else:
+            raise ValueError("Unknown accelerator.")
+
+
 def check_gpu(device):
     """Check the device type and the availability of GPU.
 
@@ -27,7 +84,7 @@ def check_gpu(device):
     device : :py:class:`torch.device`
         device to use
     """
-    if device.type == "cuda":
+    if device == "cuda":
         # asserts we do have a GPU
         assert bool(
             gpu_constants()
@@ -78,7 +135,7 @@ def static_information_to_csv(static_logfile_name, device, n):
         shutil.move(static_logfile_name, backup)
     with open(static_logfile_name, "w", newline="") as f:
         logdata = cpu_constants()
-        if device.type == "cuda":
+        if device == "cuda":
             logdata += gpu_constants()
         logdata += (("model_size", n),)
         logwriter = csv.DictWriter(f, fieldnames=[k[0] for k in logdata])
@@ -142,9 +199,7 @@ def create_logfile_fields(valid_loader, extra_valid_loaders, device):
         logfile_fields += ("validation_loss",)
     if extra_valid_loaders:
         logfile_fields += ("extra_validation_losses",)
-    logfile_fields += tuple(
-        ResourceMonitor.monitored_keys(device.type == "cuda")
-    )
+    logfile_fields += tuple(ResourceMonitor.monitored_keys(device == "cuda"))
     return logfile_fields
 
 
@@ -154,7 +209,7 @@ def run(
     valid_loader,
     extra_valid_loaders,
     checkpoint_period,
-    device,
+    accelerator,
     arguments,
     output_folder,
     monitoring_interval,
@@ -190,8 +245,8 @@ def run(
         save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
         not save intermediary checkpoints
 
-    device : :py:class:`torch.device`
-        device to use
+    accelerator : str`
+        accelerator to use
 
     arguments : dict
         start and end epochs
@@ -213,7 +268,9 @@ def run(
 
     max_epoch = arguments["max_epoch"]
 
-    check_gpu(device)
+    accelerator_processor = AcceleratorProcessor()
+
+    check_gpu(accelerator_processor.to_torch(accelerator)[0])
 
     os.makedirs(output_folder, exist_ok=True)
 
@@ -225,7 +282,7 @@ def run(
 
     resource_monitor = ResourceMonitor(
         interval=monitoring_interval,
-        has_gpu=(device.type == "cuda"),
+        has_gpu=torch.cuda.is_available(),
         main_pid=os.getpid(),
         logging_level=logging.ERROR,
     )
@@ -244,12 +301,14 @@ def run(
 
     # write static information to a CSV file
     static_logfile_name = os.path.join(output_folder, "constants.csv")
-    static_information_to_csv(static_logfile_name, device, n)
+    static_information_to_csv(
+        static_logfile_name, accelerator_processor.to_torch(accelerator)[0], n
+    )
 
     with resource_monitor:
         trainer = Trainer(
-            accelerator="auto",
-            devices="auto",
+            accelerator=accelerator_processor.to_torch(accelerator)[0],
+            devices=accelerator_processor.to_torch(accelerator)[1],
             max_epochs=max_epoch,
             accumulate_grad_batches=batch_chunk_count,
             logger=[csv_logger, tensorboard_logger],
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index 64f66a6a03c713521463907c340027948c0a07ad..b6ef37bcade2708b5507eef5f31f27b8eee25bc7 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -13,47 +13,6 @@ from pytorch_lightning import seed_everything
 logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
 
 
-def setup_pytorch_device(name):
-    """Sets-up the pytorch device to use.
-
-    Parameters
-    ----------
-
-    name : str
-        The device name (``cpu``, ``cuda:0``, ``cuda:1``, and so on).  If you
-        set a specific cuda device such as ``cuda:1``, then we'll make sure it
-        is currently set.
-
-
-    Returns
-    -------
-
-    device : :py:class:`torch.device`
-        The pytorch device to use, pre-configured (and checked)
-    """
-    import torch
-
-    if name.startswith("cuda:"):
-        # In case one has multiple devices, we must first set the one
-        # we would like to use so pytorch can find it.
-        logger.info(f"User set device to '{name}' - trying to force device...")
-        os.environ["CUDA_VISIBLE_DEVICES"] = name.split(":", 1)[1]
-        if not torch.cuda.is_available():
-            raise RuntimeError(
-                f"CUDA is not currently available, but "
-                f"you set device to '{name}'"
-            )
-        # Let pytorch auto-select from environment variable
-        return torch.device("cuda")
-
-    elif name.startswith("cuda"):  # use default device
-        logger.info(f"User set device to '{name}' - using default CUDA device")
-        assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None
-
-    # cuda or cpu
-    return torch.device(name)
-
-
 def set_reproducible_cuda():
     """Turns-off all CUDA optimizations that would affect reproducibility.
 
@@ -217,12 +176,12 @@ def set_reproducible_cuda():
     cls=ResourceOption,
 )
 @click.option(
-    "--device",
-    "-d",
-    help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
+    "--accelerator",
+    "-a",
+    help='A string indicating the accelerator to use (e.g. "auto", "cpu" or "gpu"). If auto, will select the best one available',
     show_default=True,
     required=True,
-    default="cpu",
+    default="auto",
     cls=ResourceOption,
 )
 @click.option(
@@ -294,7 +253,7 @@ def train(
     criterion_valid,
     dataset,
     checkpoint_period,
-    device,
+    accelerator,
     seed,
     parallel,
     normalization,
@@ -323,8 +282,6 @@ def train(
     from ..configs.datasets import get_positive_weights, get_samples_weights
     from ..engine.trainer import run
 
-    device = setup_pytorch_device(device)
-
     seed_everything(seed)
 
     use_dataset = dataset
@@ -517,7 +474,7 @@ def train(
         valid_loader=valid_loader,
         extra_valid_loaders=extra_valid_loaders,
         checkpoint_period=checkpoint_period,
-        device=device,
+        accelerator=accelerator,
         arguments=arguments,
         output_folder=output_folder,
         monitoring_interval=monitoring_interval,