Skip to content

Commit

Permalink
Merge pull request #264 from AllenInstitute/opencv_pm_update
Browse files Browse the repository at this point in the history
opencv SIFT match update
  • Loading branch information
RussTorres authored Dec 20, 2024
2 parents 7966d4b + cdc9810 commit ed8fc1e
Showing 1 changed file with 222 additions and 65 deletions.
287 changes: 222 additions & 65 deletions asap/pointmatch/generate_point_matches_opencv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import pathlib2 as pathlib
import renderapi

import imageio

from asap.pointmatch.schemas import (
PointMatchOpenCVParameters,
PointMatchClientOutputSchema)
Expand Down Expand Up @@ -46,7 +48,7 @@
def ransac_chunk(fargs):
[k1xy, k2xy, des1, des2, k1ind, args] = fargs

FLANN_INDEX_KDTREE = 0
FLANN_INDEX_KDTREE = 1
index_params = dict(
algorithm=FLANN_INDEX_KDTREE,
trees=args['FLANN_ntree'])
Expand Down Expand Up @@ -90,11 +92,27 @@ def ransac_chunk(fargs):
return k1, k2


# FIXME work w/ tile min/max in layout
def to_8bpp(im, min_val=None, max_val=None):
if im.dtype == np.uint16:
if max_val is not None or min_val is not None:
min_val = min_val or 0
max_val = max_val or 65535
scale_factor = 255 / (max_val - min_val)
im = ((np.clip(im, min_val, max_val) - min_val) * scale_factor)
return (im).astype(np.uint8)
return im


def read_downsample_equalize_mask_uri(
impath, scale, CLAHE_grid=None, CLAHE_clip=None):
im = cv2.imdecode(
np.fromstring(uri_utils.uri_readbytes(impath[0]), np.uint8),
0)
impath, scale, CLAHE_grid=None, CLAHE_clip=None, min_val=None, max_val=None):
# im = cv2.imdecode(
# np.fromstring(uri_utils.uri_readbytes(impath[0]), np.uint8),
# 0)
im = imageio.v3.imread(uri_utils.uri_readbytes(impath[0]))
# FIXME this should be read from tilespec
max_val = max_val or im.max()
im = to_8bpp(im, min_val, max_val)

