-
Notifications
You must be signed in to change notification settings - Fork 124
/
Copy pathdataset.py
100 lines (68 loc) · 3.11 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Code with dataset loader for VOC12 and Cityscapes (adapted from bodokaiser/piwise code)
# Sept 2017
# Eduardo Romera
#######################
import numpy as np
import os
from PIL import Image
from torch.utils.data import Dataset
EXTENSIONS = ['.jpg', '.png']
def load_image(file):
return Image.open(file)
def is_image(filename):
return any(filename.endswith(ext) for ext in EXTENSIONS)
def is_label(filename):
return filename.endswith("_labelTrainIds.png")
def image_path(root, basename, extension):
return os.path.join(root, f'{basename}{extension}')
def image_path_city(root, name):
return os.path.join(root, f'{name}')
def image_basename(filename):
return os.path.basename(os.path.splitext(filename)[0])
class VOC12(Dataset):
def __init__(self, root, input_transform=None, target_transform=None):
self.images_root = os.path.join(root, 'images')
self.labels_root = os.path.join(root, 'labels')
self.filenames = [image_basename(f)
for f in os.listdir(self.labels_root) if is_image(f)]
self.filenames.sort()
self.input_transform = input_transform
self.target_transform = target_transform
def __getitem__(self, index):
filename = self.filenames[index]
with open(image_path(self.images_root, filename, '.jpg'), 'rb') as f:
image = load_image(f).convert('RGB')
with open(image_path(self.labels_root, filename, '.png'), 'rb') as f:
label = load_image(f).convert('P')
if self.input_transform is not None:
image = self.input_transform(image)
if self.target_transform is not None:
label = self.target_transform(label)
return image, label
def __len__(self):
return len(self.filenames)
class cityscapes(Dataset):
def __init__(self, root, input_transform=None, target_transform=None, subset='val'):
self.images_root = os.path.join(root, 'leftImg8bit/' + subset)
self.labels_root = os.path.join(root, 'gtFine/' + subset)
self.filenames = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(self.images_root)) for f in fn if is_image(f)]
self.filenames.sort()
self.filenamesGt = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(self.labels_root)) for f in fn if is_label(f)]
self.filenamesGt.sort()
self.input_transform = input_transform
self.target_transform = target_transform
def __getitem__(self, index):
filename = self.filenames[index]
filenameGt = self.filenamesGt[index]
#print(filename)
with open(image_path_city(self.images_root, filename), 'rb') as f:
image = load_image(f).convert('RGB')
with open(image_path_city(self.labels_root, filenameGt), 'rb') as f:
label = load_image(f).convert('P')
if self.input_transform is not None:
image = self.input_transform(image)
if self.target_transform is not None:
label = self.target_transform(label)
return image, label, filename, filenameGt
def __len__(self):
return len(self.filenames)