diff --git a/asap/pointmatch/generate_point_matches_opencv.py b/asap/pointmatch/generate_point_matches_opencv.py index c1821696..d8b5ec9a 100644 --- a/asap/pointmatch/generate_point_matches_opencv.py +++ b/asap/pointmatch/generate_point_matches_opencv.py @@ -10,6 +10,8 @@ import pathlib2 as pathlib import renderapi +import imageio + from asap.pointmatch.schemas import ( PointMatchOpenCVParameters, PointMatchClientOutputSchema) @@ -46,7 +48,7 @@ def ransac_chunk(fargs): [k1xy, k2xy, des1, des2, k1ind, args] = fargs - FLANN_INDEX_KDTREE = 0 + FLANN_INDEX_KDTREE = 1 index_params = dict( algorithm=FLANN_INDEX_KDTREE, trees=args['FLANN_ntree']) @@ -90,11 +92,27 @@ def ransac_chunk(fargs): return k1, k2 +# FIXME work w/ tile min/max in layout +def to_8bpp(im, min_val=None, max_val=None): + if im.dtype == np.uint16: + if max_val is not None or min_val is not None: + min_val = min_val or 0 + max_val = max_val or 65535 + scale_factor = 255 / (max_val - min_val) + im = ((np.clip(im, min_val, max_val) - min_val) * scale_factor) + return (im).astype(np.uint8) + return im + + def read_downsample_equalize_mask_uri( - impath, scale, CLAHE_grid=None, CLAHE_clip=None): - im = cv2.imdecode( - np.fromstring(uri_utils.uri_readbytes(impath[0]), np.uint8), - 0) + impath, scale, CLAHE_grid=None, CLAHE_clip=None, min_val=None, max_val=None): + # im = cv2.imdecode( + # np.fromstring(uri_utils.uri_readbytes(impath[0]), np.uint8), + # 0) + im = imageio.v3.imread(uri_utils.uri_readbytes(impath[0])) + # FIXME this should be read from tilespec + max_val = max_val or im.max() + im = to_8bpp(im, min_val, max_val) im = cv2.resize(im, (0, 0), fx=scale, @@ -110,9 +128,10 @@ def read_downsample_equalize_mask_uri( im = cv2.equalizeHist(im) if impath[1] is not None: - mask = cv2.imdecode( - np.fromstring(uri_utils.uri_readbytes(impath[1]), np.uint8), - 0) + # mask = cv2.imdecode( + # np.fromstring(uri_utils.uri_readbytes(impath[1]), np.uint8), + # 0) + mask = imageio.v3.imread(uri_utils.uri_readbytes(impath[1])) mask = cv2.resize(mask, (0, 0), fx=scale, @@ -121,6 +140,7 @@ def read_downsample_equalize_mask_uri( im = cv2.bitwise_and(im, im, mask=mask) return im + # return to_8bpp(im, min_val, max_val) def read_downsample_equalize_mask( @@ -129,24 +149,101 @@ def read_downsample_equalize_mask( return read_downsample_equalize_mask_uri(uri_impath, *args, **kwargs) -def find_matches(fargs): - [impaths, ids, gids, args] = fargs +FLANN_INDEX_KDTREE = 1 + + +# TODO take this from existing ransac_chunk +def match_and_ransac( + loc_p, des_p, loc_q, des_q, + FLANN_ntree=5, ratio_of_dist=0.7, + FLANN_ncheck=50, RANSAC_outlier=5.0, + min_match_count=10, FLANN_index=FLANN_INDEX_KDTREE, **kwargs): + + index_params = dict( + algorithm=FLANN_INDEX_KDTREE, + trees=FLANN_ntree) + search_params = dict(checks=FLANN_ncheck) + flann = cv2.FlannBasedMatcher(index_params, search_params) + + matches = flann.knnMatch(des_p, des_q, k=2) + + # store all the good matches as per Lowe's ratio test. + good = [] + k1 = [] + k2 = [] + for m, n in matches: + if m.distance < ratio_of_dist * n.distance: + good.append(m) + if len(good) > min_match_count: + src_pts = np.float32( + [loc_p[m.queryIdx] for m in good]).reshape(-1, 1, 2) + dst_pts = np.float32( + [loc_q[m.trainIdx] for m in good]).reshape(-1, 1, 2) + M, mask = cv2.findHomography( + src_pts, + dst_pts, + cv2.RANSAC, + RANSAC_outlier) + matchesMask = mask.ravel().tolist() + + good = np.array(good)[np.array(matchesMask).astype('bool')] + imgIdx = np.array([g.imgIdx for g in good]) + tIdx = np.array([g.trainIdx for g in good]) + qIdx = np.array([g.queryIdx for g in good]) + for i in range(len(tIdx)): + if imgIdx[i] == 1: + k1.append(loc_p[tIdx[i]]) + k2.append(loc_q[qIdx[i]]) + else: + k1.append(loc_p[qIdx[i]]) + k2.append(loc_q[tIdx[i]]) + + return k1, k2 + + +# TODO change this to pq terminology +def chunk_match_keypoints( + loc1, des1, loc2, des2, ndiv=1, full_shape=None, + ransac_kwargs=None, **kwargs): + ransac_kwargs = ransac_kwargs or {} + if full_shape is None: + full_shape = np.ptp(np.concatenate([loc1, loc2]), axis=0) + + nr, nc = full_shape + + chunk_results = [] + + # FIXME better way than doing min and max of arrays + for i in range(ndiv): + r = np.arange(nr * i / ndiv, nr * (i + 1) / ndiv) + for j in range(ndiv): + c = np.arange(nc * j / ndiv, nc * (j + 1) / ndiv) + k1ind = np.argwhere( + (loc1[:, 0] >= r.min()) & + (loc1[:, 0] <= r.max()) & + (loc1[:, 1] >= c.min()) & + (loc1[:, 1] <= c.max())).flatten() + + chunk_results.append(match_and_ransac( + loc1[k1ind, ...], des1[k1ind, ...], + loc2, des2, + **{**ransac_kwargs, **kwargs})) + + p_results, q_results = zip(*chunk_results) + p_results = np.concatenate([i for i in p_results if len(i)]) + q_results = np.concatenate([i for i in q_results if len(i)]) + return p_results, q_results - pim = read_downsample_equalize_mask_uri( - impaths[0], - args['downsample_scale'], - CLAHE_grid=args['CLAHE_grid'], - CLAHE_clip=args['CLAHE_clip']) - qim = read_downsample_equalize_mask_uri( - impaths[1], - args['downsample_scale'], - CLAHE_grid=args['CLAHE_grid'], - CLAHE_clip=args['CLAHE_clip']) - sift = cv2.xfeatures2d.SIFT_create( - nfeatures=args['SIFT_nfeature'], - nOctaveLayers=args['SIFT_noctave'], - sigma=args['SIFT_sigma']) +def sift_match_images( + pim, qim, sift_kwargs=None, + ransac_kwargs=None, match_kwargs=None, + return_num_features=False, + **kwargs): + sift_kwargs = sift_kwargs or {} + match_kwargs = match_kwargs or {} + + sift = cv2.SIFT_create(**sift_kwargs) # find the keypoints and descriptors kp1, des1 = sift.detectAndCompute(pim, None) @@ -155,47 +252,107 @@ def find_matches(fargs): k1xy = np.array([np.array(k.pt) for k in kp1]) k2xy = np.array([np.array(k.pt) for k in kp2]) - nr, nc = pim.shape - k1 = [] - k2 = [] - ransac_args = [] - results = [] - ndiv = args['ndiv'] - for i in range(ndiv): - r = np.arange(nr*i/ndiv, nr*(i+1)/ndiv) - for j in range(ndiv): - c = np.arange(nc*j/ndiv, nc*(j+1)/ndiv) - k1ind = np.argwhere( - (k1xy[:, 0] >= r.min()) & - (k1xy[:, 0] <= r.max()) & - (k1xy[:, 1] >= c.min()) & - (k1xy[:, 1] <= c.max())).flatten() - ransac_args.append([k1xy, k2xy, des1, des2, k1ind, args]) - results.append(ransac_chunk(ransac_args[-1])) - - for result in results: - k1 += result[0] - k2 += result[1] - - if len(k1) >= 1: - k1 = np.array(k1) / args['downsample_scale'] - k2 = np.array(k2) / args['downsample_scale'] - - if k1.shape[0] > args['matchMax']: - a = np.arange(k1.shape[0]) - np.random.shuffle(a) - k1 = k1[a[0: args['matchMax']], :] - k2 = k2[a[0: args['matchMax']], :] - - render = renderapi.connect(**args['render']) - pm_dict = make_pm(ids, gids, k1, k2) - - renderapi.pointmatch.import_matches( - args['match_collection'], - [pm_dict], - render=render) - - return [impaths, len(kp1), len(kp2), len(k1), len(k2)] + k1, k2 = chunk_match_keypoints( + k1xy, des1, k2xy, des2, + full_shape=pim.shape, + ransac_kwargs=ransac_kwargs, + **{**match_kwargs, **kwargs} + ) + if return_num_features: + return (k1, k2), (len(kp1), len(kp2)) + return k1, k2 + + +def locs_to_dict( + pGroupId, pId, loc_p, + qGroupId, qId, loc_q, + scale_factor=1.0, match_max=1000): + if loc_p.shape[0] < 0: + return + loc_p *= scale_factor + loc_q *= scale_factor + + if loc_p.shape[0] > match_max: + ind = np.arange(loc_p.shape[0]) + np.random.shuffle(ind) + ind = ind[0:match_max] + loc_p = loc_p[ind, ...] + loc_q = loc_q[ind, ...] + + return make_pm( + (pId, qId), + (pGroupId, qGroupId), + loc_p, loc_q) + + +def process_matches( + pId, pGroupId, p_image_uri, + qId, qGroupId, q_image_uri, + downsample_scale=1.0, + CLAHE_grid=None, CLAHE_clip=None, + matchMax=1000, + sift_kwargs=None, + **kwargs): + + pim = read_downsample_equalize_mask_uri( + p_image_uri, + downsample_scale, + CLAHE_grid=CLAHE_grid, + CLAHE_clip=CLAHE_clip) + qim = read_downsample_equalize_mask_uri( + q_image_uri, + downsample_scale, + CLAHE_grid=CLAHE_grid, + CLAHE_clip=CLAHE_clip) + + (loc_p, loc_q), (num_features_p, num_features_q) = sift_match_images( + pim, qim, sift_kwargs=sift_kwargs, + return_num_features=True, + **kwargs) + + pm_dict = locs_to_dict( + pGroupId, pId, loc_p, + qGroupId, qId, loc_q, + scale_factor=(1. / downsample_scale), + match_max=matchMax) + + return pm_dict, len(loc_p), num_features_p, num_features_q + + +def find_matches(fargs): + [impaths, ids, gids, args] = fargs + + pm_dict, num_matches, num_features_p, num_features_q = process_matches( + ids[0], gids[0], impaths[0], + ids[1], gids[1], impaths[1], + downsample_scale=args["downsample_scale"], + CLAHE_grid=args["CLAHE_grid"], + CLAHE_clip=args["CLAHE_clip"], + sift_kwargs={ + "nfeatures": args["SIFT_nfeature"], + "nOctaveLayers": args['SIFT_noctave'], + "sigma": args['SIFT_sigma'] + }, + match_kwargs={ + "ndiv": args["ndiv"], + "FLANN_ntree": args["FLANN_ntree"], + "ratio_of_dist": args["ratio_of_dist"], + "FLANN_ncheck": args["FLANN_ncheck"] + }, + ransac_kwargs={ + "RANSAC_outlier": args["RANSAC_outlier"] + }, + matchMax=args["matchMax"] + ) + + render = renderapi.connect(**args['render']) + + renderapi.pointmatch.import_matches( + args['match_collection'], + [pm_dict], + render=render) + + return [impaths, num_features_p, num_features_q, num_matches, num_matches] def make_pm(ids, gids, k1, k2):