im = cv2.resize(im, (0, 0),
fx=scale,
Expand All @@ -110,9 +128,10 @@ def read_downsample_equalize_mask_uri(
im = cv2.equalizeHist(im)

if impath[1] is not None:
mask = cv2.imdecode(
np.fromstring(uri_utils.uri_readbytes(impath[1]), np.uint8),
0)
# mask = cv2.imdecode(
# np.fromstring(uri_utils.uri_readbytes(impath[1]), np.uint8),
# 0)
mask = imageio.v3.imread(uri_utils.uri_readbytes(impath[1]))

mask = cv2.resize(mask, (0, 0),
fx=scale,
Expand All @@ -121,6 +140,7 @@ def read_downsample_equalize_mask_uri(
im = cv2.bitwise_and(im, im, mask=mask)

return im
# return to_8bpp(im, min_val, max_val)


def read_downsample_equalize_mask(
Expand All @@ -129,24 +149,101 @@ def read_downsample_equalize_mask(
return read_downsample_equalize_mask_uri(uri_impath, *args, **kwargs)


def find_matches(fargs):
[impaths, ids, gids, args] = fargs
FLANN_INDEX_KDTREE = 1


# TODO take this from existing ransac_chunk
def match_and_ransac(
loc_p, des_p, loc_q, des_q,
FLANN_ntree=5, ratio_of_dist=0.7,
FLANN_ncheck=50, RANSAC_outlier=5.0,
min_match_count=10, FLANN_index=FLANN_INDEX_KDTREE, **kwargs):

index_params = dict(
algorithm=FLANN_INDEX_KDTREE,
trees=FLANN_ntree)
search_params = dict(checks=FLANN_ncheck)
flann = cv2.FlannBasedMatcher(index_params, search_params)

matches = flann.knnMatch(des_p, des_q, k=2)

# store all the good matches as per Lowe's ratio test.
good = []
k1 = []
k2 = []
for m, n in matches:
if m.distance < ratio_of_dist * n.distance:
good.append(m)
if len(good) > min_match_count:
src_pts = np.float32(
[loc_p[m.queryIdx] for m in good]).reshape(-1, 1, 2)
dst_pts = np.float32(
[loc_q[m.trainIdx] for m in good]).reshape(-1, 1, 2)
M, mask = cv2.findHomography(
src_pts,
dst_pts,
cv2.RANSAC,
RANSAC_outlier)
matchesMask = mask.ravel().tolist()

good = np.array(good)[np.array(matchesMask).astype('bool')]
imgIdx = np.array([g.imgIdx for g in good])
tIdx = np.array([g.trainIdx for g in good])
qIdx = np.array([g.queryIdx for g in good])
for i in range(len(tIdx)):
if imgIdx[i] == 1:
k1.append(loc_p[tIdx[i]])
k2.append(loc_q[qIdx[i]])
else:
k1.append(loc_p[qIdx[i]])
k2.append(loc_q[tIdx[i]])

return k1, k2


# TODO change this to pq terminology
def chunk_match_keypoints(
loc1, des1, loc2, des2, ndiv=1, full_shape=None,
ransac_kwargs=None, **kwargs):
ransac_kwargs = ransac_kwargs or {}
if full_shape is None:
full_shape = np.ptp(np.concatenate([loc1, loc2]), axis=0)

nr, nc = full_shape

chunk_results = []

# FIXME better way than doing min and max of arrays
for i in range(ndiv):
r = np.arange(nr * i / ndiv, nr * (i + 1) / ndiv)
for j in range(ndiv):
c = np.arange(nc * j / ndiv, nc * (j + 1) / ndiv)
k1ind = np.argwhere(
(loc1[:, 0] >= r.min()) &
(loc1[:, 0] <= r.max()) &
(loc1[:, 1] >= c.min()) &
(loc1[:, 1] <= c.max())).flatten()

chunk_results.append(match_and_ransac(
loc1[k1ind, ...], des1[k1ind, ...],
loc2, des2,
**{**ransac_kwargs, **kwargs}))

p_results, q_results = zip(*chunk_results)
p_results = np.concatenate([i for i in p_results if len(i)])
q_results = np.concatenate([i for i in q_results if len(i)])
return p_results, q_results

pim = read_downsample_equalize_mask_uri(
impaths[0],
args['downsample_scale'],
CLAHE_grid=args['CLAHE_grid'],
CLAHE_clip=args['CLAHE_clip'])
qim = read_downsample_equalize_mask_uri(
impaths[1],
args['downsample_scale'],
CLAHE_grid=args['CLAHE_grid'],
CLAHE_clip=args['CLAHE_clip'])

sift = cv2.xfeatures2d.SIFT_create(
nfeatures=args['SIFT_nfeature'],
nOctaveLayers=args['SIFT_noctave'],
sigma=args['SIFT_sigma'])
def sift_match_images(
pim, qim, sift_kwargs=None,
ransac_kwargs=None, match_kwargs=None,
return_num_features=False,
**kwargs):
sift_kwargs = sift_kwargs or {}
match_kwargs = match_kwargs or {}

sift = cv2.SIFT_create(**sift_kwargs)

# find the keypoints and descriptors
kp1, des1 = sift.detectAndCompute(pim, None)
Expand All @@ -155,47 +252,107 @@ def find_matches(fargs):
k1xy = np.array([np.array(k.pt) for k in kp1])
k2xy = np.array([np.array(k.pt) for k in kp2])

nr, nc = pim.shape
k1 = []
k2 = []
ransac_args = []
results = []
ndiv = args['ndiv']
for i in range(ndiv):
r = np.arange(nr*i/ndiv, nr*(i+1)/ndiv)
for j in range(ndiv):
c = np.arange(nc*j/ndiv, nc*(j+1)/ndiv)
k1ind = np.argwhere(
(k1xy[:, 0] >= r.min()) &
(k1xy[:, 0] <= r.max()) &
(k1xy[:, 1] >= c.min()) &
(k1xy[:, 1] <= c.max())).flatten()
ransac_args.append([k1xy, k2xy, des1, des2, k1ind, args])
results.append(ransac_chunk(ransac_args[-1]))

for result in results:
k1 += result[0]
k2 += result[1]

if len(k1) >= 1:
k1 = np.array(k1) / args['downsample_scale']
k2 = np.array(k2) / args['downsample_scale']

if k1.shape[0] > args['matchMax']:
a = np.arange(k1.shape[0])
np.random.shuffle(a)
k1 = k1[a[0: args['matchMax']], :]
k2 = k2[a[0: args['matchMax']], :]

render = renderapi.connect(**args['render'])
pm_dict = make_pm(ids, gids, k1, k2)

renderapi.pointmatch.import_matches(
args['match_collection'],
[pm_dict],
render=render)

return [impaths, len(kp1), len(kp2), len(k1), len(k2)]
k1, k2 = chunk_match_keypoints(
k1xy, des1, k2xy, des2,
full_shape=pim.shape,
ransac_kwargs=ransac_kwargs,
**{**match_kwargs, **kwargs}
)
if return_num_features:
return (k1, k2), (len(kp1), len(kp2))
return k1, k2


def locs_to_dict(
pGroupId, pId, loc_p,
qGroupId, qId, loc_q,
scale_factor=1.0, match_max=1000):
if loc_p.shape[0] < 0:
return
loc_p *= scale_factor
loc_q *= scale_factor

if loc_p.shape[0] > match_max:
ind = np.arange(loc_p.shape[0])
np.random.shuffle(ind)
ind = ind[0:match_max]
loc_p = loc_p[ind, ...]
loc_q = loc_q[ind, ...]

return make_pm(
(pId, qId),
(pGroupId, qGroupId),
loc_p, loc_q)


def process_matches(
pId, pGroupId, p_image_uri,
qId, qGroupId, q_image_uri,
downsample_scale=1.0,
CLAHE_grid=None, CLAHE_clip=None,
matchMax=1000,
sift_kwargs=None,
**kwargs):

pim = read_downsample_equalize_mask_uri(
p_image_uri,
downsample_scale,
CLAHE_grid=CLAHE_grid,
CLAHE_clip=CLAHE_clip)
qim = read_downsample_equalize_mask_uri(
q_image_uri,
downsample_scale,
CLAHE_grid=CLAHE_grid,
CLAHE_clip=CLAHE_clip)

(loc_p, loc_q), (num_features_p, num_features_q) = sift_match_images(
pim, qim, sift_kwargs=sift_kwargs,
return_num_features=True,
**kwargs)

pm_dict = locs_to_dict(
pGroupId, pId, loc_p,
qGroupId, qId, loc_q,
scale_factor=(1. / downsample_scale),
match_max=matchMax)

return pm_dict, len(loc_p), num_features_p, num_features_q


def find_matches(fargs):
[impaths, ids, gids, args] = fargs

pm_dict, num_matches, num_features_p, num_features_q = process_matches(
ids[0], gids[0], impaths[0],
ids[1], gids[1], impaths[1],
downsample_scale=args["downsample_scale"],
CLAHE_grid=args["CLAHE_grid"],
CLAHE_clip=args["CLAHE_clip"],
sift_kwargs={
"nfeatures": args["SIFT_nfeature"],
"nOctaveLayers": args['SIFT_noctave'],
"sigma": args['SIFT_sigma']
},
match_kwargs={
"ndiv": args["ndiv"],
"FLANN_ntree": args["FLANN_ntree"],
"ratio_of_dist": args["ratio_of_dist"],
"FLANN_ncheck": args["FLANN_ncheck"]
},
ransac_kwargs={
"RANSAC_outlier": args["RANSAC_outlier"]
},
matchMax=args["matchMax"]
)

render = renderapi.connect(**args['render'])

renderapi.pointmatch.import_matches(
args['match_collection'],
[pm_dict],
render=render)

return [impaths, num_features_p, num_features_q, num_matches, num_matches]


def make_pm(ids, gids, k1, k2):
Expand Down

0 comments on commit ed8fc1e

Please sign in to comment.