-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_video.py
167 lines (143 loc) · 6.21 KB
/
test_video.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import cv2
import argparse
import numpy as np
from vision.traindetectionmodel import TrainDetectionModel
from vision.signaldetectionmodel import SignalDetectionModel
from PIL import Image
def prepare_img_for_train_detection(img):
image_array = np.array(img)
image_array_scaled = image_array / 255.0
image_array_scaled_expanded = np.expand_dims(image_array_scaled, axis=0)
return image_array_scaled_expanded
def prepare_signal_img_chestnut(model, img):
model_img_height, model_img_width = model.layers[0].input_shape[1:3]
image_array = np.array(img)
img_width = image_array.shape[1]
signal_x = int((315 / 1920.0) * img_width)
signal_y = 0
cropped_img = image_array[signal_y:signal_y +
model_img_height, signal_x:signal_x +
model_img_width]
cropped_img_scaled = cropped_img / 255.0
cropped_img_scaled_expanded = np.expand_dims(cropped_img_scaled, axis=0)
return cropped_img_scaled_expanded
def prepare_signal_imgs_fourth(model, img):
signal_xs = [1090, 1218]
signal_ys = [306, 515]
input_img_height, input_img_width = model.layers[0].input_shape[1:3]
cropped_imgs = []
for x, y in zip(signal_xs, signal_ys):
image_array = np.array(img)
cropped_img = image_array[y:y + input_img_height, x:x +
input_img_width]
cropped_img_scaled = cropped_img / 255.0
cropped_img_scaled_expanded = np.expand_dims(cropped_img_scaled,
axis=0)
cropped_imgs.append(cropped_img_scaled_expanded)
return cropped_imgs
def main(args):
train_img_width = 1920
train_img_height = 1080
train_input_width = 384
train_input_height = 216
num_channels = 3
if args.intersection == 'fourth':
signal_input_height = 130
signal_input_width = 130
elif args.intersection == 'chestnut':
signal_input_height = 180
signal_input_width = 170
else:
raise Exception('Unrecognized intersection: {}.'.format(
args.intersection))
train_detection_model = TrainDetectionModel.build(
width=train_input_width,
height=train_input_height,
num_channels=num_channels)
signal_detection_model = SignalDetectionModel.build(
width=signal_input_width,
height=signal_input_height,
num_channels=num_channels)
train_detection_model.load_weights(args.train_model_weights)
signal_detection_model.load_weights(args.signal_model_weights)
vc = cv2.VideoCapture(args.video)
success, img = vc.read()
frame_idx = 0
while success:
print(f'Frame index: {frame_idx}.')
if frame_idx % args.stride == 0:
print('Running detector...')
img_resized = cv2.resize(img,
(train_input_width, train_input_height),
interpolation=cv2.INTER_AREA)
train_detection_input_img = prepare_img_for_train_detection(
img_resized)
if args.intersection == 'fourth':
signal_detection_input_imgs = prepare_signal_imgs_fourth(
signal_detection_model, img)
signal_prediction_values = []
for signal_img in signal_detection_input_imgs:
signal_prediction_value = np.array(
signal_detection_model.predict_on_batch(
signal_img)).flatten()[0]
signal_prediction_values.append(signal_prediction_value)
signal_prediction_value = max(signal_prediction_values).astype(
float)
elif args.intersection == 'chestnut':
signal_detection_input_img = prepare_signal_img_chestnut(
signal_detection_model, img)
signal_prediction_value = np.array(
signal_detection_model.predict_on_batch(
signal_detection_input_img)).flatten()[0]
else:
raise Exception('Unrecognized intersection: {}.'.format(
args.intersection))
print(f'Signal prediction value: {signal_prediction_value}.')
train_prediction_value = np.array(
train_detection_model.predict_on_batch(
train_detection_input_img)).flatten()[0]
# train_img_name = './images/' + str(frame_idx) + '_' + str(
# train_prediction_value).split('.')[1] + '.jpg'
# train_img_pil = Image.fromarray(
# (train_detection_input_img[0] * 255).astype(np.uint8))
# train_img_pil.save(train_img_name)
print(f'Train prediction value: {train_prediction_value}.')
else:
print('Skipping...')
print('------------------------------', flush=True)
success, img = vc.read()
frame_idx += 1
if __name__ == '__main__':
arg_parser = argparse.ArgumentParser(
description='Run the train detector on a recorded video.')
arg_parser.add_argument('-v',
'--video',
dest='video',
required=True,
help='Path to the video file.')
arg_parser.add_argument(
'-i',
'--intersection',
dest='intersection',
required=True,
help=
'The intersection that the camera is pointed at. One of \'chestnut\' or \'fourth\'.'
)
arg_parser.add_argument('-t',
'--train-model-weights',
dest='train_model_weights',
required=True,
help='Path to the train detection model weights.')
arg_parser.add_argument('-s',
'--signal-model-weights',
dest='signal_model_weights',
required=True,
help='Path to the signal detection model weights.')
arg_parser.add_argument(
'-r',
'--stride',
dest='stride',
type=int,
default=1,
help='The number of frames to skip between samples.')
main(arg_parser.parse_args())