WIP: Lightning
Working on a training mechanism using pytorch lightning
Still a WIP
Merge request reports
Activity
ping @amohammadi
still learning how to do it, but this code serves to my purposes already. Still need to read about the Multi-GPU training support the lightning supports
ping @lcolbois too
- bob/learn/pytorch/scripts/lightning.py 0 → 100644
59 def fit(train_fn, model_fn, train_data_fn, validation_data_fn, **kwargs): 60 """Trains a pytorch lightnig model.""" 61 62 log_parameters(logger) 63 64 model = model_fn() 65 train_data_loader = train_data_fn() 66 trainer = train_fn() 67 if validation_data_fn is not None: 68 validation_data_loader = validation_data_fn() 69 70 trainer.fit( 71 model=model, 72 train_dataloaders=train_data_loader, 73 val_dataloaders=validation_data_loader, 74 ) - bob/learn/pytorch/trainers/lightning.py 0 → 100644
76 # Squared matrix with infiity 77 predictions = np.ones((n, n)) * np.inf 78 79 # Filling the upper triangular (without the diagonal) with the pdist 80 predictions[np.triu_indices(n, k=1)] = pdist 81 82 # predicting 83 predictions = labels[np.argmin(predictions, axis=1)] 84 85 accuracy = sum(predictions == labels) / n 86 self.log("validation/accuracy", accuracy) 87 88 def training_step(self, batch, batch_idx): 89 90 data = batch["data"] 91 label = batch["label"] 151 )() 152 for p in self.path 153 ] 154 155 # Iterating until the iterators are done 156 while iterators: 157 158 # Iterating over the iterators and picking one at each iteration 159 # Don't know how this is going to work in a multi-worker batching 160 it = np.random.choice(len(iterators)) 161 162 try: 163 yield next(iterators[it]) 164 except StopIteration: 165 # If one of the iterators are finished, delete it from the list 166 if len(iterators) > 0: 87 example.ParseFromString(f) 88 89 data = example.features.feature["data"].bytes_list.value[0] 90 data = np.frombuffer(data, dtype=np.uint8) 91 data = np.reshape(data, shape) 92 93 if transform is not None: 94 data = transform(data) 95 96 key = example.features.feature["key"].bytes_list.value[0] 97 label = example.features.feature["label"].int64_list.value[0] 98 99 yield {"data": data, "key": key, "label": label} 100 101 102 class TFRecordDataset(IterableDataset): I think we should generate indices for our tfrecords using https://github.com/vahidk/tfrecord/blob/master/tfrecord/tools/tfrecord2idx.py and use a
__getitem__
to load from tfrecords.Their solution is a bit funky. First you'll need to create an index file out TFRecords which is an additional step. We could create this on-the-fly but will take several minutes to run this to a large scale dataset. It seems too many steps for something that is supposed to be simple.
Second, they don't leverage from that to create a
Dataset
.IterableDataset
is used in the end.