Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YOLOX demo optimization #9

Open
wants to merge 1 commit into
base: release
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 30 additions & 28 deletions sample/YOLOX/python/postprocess_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,37 @@
import cv2

class PostProcess:
def __init__(self, conf_thresh=0.001, nms_thresh=0.7, agnostic=False, multi_label=True, max_det=300):
def __init__(self, input_h, input_w, conf_thresh=0.001, nms_thresh=0.7, agnostic=False, multi_label=True, max_det=300, p6=False):
self.input_h = input_h
self.input_w = input_w
self.conf_thresh = conf_thresh
self.nms_thresh = nms_thresh
self.agnostic_nms = agnostic
self.multi_label = multi_label
self.max_det = max_det
self.nms = pseudo_torch_nms()

self.grids = []
self.expanded_strides = []

if not p6:
strides = [8,16,32]
else:
strides = [8,16,32,64]

hsizes = [input_h // stride for stride in strides]
wsizes = [input_w // stride for stride in strides]

for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = np.meshgrid(np.arange(wsize),np.arange(hsize))
grid = np.stack((xv,yv),2).reshape(1,-1,2)
self.grids.append(grid)
shape = grid.shape[:2]
self.expanded_strides.append(np.full((*shape,1),stride))

self.grids = np.concatenate(self.grids,1)
self.expanded_strides = np.concatenate(self.expanded_strides,1)


def __call__(self, preds_batch, input_size, org_size_batch, ratios_batch, txy_batch):
"""
Expand All @@ -36,7 +59,7 @@ def __call__(self, preds_batch, input_size, org_size_batch, ratios_batch, txy_b
print('preds_batch type: '.format(type(preds_batch)))
raise NotImplementedError

dets = self.decode(preds_batch[0], *input_size)
dets = self.decode(preds_batch[0])



Expand Down Expand Up @@ -66,35 +89,14 @@ def __call__(self, preds_batch, input_size, org_size_batch, ratios_batch, txy_b
return outs


def decode(self, outputs, input_w, input_h, p6=False):
grids = []
expanded_strides = []

if not p6:
strides = [8,16,32]
else:
strides = [8,16,32,64]

hsizes = [input_h // stride for stride in strides]
wsizes = [input_w // stride for stride in strides]

for hsize, wsize, stride in zip(hsizes, wsizes, strides):
xv, yv = np.meshgrid(np.arange(wsize),np.arange(hsize))
grid = np.stack((xv,yv),2).reshape(1,-1,2)
grids.append(grid)
shape = grid.shape[:2]
expanded_strides.append(np.full((*shape,1),stride))

grids = np.concatenate(grids,1)
expanded_strides = np.concatenate(expanded_strides,1)


def decode(self, outputs):
for i in range(len(outputs)):

valid_indices = np.where(outputs[..., 4] > self.conf_thresh)[1]
expanded_strides = self.expanded_strides[:, valid_indices, :]
grids = self.grids[:, valid_indices, :]
outputs = outputs[:, valid_indices, :]
outputs[i][..., :2] = (outputs[i][..., :2] + grids) * expanded_strides
np.savetxt("python_center.txt",outputs[i][...,:2],fmt="%.4f")
outputs[i][..., 2:4] = np.exp(outputs[i][..., 2:4]) * expanded_strides

outputs[i][..., 5:] *= outputs[i][..., 4:5]

return outputs
Expand Down
7 changes: 6 additions & 1 deletion sample/YOLOX/python/yolox_bmcv.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __init__(self, args):
agnostic=self.agnostic,
multi_label=self.multi_label,
max_det=self.max_det,
input_h=self.net_h,
input_w=self.net_w
)

# init time
Expand Down Expand Up @@ -231,7 +233,7 @@ def draw_bmcv(bmcv, bmimg, boxes, classes_ids=None, conf_scores=None, save_path=

def main(args):
# check params
if not os.path.exists(args.input):
if not args.input.startswith("rtsp") and not args.input.startswith("rtmp") and not os.path.exists(args.input):
raise FileNotFoundError('{} is not existed.'.format(args.input))
if not os.path.exists(args.bmodel):
raise FileNotFoundError('{} is not existed.'.format(args.bmodel))
Expand Down Expand Up @@ -347,6 +349,8 @@ def main(args):
cn = 0
frame_list = []
while True:
if args.max_frames and cn > args.max_frames:
break
frame = sail.BMImage()
start_time = time.time()
ret = decoder.read(handle, frame)
Expand Down Expand Up @@ -395,6 +399,7 @@ def argsparser():
parser.add_argument('--dev_id', type=int, default=0, help='dev id')
parser.add_argument('--conf_thresh', type=float, default=0.40, help='confidence threshold')
parser.add_argument('--nms_thresh', type=float, default=0.5, help='nms threshold')
parser.add_argument('--max_frames', type=int, default=None, help='max number of frames to process')
args = parser.parse_args()
return args

Expand Down