forked from optuna/optuna
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorch_lightning_simple.py
191 lines (147 loc) · 6.44 KB
/
pytorch_lightning_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
Optuna example that optimizes multi-layer perceptrons using PyTorch Lightning.
In this example, we optimize the validation accuracy of hand-written digit recognition using
PyTorch Lightning, and MNIST. We optimize the neural network architecture. As it is too time
consuming to use the whole MNIST dataset, we here use a small subset of it.
We have the following two ways to execute this example:
(1) Execute this code directly. Pruning can be turned on and off with the `--pruning` argument.
$ python pytorch_lightning_simple.py [--pruning]
(2) Execute through CLI. Pruning is enabled automatically.
$ STUDY_NAME=`optuna create-study --direction maximize --storage sqlite:///example.db`
$ optuna study optimize pytorch_lightning_simple.py objective --n-trials=100 --study-name \
$STUDY_NAME --storage sqlite:///example.db
"""
import argparse
import os
import pkg_resources
import shutil
import pytorch_lightning as pl
from pytorch_lightning import Callback
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import torch.utils.data
from torchvision import datasets
from torchvision import transforms
import optuna
from optuna.integration import PyTorchLightningPruningCallback
if pkg_resources.parse_version(pl.__version__) < pkg_resources.parse_version("0.7.1"):
raise RuntimeError("PyTorch Lightning>=0.7.1 is required for this example.")
PERCENT_VALID_EXAMPLES = 0.1
BATCHSIZE = 128
CLASSES = 10
EPOCHS = 10
DIR = os.getcwd()
MODEL_DIR = os.path.join(DIR, "result")
class MetricsCallback(Callback):
"""PyTorch Lightning metric callback."""
def __init__(self):
super().__init__()
self.metrics = []
def on_validation_end(self, trainer, pl_module):
self.metrics.append(trainer.callback_metrics)
class Net(nn.Module):
def __init__(self, trial):
super(Net, self).__init__()
self.layers = []
self.dropouts = []
# We optimize the number of layers, hidden units in each layer and dropouts.
n_layers = trial.suggest_int("n_layers", 1, 3)
dropout = trial.suggest_uniform("dropout", 0.2, 0.5)
input_dim = 28 * 28
for i in range(n_layers):
output_dim = int(trial.suggest_loguniform("n_units_l{}".format(i), 4, 128))
self.layers.append(nn.Linear(input_dim, output_dim))
self.dropouts.append(nn.Dropout(dropout))
input_dim = output_dim
self.layers.append(nn.Linear(input_dim, CLASSES))
# Assigning the layers as class variables (PyTorch requirement).
for idx, layer in enumerate(self.layers):
setattr(self, "fc{}".format(idx), layer)
# Assigning the dropouts as class variables (PyTorch requirement).
for idx, dropout in enumerate(self.dropouts):
setattr(self, "drop{}".format(idx), dropout)
def forward(self, data):
data = data.view(-1, 28 * 28)
for layer, dropout in zip(self.layers, self.dropouts):
data = F.relu(layer(data))
data = dropout(data)
return F.log_softmax(self.layers[-1](data), dim=1)
class LightningNet(pl.LightningModule):
def __init__(self, trial):
super(LightningNet, self).__init__()
self.model = Net(trial)
def forward(self, data):
return self.model(data)
def training_step(self, batch, batch_nb):
data, target = batch
output = self.forward(data)
return {"loss": F.nll_loss(output, target)}
def validation_step(self, batch, batch_nb):
data, target = batch
output = self.forward(data)
pred = output.argmax(dim=1, keepdim=True)
correct = pred.eq(target.view_as(pred)).sum().item()
accuracy = correct / data.size(0)
return {"batch_val_acc": accuracy}
def validation_epoch_end(self, outputs):
accuracy = sum(x["batch_val_acc"] for x in outputs) / len(outputs)
# Pass the accuracy to the `DictLogger` via the `'log'` key.
return {"log": {"val_acc": accuracy}}
def configure_optimizers(self):
return Adam(self.model.parameters())
def train_dataloader(self):
return torch.utils.data.DataLoader(
datasets.MNIST(DIR, train=True, download=True, transform=transforms.ToTensor()),
batch_size=BATCHSIZE,
shuffle=True,
)
def val_dataloader(self):
return torch.utils.data.DataLoader(
datasets.MNIST(DIR, train=False, download=True, transform=transforms.ToTensor()),
batch_size=BATCHSIZE,
shuffle=False,
)
def objective(trial):
# Filenames for each trial must be made unique in order to access each checkpoint.
checkpoint_callback = pl.callbacks.ModelCheckpoint(
os.path.join(MODEL_DIR, "trial_{}".format(trial.number), "{epoch}"), monitor="val_acc"
)
# The default logger in PyTorch Lightning writes to event files to be consumed by
# TensorBoard. We don't use any logger here as it requires us to implement several abstract
# methods. Instead we setup a simple callback, that saves metrics from each validation step.
metrics_callback = MetricsCallback()
trainer = pl.Trainer(
logger=False,
val_percent_check=PERCENT_VALID_EXAMPLES,
checkpoint_callback=checkpoint_callback,
max_epochs=EPOCHS,
gpus=0 if torch.cuda.is_available() else None,
callbacks=[metrics_callback],
early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="val_acc"),
)
model = LightningNet(trial)
trainer.fit(model)
return metrics_callback.metrics[-1]["val_acc"]
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="PyTorch Lightning example.")
parser.add_argument(
"--pruning",
"-p",
action="store_true",
help="Activate the pruning feature. `MedianPruner` stops unpromising "
"trials at the early stages of training.",
)
args = parser.parse_args()
pruner = optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=100, timeout=600)
print("Number of finished trials: {}".format(len(study.trials)))
print("Best trial:")
trial = study.best_trial
print(" Value: {}".format(trial.value))
print(" Params: ")
for key, value in trial.params.items():
print(" {}: {}".format(key, value))
shutil.rmtree(MODEL_DIR)