-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsampler.py
92 lines (72 loc) · 3.15 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Functions to enable patch sampling
import numpy as np
def samplePatchWeighted(image, label, probs, patch_size):
"""
Function to sample a patch from an image with specified probabilities
:param image: the raw image.
:param label: the label map for the image.
:param probs: the probability with which we will sample this class as a list. Length must match number of classes.
:param patch_size: The size of the patch which will be sampled.
:return:
"""
# Check that the probabilities list matches the number of classes.
if not(np.unique(label).shape[0] == len(probs)):
raise(Exception("The length of the probabilites list does not match the number of label classes"))
if not(np.sum(np.array(probs)) == 1):
raise(Exception("The probabilities provided do not sum to 1."))
# Choose which class will be at the centre of the patch using the probabilities list
c = np.argmax(np.random.multinomial(1, probs))
# Choose the relevant part of the label map, convert to verticies and crop so that patch is within image limits
verts = np.nonzero(label[c, :, :] == c)
cx = np.unique(verts[:, 0])
cy = np.unique(verts[:, 1])
delta = int(np.floor(patch_size / 2))
vert_min = delta
vert_max = image.shape[1] - delta
cx_cropped = cx[(cx > vert_min) & (cx < vert_max)]
cy_cropped = cy[(cy > vert_min) & (cy < vert_max)]
# Check if this has resulted in no suitable vertices
if (cy_cropped.shape[0] < 1) or (cx_cropped.shape[0] < 1):
y = 256
x = 256
else:
# Randomly sample to get the central verticies of the patch
x = int(np.random.choice(cx_cropped, size=1))
y = int(np.random.choice(cy_cropped, size=1))
# Crop the patch from the image and label
image_cropped = image[:, x-delta:x+delta, y-delta:y+delta]
label_cropped = label[:, x-delta:x+delta, y-delta:y+delta]
if np.min(cx_cropped) - delta < 0:
print(cx_cropped)
raise(Exception("x coordinate is less than minimum"))
if np.min(cy_cropped) - delta < 0:
print(cy_cropped)
raise(Exception("y coordinate is less than minimum"))
if np.max(cx_cropped) + delta > 512:
print(cx_cropped)
raise (Exception("x coordinate is greater than maximum"))
if np.max(cy_cropped) + delta > 512:
print(cy_cropped)
raise (Exception("y coordinate is greater than maximum"))
if image_cropped.shape[2] < 256:
print(image_cropped.shape)
print(x), print(y)
print(cy_cropped), print(cx_cropped)
return image_cropped, label_cropped
def samplePatchRandom(image, label, patch_size):
"""
Randomly sample a patch from an image and return cropped label and patch/
:param image:
:param label:
:return:
"""
img_size = image.shape
maxW = img_size[2] - patch_size
maxH = img_size[1] - patch_size
# randomly select patch origin
xO = np.random.randint(0, maxH)
yO = np.random.randint(0, maxW)
# Select patch
image_patch = image[:, xO:xO + patch_size, yO:yO + patch_size]
label_patch = label[:, xO:xO + patch_size, yO:yO + patch_size]
return image_patch, label_patch