-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
56 lines (49 loc) · 2.37 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torchvision.transforms as transforms
from PIL import Image
def print_examples(model, device, dataset):
transform = transforms.Compose(
[
transforms.Resize((299, 299)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
model.eval()
test_img1 = transform(Image.open("test_examples/CXR1_1_IM-0001-3001.png").convert("RGB")).unsqueeze(
0
)
print("Example 1 CORRECT: The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax.")
print(
"Example 1 OUTPUT: "
+ " ".join(model.caption_image(test_img1.to(device), dataset.vocab))
)
# --------------------------
test_img2 = transform(Image.open("test_examples/CXR2_IM-0652-1001.png").convert("RGB")).unsqueeze(
0
)
print("Example 2 CORRECT: Borderline cardiomegaly. Midline sternotomy XXXX. Enlarged pulmonary arteries. Clear lungs. Inferior XXXX XXXX XXXX.")
print(
"Example 2 OUTPUT: "
+ " ".join(model.caption_image(test_img2.to(device), dataset.vocab))
)
# --------------------------
test_img3 = transform(Image.open("test_examples/CXR4_IM-2050-1001.png").convert("RGB")).unsqueeze(
0
)
print("Example 3 CORRECT: There are diffuse bilateral interstitial and alveolar opacities consistent with chronic obstructive lung disease and bullous emphysema. There are irregular opacities in the left lung apex, that could represent a cavitary lesion in the left lung apex.There are streaky opacities in the right upper lobe, XXXX scarring. The cardiomediastinal silhouette is normal in size and contour. There is no pneumothorax or large pleural effusion.")
print(
"Example 3 OUTPUT: "
+ " ".join(model.caption_image(test_img3.to(device), dataset.vocab))
)
# --------------------------
model.train()
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
torch.save(state, filename)
def load_checkpoint(checkpoint, model, optimizer):
print("=> Loading checkpoint")
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
step = checkpoint["step"]
return step