diff --git a/src/ptbench/engine/device.py b/src/ptbench/engine/device.py
new file mode 100644
index 0000000000000000000000000000000000000000..253bba0d9da3bedca6010b2ad87937b9da4f08e0
--- /dev/null
+++ b/src/ptbench/engine/device.py
@@ -0,0 +1,150 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import logging
+import os
+
+import torch
+import torch.backends
+
+logger = logging.getLogger(__name__)
+
+
+def _split_int_list(s: str) -> list[int]:
+    """Splits a list of integers encoded in a string (e.g. "1,2,3") into a
+    Python list of integers (e.g. ``[1, 2, 3]``)."""
+    return [int(k.strip()) for k in s.split(",")]
+
+
+class DeviceManager:
+    """This class is used to manage the Lightning Accelerator and Pytorch
+    Devices.
+
+    It takes the user input, in the form of a string defined by
+    ``[\\S+][:\\d[,\\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``), and can
+    translate to the right incarnation of Pytorch devices or Lightning
+    Accelerators to interface with the various frameworks.
+
+    Instances of this class also manage the environment variable
+    ``$CUDA_VISIBLE_DEVICES`` if necessary.
+
+
+    Parameters
+    ----------
+
+    name
+        The name of the device to use, in the form of a string defined by
+        ``[\\S+][:\\d[,\\d]?]?`` (e.g.: ``cpu``, ``mps``, or ``cuda:3``).  In
+        the specific case of ``cuda``, one can also specify a device to use
+        either by adding ``:N``, where N is the zero-indexed board number on
+        the computer, or by setting the environment variable
+        ``$CUDA_VISIBLE_DEVICES`` with the devices that are usable by the
+        current process.
+    """
+
+    SUPPORTED = ("cpu", "cuda", "mps")
+
+    def __init__(self, name: str):
+        parts = name.split(":", 1)
+        self.device_type = parts[0]
+        self.device_ids: list[int] = []
+        if len(parts) > 1:
+            self.device_ids = _split_int_list(parts[1])
+
+        if self.device_type == "cuda":
+            visible_env = os.environ.get("CUDA_VISIBLE_DEVICES")
+            if visible_env:
+                visible = _split_int_list(visible_env)
+                if self.device_ids and visible != self.device_ids:
+                    logger.warning(
+                        f"${{CUDA_VISIBLE_DEVICES}}={visible} and name={name} "
+                        f"- overriding environment with value set on `name`"
+                    )
+                else:
+                    self.device_ids = visible
+
+            # make sure that it is consistent with the environment
+            if self.device_ids:
+                os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
+                    [str(k) for k in self.device_ids]
+                )
+
+        if self.device_type not in DeviceManager.SUPPORTED:
+            raise RuntimeError(
+                f"Unsupported device type `{self.device_type}`. "
+                f"Supported devices types are `{', '.join(DeviceManager.SUPPORTED)}`"
+            )
+
+        if self.device_ids and self.device_type in ("cpu", "mps"):
+            logger.warning(
+                f"Cannot pin device ids if using cpu or mps backend. "
+                f"Setting `name` to {name} is non-sensical.  Ignoring..."
+            )
+
+        # check if the device_type that was set has support compiled in
+        if self.device_type == "cuda":
+            assert hasattr(torch, "cuda") and torch.cuda.is_available(), (
+                f"User asked for device = `{name}`, but CUDA support is "
+                f"not compiled into pytorch!"
+            )
+
+        if self.device_type == "mps":
+            assert (
+                hasattr(torch.backends, "mps")
+                and torch.backends.mps.is_available()  # type:ignore
+            ), (
+                f"User asked for device = `{name}`, but MPS support is "
+                f"not compiled into pytorch!"
+            )
+
+    def torch_device(self) -> torch.device:
+        """Returns a representation of the torch device to use by default.
+
+        .. warning::
+
+           If a list of devices is set, then this method only returns the first
+           device.  This may impact Nvidia GPU logging in the case multiple
+           GPU cards are used.
+
+
+        Returns
+        -------
+
+        device
+            The **first** torch device (if a list of ids is set).
+        """
+
+        if self.device_type in ("cpu", "mps"):
+            return torch.device(self.device_type)
+        elif self.device_type == "cuda":
+            if not self.device_ids:
+                return torch.device(self.device_type)
+            else:
+                return torch.device(self.device_type, self.device_ids[0])
+
+        # if you get to this point, this is an unexpected RuntimeError
+        raise RuntimeError(
+            f"Unexpected device type {self.device_type} lacks support"
+        )
+
+    def lightning_accelerator(self) -> tuple[str, int | list[int] | str | None]:
+        """Returns the lightning accelerator setup.
+
+        Returns
+        -------
+
+        accelerator
+            The lightning accelerator to use
+
+        devices
+            The lightning devices to use
+        """
+
+        devices: int | list[int] | str = self.device_ids
+        if not devices:
+            devices = "auto"
+        elif self.device_type == "mps":
+            devices = 1
+
+        return self.device_type, devices
diff --git a/src/ptbench/engine/trainer.py b/src/ptbench/engine/trainer.py
index ac92d3c154191121adc92988d276c2e71a0ba9c7..10121af1091a0e3a186efd4c338f4af0ad4b9cf5 100644
--- a/src/ptbench/engine/trainer.py
+++ b/src/ptbench/engine/trainer.py
@@ -12,7 +12,6 @@ import lightning.pytorch.callbacks
 import lightning.pytorch.loggers
 import torch.nn
 
