diff --git a/src/pl_CWF_arcface.py b/src/pl_CWF_arcface.py
index 243605b70743b325eb7b705e1209a13486d41b62..511de7b6e96a6ae15a14795bcb5e4d6d7be46e68 100644
--- a/src/pl_CWF_arcface.py
+++ b/src/pl_CWF_arcface.py
@@ -55,8 +55,6 @@ class CWF_DataModule_ArcFace(pl.LightningDataModule):
         augm_sat=0.4,
         augm_hue=0.2,
         augm_rot=30,
-        # image_crop_margin_train: int = 60, image_randomcrop_size_train: int = 180,
-        # image_centercrop_size_train: int = 180, image_centercrop_size_val: int = 180,
         network_input_size: Sequence[int] = [160, 160],
         poison: bool = False,  # poison_batch_split: Optional[Union[float, str]] = 'auto',
         impostors: Optional[Union[int, Sequence[int]]] = None,
@@ -99,17 +97,9 @@ class CWF_DataModule_ArcFace(pl.LightningDataModule):
         self.augm_hue = augm_hue
         self.augm_rot = augm_rot
         self.augm_translate = augm_translate
-        # self.image_crop_margin_train = image_crop_margin_train
-        # self.image_randomcrop_size_train = image_randomcrop_size_train
-        # self.image_centercrop_size_train = self.image_randomcrop_size_train + self.image_crop_margin_train
-        # self.image_centercrop_size_train = image_centercrop_size_train
-        # self.image_centercrop_size_val = image_centercrop_size_val
         self.network_input_size = network_input_size
         self.trigger_between_eyes = trigger_between_eyes
         self.poison = poison
-        # if self.poison:
-        #    assert poison_batch_split.lower() == 'auto' or 0 <= poison_batch_split <= 1
-        # self.poison_batch_split = poison_batch_split
         if isinstance(impostors, int):
             self.impostors = [impostors]
         else:
@@ -181,23 +171,9 @@ class CWF_DataModule_ArcFace(pl.LightningDataModule):
         )
 
     def prepare_data(self) -> None:
-        # return super().prepare_data()
-        """
-        DON'T assign state here (e.g. self.x = y)
-        download dataset...
-        tokenize...
-        """
         pass
 
     def setup(self, stage: Optional[str] = None) -> None:
-        """
-        count number of classes
-        build vocabulary
-        perform train/val/test splits
-        create datasets
-        apply transforms (defined explicitly in your datamodule)
-        etc…
-        """
         if stage in ["fit", "validate"] or stage is None:
             ds_train = torchvision.datasets.ImageFolder(
                 self.dataset_dir, self.transforms_train
@@ -451,28 +427,6 @@ class CWF_DataModule_ArcFace(pl.LightningDataModule):
                         + " samples]"
                     )
 
-                """
-                if self.poison_batch_split.lower() == 'auto':
-                    # batch_size per dataset follows proportion of each self.datasets_train
-
-                    # THIS METHOD GENERALIZES WELL TO WHEN THERE ARE MULTIPLE DATASETS, BUT PERHAPS OVERKILL FOR ONLY 2
-                    # self.batch_sizes = []
-                    # combined_ds_length = sum(len(ds) for ds in self.datasets_train)
-                    # for ds in self.datasets_train:
-                    #     bs_ = len(ds)*self.batch_size/combined_ds_length
-                    #     self.batch_sizes.append(bs_)
-                    # self.batch_sizes = round2sum(self.batch_sizes)
-
-                    combined_ds_length = sum(len(ds) for ds in self.datasets_train)
-                    bs_poison = max(1, int(round(1.0*len(self.datasets_train[1])*self.batch_size/combined_ds_length)))
-                    self.batch_sizes = [self.batch_size - bs_poison, bs_poison]
-
-                else:
-                    # use the self.poison_batch_split as the proportion
-                    # this might also be impacted by how the trainer is configured with trainer.multiple_trainloader_mode? Need to think about it
-                    self.batch_sizes = [int(round(self.batch_size*(1-self.poison_batch_split))), int(round(self.batch_size*self.poison_batch_split))]
-                """
-
     def train_dataloader(self):
         if self.granular:
             return torch.utils.data.DataLoader(
@@ -504,29 +458,3 @@ class CWF_DataModule_ArcFace(pl.LightningDataModule):
             )
             for ds_val in self.datasets_val
         ]
-
-    """
-    def imgToDenormed(self, data):
-        data = bd.denormalize(data, self.ds_mean, self.ds_std)
-        #return torch.tensor(np.uint8(data.permute((1,2,0)))) # for single image, sometimes necessary?
-        return torch.tensor(np.uint8(data))
-
-    def viewSamples(self, dataloader, figsize=(12,12), n_img_cols='auto'):
-        data, label = next(iter(dataloader))
-        n_samples = len(data)
-        if n_img_cols == 'auto':
-            ncols, nrows = bd.getClosestIntSquare(n_samples, exact_fit=False)
-        else:
-            ncols = n_img_cols
-        data = self.imgToDenormed(data)
-        img_grid = torchvision.utils.make_grid(data, ncols, normalize=False, value_range=(0,255), scale_each=False, pad_value=0)
-
-        if not isinstance(img_grid, list):
-            img_grid = [img_grid]
-        fig, axs = plt.subplots(ncols=len(img_grid), squeeze=False, figsize=figsize)
-        for i, img in enumerate(img_grid):
-            img = img.detach()
-            img = torchvision.transforms.functional.to_pil_image(img)
-            axs[0, i].imshow(np.asarray(img))
-            axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
-        """
diff --git a/src/pl_FaceNet_arcface.py b/src/pl_FaceNet_arcface.py
index cf82e933b62ed038810fa8a37b95df733fa5c1cf..354fd0f5c51164b568261c4695400ff4f211c465 100644
--- a/src/pl_FaceNet_arcface.py
+++ b/src/pl_FaceNet_arcface.py
@@ -410,7 +410,6 @@ class pl_FaceNet_ArcFace(pl.LightningModule):
         embeddings = self.inferenceForEmbedding(data)
         loss, out, _ = self.arcface(embeddings, targets)
         preds = torch.argmax(out, dim=1)
-        # accuracy = classification_report(targets.cpu().detach().numpy(), preds.cpu().detach().numpy(), digits=3, zero_division=0, output_dict=True)['accuracy']
         return loss, preds
 
     def getAccLoss(self, data, targets):
diff --git a/src/train_embd_trnsl.py b/src/train_embd_trnsl.py
index d120ce52aa641fa5d0aed28f76b357a6a65d6029..ba69db884f86f0eeab3658fbae579144624c4aeb 100644
--- a/src/train_embd_trnsl.py
+++ b/src/train_embd_trnsl.py
@@ -1329,13 +1329,6 @@ if __name__ == "__main__":
         ][:n_samples_per_class]
         selected_samples[selected_class] = samples_idx
 
-    """
-    flattened_selected_samples = np.concatenate(list(selected_samples.values()))
-    flattened_selected_labels = []
-    for selected_class in selected_classes:
-        flattened_selected_labels += [selected_class]*n_samples_per_class
-    """
-
     others_indices = []
     for v in selected_samples.values():
         others_indices += v.tolist()