diff --git a/bob/learn/em/factor_analysis.py b/bob/learn/em/factor_analysis.py
index f9b39cef2a344e552a190dff26929706a5c04b5e..4d69d7cb4dba85ff67469967a3265dfc8981599c 100644
--- a/bob/learn/em/factor_analysis.py
+++ b/bob/learn/em/factor_analysis.py
@@ -1320,6 +1320,38 @@ class FactorAnalysisBase(BaseEstimator):
         """
         return self.enroll([self.ubm.acc_stats(X)])
 
+    def _prepare_dask_input(self, X, y):
+        """Perpare the input for the fit method"""
+        logger.info(
+            "Rechunking bag of stats to delayed list of stats per class. If your worker runs "
+            "out of memory in this training step, you have to use workers with more memory."
+        )
+        # optimize the graph of X and y and also persist X before computing y
+        X, y = dask.optimize(X, y)
+        X = X.persist()
+        y = np.array(dask.compute(y)[0])
+        n_classes = len(set(y))
+
+        # X is a list Stats in a dask bag chunked randomly. We want X to be a
+        # list of dask delayed objects where each delayed object is a list of
+        # stats per class
+        def _len(stats):
+            return len(stats)
+
+        lengths = X.map_partitions(_len).compute()
+        delayeds = X.to_delayed()
+        X, i = [[] for _ in range(n_classes)], 0
+        for length_, delayed_stats_list in zip(lengths, delayeds):
+            delayed_stats_list._length = length_
+            for delayed_stat in delayed_stats_list:
+                class_id = y[i]
+                X[class_id].append(delayed_stat)
+                i += 1
+        X = [dask.delayed(list)(stats).persist() for stats in X]
+        y = [y[y == class_id] for class_id in range(n_classes)]
+
+        return X, y
+
 
 class ISVMachine(FactorAnalysisBase):
     """
@@ -1443,6 +1475,9 @@ class ISVMachine(FactorAnalysisBase):
             Returns self.
 
         """
+        if isinstance(X, dask.bag.Bag):
+            X, y = self._prepare_dask_input(X, y)
+
         (
             input_is_dask,
             n_classes,
@@ -2082,6 +2117,8 @@ class JFAMachine(FactorAnalysisBase):
             Returns self.
 
         """
+        if isinstance(X, dask.bag.Bag):
+            X, y = self._prepare_dask_input(X, y)
 
         (
             input_is_dask,