From 41a2f0f78b2f77a23e8de7daeefd26eeb8630ddc Mon Sep 17 00:00:00 2001 From: liutingxi Date: Thu, 12 Oct 2023 16:37:45 +0800 Subject: [PATCH] 1. Support RTSP/RTMP inputs; 2. Optimized post-processing in sample/YOLOX/python/yolox_bmcv.py --- sample/YOLOX/python/postprocess_numpy.py | 58 ++++++++++++------------ sample/YOLOX/python/yolox_bmcv.py | 7 ++- 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/sample/YOLOX/python/postprocess_numpy.py b/sample/YOLOX/python/postprocess_numpy.py index cb388144..f44465c7 100644 --- a/sample/YOLOX/python/postprocess_numpy.py +++ b/sample/YOLOX/python/postprocess_numpy.py @@ -10,7 +10,9 @@ import cv2 class PostProcess: - def __init__(self, conf_thresh=0.001, nms_thresh=0.7, agnostic=False, multi_label=True, max_det=300): + def __init__(self, input_h, input_w, conf_thresh=0.001, nms_thresh=0.7, agnostic=False, multi_label=True, max_det=300, p6=False): + self.input_h = input_h + self.input_w = input_w self.conf_thresh = conf_thresh self.nms_thresh = nms_thresh self.agnostic_nms = agnostic @@ -18,6 +20,27 @@ def __init__(self, conf_thresh=0.001, nms_thresh=0.7, agnostic=False, multi_labe self.max_det = max_det self.nms = pseudo_torch_nms() + self.grids = [] + self.expanded_strides = [] + + if not p6: + strides = [8,16,32] + else: + strides = [8,16,32,64] + + hsizes = [input_h // stride for stride in strides] + wsizes = [input_w // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize),np.arange(hsize)) + grid = np.stack((xv,yv),2).reshape(1,-1,2) + self.grids.append(grid) + shape = grid.shape[:2] + self.expanded_strides.append(np.full((*shape,1),stride)) + + self.grids = np.concatenate(self.grids,1) + self.expanded_strides = np.concatenate(self.expanded_strides,1) + def __call__(self, preds_batch, input_size, org_size_batch, ratios_batch, txy_batch): """ @@ -36,7 +59,7 @@ def __call__(self, preds_batch, input_size, org_size_batch, ratios_batch, txy_b print('preds_batch type: '.format(type(preds_batch))) raise NotImplementedError - dets = self.decode(preds_batch[0], *input_size) + dets = self.decode(preds_batch[0]) @@ -66,35 +89,14 @@ def __call__(self, preds_batch, input_size, org_size_batch, ratios_batch, txy_b return outs - def decode(self, outputs, input_w, input_h, p6=False): - grids = [] - expanded_strides = [] - - if not p6: - strides = [8,16,32] - else: - strides = [8,16,32,64] - - hsizes = [input_h // stride for stride in strides] - wsizes = [input_w // stride for stride in strides] - - for hsize, wsize, stride in zip(hsizes, wsizes, strides): - xv, yv = np.meshgrid(np.arange(wsize),np.arange(hsize)) - grid = np.stack((xv,yv),2).reshape(1,-1,2) - grids.append(grid) - shape = grid.shape[:2] - expanded_strides.append(np.full((*shape,1),stride)) - - grids = np.concatenate(grids,1) - expanded_strides = np.concatenate(expanded_strides,1) - - + def decode(self, outputs): for i in range(len(outputs)): - + valid_indices = np.where(outputs[..., 4] > self.conf_thresh)[1] + expanded_strides = self.expanded_strides[:, valid_indices, :] + grids = self.grids[:, valid_indices, :] + outputs = outputs[:, valid_indices, :] outputs[i][..., :2] = (outputs[i][..., :2] + grids) * expanded_strides - np.savetxt("python_center.txt",outputs[i][...,:2],fmt="%.4f") outputs[i][..., 2:4] = np.exp(outputs[i][..., 2:4]) * expanded_strides - outputs[i][..., 5:] *= outputs[i][..., 4:5] return outputs diff --git a/sample/YOLOX/python/yolox_bmcv.py b/sample/YOLOX/python/yolox_bmcv.py index 54fcd4eb..0e8a070e 100644 --- a/sample/YOLOX/python/yolox_bmcv.py +++ b/sample/YOLOX/python/yolox_bmcv.py @@ -79,6 +79,8 @@ def __init__(self, args): agnostic=self.agnostic, multi_label=self.multi_label, max_det=self.max_det, + input_h=self.net_h, + input_w=self.net_w ) # init time @@ -231,7 +233,7 @@ def draw_bmcv(bmcv, bmimg, boxes, classes_ids=None, conf_scores=None, save_path= def main(args): # check params - if not os.path.exists(args.input): + if not args.input.startswith("rtsp") and not args.input.startswith("rtmp") and not os.path.exists(args.input): raise FileNotFoundError('{} is not existed.'.format(args.input)) if not os.path.exists(args.bmodel): raise FileNotFoundError('{} is not existed.'.format(args.bmodel)) @@ -347,6 +349,8 @@ def main(args): cn = 0 frame_list = [] while True: + if args.max_frames and cn > args.max_frames: + break frame = sail.BMImage() start_time = time.time() ret = decoder.read(handle, frame) @@ -395,6 +399,7 @@ def argsparser(): parser.add_argument('--dev_id', type=int, default=0, help='dev id') parser.add_argument('--conf_thresh', type=float, default=0.40, help='confidence threshold') parser.add_argument('--nms_thresh', type=float, default=0.5, help='nms threshold') + parser.add_argument('--max_frames', type=int, default=None, help='max number of frames to process') args = parser.parse_args() return args