-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdataset.py
130 lines (108 loc) · 4.65 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch.utils.data as data
import json
import random
from PIL import Image
import numpy as np
import torch
import os
Vis_CLSNAMES = ['candle', 'capsules', 'cashew', 'chewinggum', 'fryum', 'macaroni1', 'macaroni2',
'pcb1', 'pcb2', 'pcb3', 'pcb4', 'pipe_fryum']
Vis_CLSNAMES_map_index = {}
for k, index in zip(Vis_CLSNAMES, range(len(Vis_CLSNAMES))):
Vis_CLSNAMES_map_index[k] = index
CLSNAMES = ['carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill',
'transistor', 'metal_nut', 'screw', 'toothbrush', 'zipper', 'tile', 'wood']
CLSNAMES_map_index = {}
for k, index in zip(CLSNAMES, range(len(CLSNAMES))):
CLSNAMES_map_index[k] = index
class VisaDataset(data.Dataset):
def __init__(self, root, transform, target_transform, mode='test', k_shot=0, save_dir=None, obj_name=None):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.data_all = []
meta_info = json.load(open(f'{self.root}/meta.json', 'r'))
name = self.root.split('/')[-1]
meta_info = meta_info[mode]
if mode == 'train':
self.cls_names = [obj_name]
save_dir = os.path.join(save_dir, 'k_shot.txt')
else:
self.cls_names = list(meta_info.keys())
for cls_name in self.cls_names:
if mode == 'train':
data_tmp = meta_info[cls_name]
indices = torch.randint(0, len(data_tmp), (k_shot,))
for i in range(len(indices)):
self.data_all.append(data_tmp[indices[i]])
with open(save_dir, "a") as f:
f.write(data_tmp[indices[i]]['img_path'] + '\n')
else:
self.data_all.extend(meta_info[cls_name])
self.length = len(self.data_all)
def __len__(self):
return self.length
def __getitem__(self, index):
data = self.data_all[index]
img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
data['specie_name'], data['anomaly']
img = Image.open(os.path.join(self.root, img_path))
if anomaly == 0:
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
img = self.transform(img) if self.transform is not None else img
img_mask = self.target_transform(
img_mask) if self.target_transform is not None and img_mask is not None else img_mask
img_mask = [] if img_mask is None else img_mask
return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
'img_path': os.path.join(self.root, img_path), "cls_id":Vis_CLSNAMES_map_index[cls_name]}
class MVTecDataset(data.Dataset):
def __init__(self, root, transform, target_transform, aug_rate, mode='test', k_shot=0, save_dir=None, obj_name=None):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.aug_rate = aug_rate
self.data_all = []
meta_info = json.load(open(f'{self.root}/meta.json', 'r'))
name = self.root.split('/')[-1]
meta_info = meta_info[mode]
if mode == 'train':
if isinstance(obj_name, list):
self.cls_names = obj_name
else:
self.cls_names = [obj_name]
save_dir = os.path.join(save_dir, 'k_shot.txt')
else:
self.cls_names = list(meta_info.keys())
for cls_name in self.cls_names:
if mode == 'train':
data_tmp = meta_info[cls_name]
indices = torch.randint(0, len(data_tmp), (k_shot,))
for i in range(len(indices)):
self.data_all.append(data_tmp[indices[i]])
with open(save_dir, "a") as f:
f.write(data_tmp[indices[i]]['img_path'] + '\n')
else:
self.data_all.extend(meta_info[cls_name])
self.length = len(self.data_all)
def __len__(self):
return self.length
def __getitem__(self, index):
data = self.data_all[index]
img_path, mask_path, cls_name, specie_name, anomaly = data['img_path'], data['mask_path'], data['cls_name'], \
data['specie_name'], data['anomaly']
img = Image.open(os.path.join(self.root, img_path))
if anomaly == 0:
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
img_mask = np.array(Image.open(os.path.join(self.root, mask_path)).convert('L')) > 0
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
# transforms
img = self.transform(img) if self.transform is not None else img
img_mask = self.target_transform(
img_mask) if self.target_transform is not None and img_mask is not None else img_mask
img_mask = [] if img_mask is None else img_mask
return {'img': img, 'img_mask': img_mask, 'cls_name': cls_name, 'anomaly': anomaly,
'img_path': os.path.join(self.root, img_path), "cls_id":CLSNAMES_map_index[cls_name]}