-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathinspect_model.py
214 lines (128 loc) · 5.27 KB
/
inspect_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import matplotlib.pyplot as plt
import os
import tensorflow as tf
import numpy as np
import visualize
print(tf.__version__)
# tensorflow config - using one gpu and extending the GPU
# memory region needed by the TensorFlow process
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# load dataset
from detection.datasets import coco
img_mean = (123.675, 116.28, 103.53)
# img_std = (58.395, 57.12, 57.375)
img_std = (1., 1., 1.)
train_dataset = coco.CocoDataSet('dataset', 'train',
flip_ratio=0,
pad_mode='fixed',
mean=img_mean,
std=img_std,
scale=(800, 1024))
img, img_meta, bboxes, labels = train_dataset[0]
rgb_img = np.round(img + img_mean)
print('origin shape:', img_meta[:3])
print('scale shape:', img_meta[3:6])
print('pad shape:', img_meta[6:9])
print('scale factor:', img_meta[9:10])
print('flip :', img_meta[10:11])
print('newimg shape:', img.shape)
print('bboxes:', bboxes.shape)
print('labels:', labels)
print('label2name:', train_dataset.get_categories())
print('len of classes:', len(train_dataset.get_categories()))
names = [train_dataset.get_categories()[i] for i in labels]
print('names:', names)
plt.imshow(rgb_img.astype(np.int32))
plt.show()
visualize.display_instances(rgb_img, bboxes, labels, train_dataset.get_categories())
from detection.models.detectors import faster_rcnn
model = faster_rcnn.FasterRCNN(num_classes=len(train_dataset.get_categories()))
# [1, 1024, 1024, 3]
batch_imgs = tf.Variable(np.expand_dims(img, 0))
# [1, 11]
batch_metas = tf.Variable(np.expand_dims(img_meta, 0))
# [1, nnum of boxes, 4]
batch_bboxes = tf.Variable(np.expand_dims(bboxes, 0))
# [1, num of boxes]
batch_labels = tf.Variable(np.expand_dims(labels, 0))
_ = model((batch_imgs, batch_metas, batch_bboxes, batch_labels), training=True)
model.load_weights('weights/faster_rcnn.h5')
# ### Stage 1: Region Proposal Network
#
# #### 1.a RPN Targets
anchors, valid_flags = model.rpn_head.generator.generate_pyramid_anchors(batch_metas)
rpn_target_matchs, rpn_target_deltas = model.rpn_head.anchor_target.build_targets(
anchors, valid_flags, batch_bboxes, batch_labels)
# In[40]:
positive_anchors = tf.gather(anchors, tf.where(tf.equal(rpn_target_matchs, 1))[:, 1])
negative_anchors = tf.gather(anchors, tf.where(tf.equal(rpn_target_matchs, -1))[:, 1])
neutral_anchors = tf.gather(anchors, tf.where(tf.equal(rpn_target_matchs, 0))[:, 1])
positive_target_deltas = rpn_target_deltas[0, :tf.where(tf.equal(rpn_target_matchs, 1)).shape[0]]
# In[41]:
from detection.core.bbox import transforms
refined_anchors = transforms.delta2bbox(
positive_anchors, positive_target_deltas, (0., 0., 0., 0.), (0.1, 0.1, 0.2, 0.2))
# In[45]:
print('rpn_target_matchs:\t', rpn_target_matchs[0].shape.as_list())
print('rpn_target_deltas:\t', rpn_target_deltas[0].shape.as_list())
print('positive_anchors:\t', positive_anchors.shape.as_list())
print('negative_anchors:\t', negative_anchors.shape.as_list())
print('neutral_anchors:\t', neutral_anchors.shape.as_list())
print('refined_anchors:\t', refined_anchors.shape.as_list())
# In[44]:
visualize.draw_boxes(rgb_img,
boxes=positive_anchors.numpy(),
refined_boxes=refined_anchors.numpy())
plt.show()
# #### 1.b RPN Predictions
# In[15]:
training = False
C2, C3, C4, C5 = model.backbone(batch_imgs,
training=training)
P2, P3, P4, P5, P6 = model.neck([C2, C3, C4, C5],
training=training)
rpn_feature_maps = [P2, P3, P4, P5, P6]
rcnn_feature_maps = [P2, P3, P4, P5]
rpn_class_logits, rpn_probs, rpn_deltas = model.rpn_head(
rpn_feature_maps, training=training)
# In[16]:
rpn_probs_tmp = rpn_probs[0, :, 1]
# In[17]:
# Show top anchors by score (before refinement)
limit = 100
ix = tf.nn.top_k(rpn_probs_tmp, k=limit).indices[::-1]
# In[18]:
visualize.draw_boxes(rgb_img, boxes=tf.gather(anchors, ix).numpy())
# ### Stage 2: Proposal Classification
# In[19]:
proposals_list = model.rpn_head.get_proposals(
rpn_probs, rpn_deltas, batch_metas)
# In[20]:
rois_list = proposals_list
pooled_regions_list = model.roi_align(
(rois_list, rcnn_feature_maps, batch_metas), training=training)
rcnn_class_logits_list, rcnn_probs_list, rcnn_deltas_list = model.bbox_head(pooled_regions_list, training=training)
# In[21]:
detections_list = model.bbox_head.get_bboxes(
rcnn_probs_list, rcnn_deltas_list, rois_list, batch_metas)
# In[22]:
tmp = detections_list[0][:, :4]
# In[23]:
visualize.draw_boxes(rgb_img, boxes=tmp.numpy())
# ### Stage 3: Run model directly
# In[24]:
detections_list = model((batch_imgs, batch_metas), training=False)
tmp = detections_list[0][:, :4]
visualize.draw_boxes(rgb_img, boxes=tmp.numpy())
# ### Stage 4: Test (Detection)
# In[25]:
from detection.datasets.utils import get_original_image
ori_img = get_original_image(img, img_meta, img_mean)
# In[26]:
proposals = model.simple_test_rpn(img, img_meta)
# In[27]:
res = model.simple_test_bboxes(img, img_meta, proposals)
# In[28]:
visualize.display_instances(ori_img, res['rois'], res['class_ids'],
train_dataset.get_categories(), scores=res['scores'])