-from ..utils.accelerator import AcceleratorProcessor
 from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
 from .callbacks import LoggingCallback
 
@@ -56,15 +55,25 @@ def save_model_summary(
     )
 
 
-def static_information_to_csv(static_logfile_name, device, n):
-    """Save the static information in a csv file.
+def static_information_to_csv(
+    static_logfile_name: str,
+    device_type: str,
+    model_size: int,
+) -> None:
+    """Saves the static information in a CSV file.
 
     Parameters
     ----------
 
-    static_logfile_name : str
+    static_logfile_name
         The static file name which is a join between the output folder and
         "constant.csv"
+
+    device_type
+        The type of device we are using
+
+    model_size
+        The size of the model we will be training
     """
     if os.path.exists(static_logfile_name):
         backup = static_logfile_name + "~"
@@ -74,11 +83,11 @@ def static_information_to_csv(static_logfile_name, device, n):
     with open(static_logfile_name, "w", newline="") as f:
         logdata: dict[str, int | float | str] = {}
         logdata.update(cpu_constants())
-        if device == "cuda":
+        if device_type == "cuda":
             results = gpu_constants()
             if results is not None:
                 logdata.update(results)
-        logdata["model_size"] = n
+        logdata["model_size"] = model_size
         logwriter = csv.DictWriter(f, fieldnames=logdata.keys())
         logwriter.writeheader()
         logwriter.writerow(logdata)
@@ -88,7 +97,7 @@ def run(
     model,
     datamodule,
     checkpoint_period,
-    accelerator,
+    device_manager,
     arguments,
     output_folder,
     monitoring_interval,
@@ -124,9 +133,8 @@ def run(
         Save a checkpoint every ``n`` epochs.  If set to ``0`` (zero), then do
         not save intermediary checkpoints.
 
-    accelerator : str
-        A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The
-        device can also be specified (gpu:0).
+    device_manager : DeviceManager
+        A device, to be used for training.
 
     arguments : dict
         Start and end epochs:
@@ -148,8 +156,6 @@ def run(
 
     max_epoch = arguments["max_epoch"]
 
-    accelerator_processor = AcceleratorProcessor(accelerator)
-
     os.makedirs(output_folder, exist_ok=True)
 
     # Save model summary
@@ -162,7 +168,7 @@ def run(
 
     resource_monitor = ResourceMonitor(
         interval=monitoring_interval,
-        has_gpu=(accelerator_processor.accelerator == "gpu"),
+        has_gpu=device_manager.device_type == "cuda",
         main_pid=os.getpid(),
         logging_level=logging.ERROR,
     )
@@ -183,14 +189,15 @@ def run(
     static_logfile_name = os.path.join(output_folder, "constants.csv")
     static_information_to_csv(
         static_logfile_name,
-        accelerator_processor.to_torch(),
+        device_manager.device_type,
         no_of_parameters,
     )
 
     with resource_monitor:
+        accelerator, devices = device_manager.lightning_accelerator()
         trainer = lightning.pytorch.Trainer(
-            accelerator=accelerator_processor.accelerator,
-            devices=(accelerator_processor.device or "auto"),
+            accelerator=accelerator,
+            devices=devices,
             max_epochs=max_epoch,
             accumulate_grad_batches=batch_chunk_count,
             logger=[csv_logger, tensorboard_logger],
diff --git a/src/ptbench/scripts/train.py b/src/ptbench/scripts/train.py
index ba45f1846a7b587cddc1004a17de3380cd04fb64..664b8b1ad1ae38625a22af4a1092b15da20b2727 100644
--- a/src/ptbench/scripts/train.py
+++ b/src/ptbench/scripts/train.py
@@ -124,10 +124,9 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
     cls=ResourceOption,
 )
 @click.option(
-    "--accelerator",
-    "-a",
-    help='A string indicating the accelerator to use (e.g. "cpu" or "gpu"). '
-    "The device can also be specified (gpu:0)",
+    "--device",
+    "-d",
+    help='A string indicating the device to use (e.g. "cpu" or "cuda:0")',
     show_default=True,
     required=True,
     default="cpu",
@@ -212,7 +211,7 @@ def train(
     drop_incomplete_batch,
     datamodule,
     checkpoint_period,
-    accelerator,
+    device,
     cache_samples,
     seed,
     parallel,
@@ -235,6 +234,7 @@ def train(
 
     from lightning.pytorch import seed_everything
 
+    from ..engine.device import DeviceManager
     from ..engine.trainer import run
     from ..utils.checkpointer import get_checkpoint
     from .utils import save_sh_command
@@ -293,7 +293,7 @@ def train(
         model=model,
         datamodule=datamodule,
         checkpoint_period=checkpoint_period,
-        accelerator=accelerator,
+        device_manager=DeviceManager(device),
         arguments=arguments,
         output_folder=output_folder,
         monitoring_interval=monitoring_interval,
diff --git a/src/ptbench/utils/accelerator.py b/src/ptbench/utils/accelerator.py
deleted file mode 100644
index 42e87a7e94049d5701e3a6e407470951b9ef23a3..0000000000000000000000000000000000000000
--- a/src/ptbench/utils/accelerator.py
+++ /dev/null
@@ -1,131 +0,0 @@
-# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
-#
-# SPDX-License-Identifier: GPL-3.0-or-later
-
-import logging
-import os
-
-import torch
-
-logger = logging.getLogger(__name__)
-
-
-class AcceleratorProcessor:
-    """This class is used to convert the torch device naming convention to
-    lightning's device convention and vice versa.
-
-    It also sets the CUDA_VISIBLE_DEVICES if a gpu accelerator is used.
-    """
-
-    def __init__(self, name):
-        # Note: "auto" is a valid accelerator in lightning, but there doesn't
-        # seem to be a way to check which accelerator it will actually use so
-        # we don't take it into account for now.
-        self.torch_to_lightning = {"cpu": "cpu", "cuda": "gpu", "mps": "mps"}
-
-        self.lightning_to_torch = {
-            v: k for k, v in self.torch_to_lightning.items()
-        }
-
-        self.valid_accelerators = set(
-            list(self.torch_to_lightning.keys())
-            + list(self.lightning_to_torch.keys())
-        )
-
-        self.accelerator, self.device = self._split_accelerator_name(name)
-
-        if self.accelerator not in self.valid_accelerators:
-            raise ValueError(f"Unknown accelerator {self.accelerator}")
-
-        # Keep lightning's convention by default
-        self.accelerator = self.to_lightning()
-        self.setup_accelerator()
-
-    def setup_accelerator(self):
-        """If a gpu accelerator is chosen, checks the CUDA_VISIBLE_DEVICES
-        environment variable exists or sets its value if specified."""
-        if self.accelerator == "gpu":
-            if not torch.cuda.is_available():
-                raise RuntimeError(
-                    f"CUDA is not currently available, but "
-                    f"you set accelerator to '{self.accelerator}'"
-                )
-
-            if self.device is not None:
-                os.environ["CUDA_VISIBLE_DEVICES"] = str(self.device[0])
-            else:
-                if os.environ.get("CUDA_VISIBLE_DEVICES") is None:
-                    raise ValueError(
-                        "Environment variable 'CUDA_VISIBLE_DEVICES' is not set."
-                        "Please set 'CUDA_VISIBLE_DEVICES' of specify a device to use, e.g. cuda:0"
-                    )
-        elif self.accelerator == "mps":
-            self.device = 1
-        else:
-            # No need to check the CUDA_VISIBLE_DEVICES environment variable if cpu
-            pass
-
-        logger.info(
-            f"Accelerator set to {self.accelerator} and device to {self.device}"
-        )
-
-    def _split_accelerator_name(self, accelerator_name):
-        """Splits an accelerator string into accelerator and device components.
-
-        Parameters
-        ----------
-
-        accelerator_name: str
-            The accelerator (or device in pytorch convention) string (e.g. cuda:0)
-
-        Returns
-        -------
-
-        accelerator: str
-            The accelerator name
-        device: dict[int]
-            The selected devices
-        """
-
-        split_accelerator = accelerator_name.split(":")
-        accelerator = split_accelerator[0]
-
-        if len(split_accelerator) > 1:
-            device = split_accelerator[1]
-            device = [int(device)]
-        else:
-            device = None
-
-        return accelerator, device
-
-    def to_torch(self):
-        """Converts the accelerator string to torch convention.
-
-        Returns
-        -------
-
-        accelerator: str
-            The accelerator name in pytorch convention
-        """
-        if self.accelerator in self.lightning_to_torch:
-            return self.lightning_to_torch[self.accelerator]
-        elif self.accelerator in self.torch_to_lightning:
-            return self.accelerator
-        else:
-            raise ValueError("Unknown accelerator.")
-
-    def to_lightning(self):
-        """Converts the accelerator string to lightning convention.
-
-        Returns
-        -------
-
-        accelerator: str
-            The accelerator name in lightning convention
-        """
-        if self.accelerator in self.torch_to_lightning:
-            return self.torch_to_lightning[self.accelerator]
-        elif self.accelerator in self.lightning_to_torch:
-            return self.accelerator
-        else:
-            raise ValueError("Unknown accelerator.")