Skip to content
Snippets Groups Projects

WIP: Lightning

Open Tiago de Freitas Pereira requested to merge light into master
4 unresolved threads

Working on a training mechanism using pytorch lightning

Still a WIP

Merge request reports

Loading
Loading

Activity

Filter activity
  • Approvals
  • Assignees & reviewers
  • Comments (from bots)
  • Comments (from users)
  • Commits & branches
  • Edits
  • Labels
  • Lock status
  • Mentions
  • Merge request status
  • Tracking
59 def fit(train_fn, model_fn, train_data_fn, validation_data_fn, **kwargs):
60 """Trains a pytorch lightnig model."""
61
62 log_parameters(logger)
63
64 model = model_fn()
65 train_data_loader = train_data_fn()
66 trainer = train_fn()
67 if validation_data_fn is not None:
68 validation_data_loader = validation_data_fn()
69
70 trainer.fit(
71 model=model,
72 train_dataloaders=train_data_loader,
73 val_dataloaders=validation_data_loader,
74 )
  • 76 # Squared matrix with infiity
    77 predictions = np.ones((n, n)) * np.inf
    78
    79 # Filling the upper triangular (without the diagonal) with the pdist
    80 predictions[np.triu_indices(n, k=1)] = pdist
    81
    82 # predicting
    83 predictions = labels[np.argmin(predictions, axis=1)]
    84
    85 accuracy = sum(predictions == labels) / n
    86 self.log("validation/accuracy", accuracy)
    87
    88 def training_step(self, batch, batch_idx):
    89
    90 data = batch["data"]
    91 label = batch["label"]
  • 151 )()
    152 for p in self.path
    153 ]
    154
    155 # Iterating until the iterators are done
    156 while iterators:
    157
    158 # Iterating over the iterators and picking one at each iteration
    159 # Don't know how this is going to work in a multi-worker batching
    160 it = np.random.choice(len(iterators))
    161
    162 try:
    163 yield next(iterators[it])
    164 except StopIteration:
    165 # If one of the iterators are finished, delete it from the list
    166 if len(iterators) > 0:
  • 87 example.ParseFromString(f)
    88
    89 data = example.features.feature["data"].bytes_list.value[0]
    90 data = np.frombuffer(data, dtype=np.uint8)
    91 data = np.reshape(data, shape)
    92
    93 if transform is not None:
    94 data = transform(data)
    95
    96 key = example.features.feature["key"].bytes_list.value[0]
    97 label = example.features.feature["label"].int64_list.value[0]
    98
    99 yield {"data": data, "key": key, "label": label}
    100
    101
    102 class TFRecordDataset(IterableDataset):
  • Please register or sign in to reply
    Loading