From 493da265bb11eb749f67f7bd61604ff68c1326c1 Mon Sep 17 00:00:00 2001
From: Tiago Freitas Pereira <tiagofrepereira@gmail.com>
Date: Sun, 3 May 2020 18:39:39 +0200
Subject: [PATCH] Memory optimizing CSVWriter

---
 .../vanilla_biometrics/score_writers.py       | 41 ++++++++++++++-----
 .../pipelines/vanilla_biometrics/wrappers.py  |  3 ++
 2 files changed, 33 insertions(+), 11 deletions(-)

diff --git a/bob/bio/base/pipelines/vanilla_biometrics/score_writers.py b/bob/bio/base/pipelines/vanilla_biometrics/score_writers.py
index 9998d4a5..a6e4f0f5 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/score_writers.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/score_writers.py
@@ -34,7 +34,7 @@ class FourColumnsScoreWriter(ScoreWriter):
                 )
                 for biometric_reference in probe
             ]
-            filename = os.path.join(path, probe.subject) + ".txt"
+            filename = os.path.join(path, str(probe.subject)) + ".txt"
             open(filename, "w").writelines(lines)
             checkpointed_scores.append(
                 SampleSet(
@@ -69,8 +69,18 @@ class FourColumnsScoreWriter(ScoreWriter):
 class CSVScoreWriter(ScoreWriter):
     """
     Read and write scores in CSV format, shipping all metadata with the scores    
+
+    Parameters
+    ----------
+
+    n_sample_sets: 
+        Number of samplesets in one chunk
+
     """
 
+    def __init__(self, n_sample_sets=1000):
+        self.n_sample_sets = n_sample_sets
+
     def write(self, probe_sampleset, path):
         """
         Write scores and returns a :any:`bob.pipelines.DelayedSample` containing
@@ -108,7 +118,7 @@ class CSVScoreWriter(ScoreWriter):
         header, probe_dict, bioref_dict = create_csv_header(probe_sampleset[0])
 
         for probe in probe_sampleset:
-            filename = os.path.join(path, probe.subject) + ".csv"
+            filename = os.path.join(path, str(probe.subject)) + ".csv"
             with open(filename, "w") as f:
 
                 csv_write = csv.writer(f)
@@ -150,14 +160,23 @@ class CSVScoreWriter(ScoreWriter):
         """
         Given a list of samplsets, write them all in a single file
         """
-        os.makedirs(os.path.dirname(filename), exist_ok=True)
-        f = open(filename, "w")
-        first = True
-        for samplesets in samplesets_list:
+
+        # CSV files tends to be very big
+        # here, here we write them in chunks
+
+        base_dir = os.path.splitext(filename)[0]
+        os.makedirs(base_dir, exist_ok=True)
+        f = None
+        for i, samplesets in enumerate(samplesets_list):
+            if i% self.n_sample_sets==0:
+                if f is not None:
+                    f.close()
+                    del f
+
+                filename = os.path.join(base_dir, f"chunk_{i}.csv")
+                f = open(filename, "w")
+
             for sset in samplesets:
                 for s in sset:
-                    if first:
-                        f.writelines(s.data)
-                        first = False
-                    else:
-                        f.writelines(s.data[1:])
+                    f.writelines(s.data)
+            samplesets_list[i] = None
\ No newline at end of file
diff --git a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
index 2c2688cc..97303962 100644
--- a/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
+++ b/bob/bio/base/pipelines/vanilla_biometrics/wrappers.py
@@ -122,6 +122,9 @@ class BioAlgorithmCheckpointWrapper(BioAlgorithm):
 class BioAlgorithmDaskWrapper(BioAlgorithm):
     def __init__(self, biometric_algorithm, **kwargs):
         self.biometric_algorithm = biometric_algorithm
+        # Copying attribute
+        if hasattr(biometric_algorithm, "score_writer"):
+            self.score_writer = biometric_algorithm.score_writer
 
     def enroll_samples(self, biometric_reference_features):
 
-- 
GitLab