-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
Copy pathdata_feed.py
73 lines (62 loc) · 2.14 KB
/
data_feed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# coding=utf-8
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import os
from collections import OrderedDict
import cv2
import numpy as np
from PIL import Image, ImageEnhance
from paddle import fluid
DATA_DIM = 224
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
resized_width = int(round(img.size[0] * percent))
resized_height = int(round(img.size[1] * percent))
img = img.resize((resized_width, resized_height), Image.LANCZOS)
return img
def crop_image(img, target_size, center):
width, height = img.size
size = target_size
if center == True:
w_start = (width - size) / 2
h_start = (height - size) / 2
else:
w_start = np.random.randint(0, width - size + 1)
h_start = np.random.randint(0, height - size + 1)
w_end = w_start + size
h_end = h_start + size
img = img.crop((w_start, h_start, w_end, h_end))
return img
def process_image(img):
img = resize_short(img, target_size=256)
img = crop_image(img, target_size=DATA_DIM, center=True)
if img.mode != 'RGB':
img = img.convert('RGB')
#img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255
img -= img_mean
img /= img_std
return img
def test_reader(paths=None, images=None):
"""data generator
:param paths: path to images.
:type paths: list, each element is a str
:param images: data of images, [N, H, W, C]
:type images: numpy.ndarray
"""
img_list = []
if paths:
for img_path in paths:
assert os.path.isfile(img_path), "The {} isn't a valid file path.".format(img_path)
img = Image.open(img_path)
#img = cv2.imread(img_path)
img_list.append(img)
if images is not None:
for img in images:
img_list.append(Image.fromarray(np.uint8(img)))
for im in img_list:
im = process_image(im)
yield im