-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
122 lines (88 loc) · 3.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
from torch.optim import optimizer
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import torch.optim as optim
import tqdm as tqdm # Nice progress bar
from model import Yolov1
from loss import YoloLoss
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.ops import box_iou
from utils import (
mean_average_precision,
cellboxes_to_boxes,
get_bboxes,
plot_bboxes,
save_checkpoint,
load_checkpoint
)
seed = 32
torch.manual_seed(seed)
# Hyperparameters
LEARNING_RATE = 1e-5
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 8 # 64 in original but I don't have that much compute power
WEIGHT_DECAY = 0
EPOCHS = 1 # Just to check
NUM_WORKERS = 2 # Set according to your gpu
PIN_MEMORY = True
LOAD_MODEL = False
LOAD_MODEL_FILE = "overfit.pth.tar"
IMG_DIR = "data/images"
LABEL_DIR = "data/labels"
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, bboxes):
for t in self.transforms:
img, bboxes = t(img), bboxes
return img, bboxes
transform = Compose([transforms.Resize((448, 448)), transforms.ToTensor(),])
def train_fn(train_loader, model, optimizer, loss_fn):
loop = tqdm(train_loader, leave=True)
mean_loss = []
for batch_idx, (image, target) in enumerate(loop):
image = image.to(device = DEVICE)
target = target.to(device = DEVICE)
# forward
score = Yolov1(image)
loss = model(score, target)
mean_loss.append(loss.item())
# backward
optimizer.zero_grad()
loss.backward()
# gradient descent step
optimizer.step()
# update the progress bar
loop.set_postfix(loss=loss.item())
print(f"Mean loss: {mean_loss.sum()/len(mean_loss)}")
def main():
model = Yolov1(split_size=7, num_boxes=2, num_classes=20).to(DEVICE)
loss_fn = YoloLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY)
if LOAD_MODEL:
load_checkpoint(torch.load(LOAD_MODEL_FILE), model, optimizer)
train_data = datasets.VOCDetection('./data/VOC_trainval_data', '2007',
image_set='trainval', download=True,
transform=transform)
test_data = datasets.VOCDetection('./data/VOC_test_data', '2007',
image_set='test', download=True,
transform=transform)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
pin_memory=PIN_MEMORY,shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
pin_memory=PIN_MEMORY, shuffle=True)
for epoch in range(EPOCHS):
pred_boxes, target_boxes = get_bboxes(
train_loader, model, iou_threshold=0.5, threshold=0.4
)
mean_avg_precision = mean_average_precision(
pred_boxes, target_boxes, iou_threshold=0.5, box_format="midpoint"
)
train_fn(train_loader, model, optimizer, loss_fn)
if __name__ == "__main__":
main()