forked from carpedm20/BEGAN-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_loader.py
56 lines (44 loc) · 1.75 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
from PIL import Image
from glob import glob
import tensorflow as tf
def get_loader(root, batch_size, scale_size, data_format, split=None, is_grayscale=False, seed=None):
dataset_name = os.path.basename(root)
if dataset_name in ['CelebA'] and split:
root = os.path.join(root, 'splits', split)
for ext in ["jpg", "png"]:
paths = glob("{}/*.{}".format(root, ext))
if ext == "jpg":
tf_decode = tf.image.decode_jpeg
elif ext == "png":
tf_decode = tf.image.decode_png
if len(paths) != 0:
break
with Image.open(paths[0]) as img:
w, h = img.size
shape = [h, w, 3]
filename_queue = tf.train.string_input_producer(list(paths), shuffle=False, seed=seed)
reader = tf.WholeFileReader()
filename, data = reader.read(filename_queue)
image = tf_decode(data, channels=3)
if is_grayscale:
image = tf.image.rgb_to_grayscale(image)
image.set_shape(shape)
min_after_dequeue = 5000
capacity = min_after_dequeue + 3 * batch_size
queue = tf.train.shuffle_batch(
[image], batch_size=batch_size,
num_threads=4, capacity=capacity,
min_after_dequeue=min_after_dequeue, name='synthetic_inputs')
if dataset_name in ['CelebA']:
queue = tf.image.crop_to_bounding_box(queue, 50, 25, 128, 128)
queue = tf.image.resize_nearest_neighbor(queue, [scale_size, scale_size])
else:
queue = tf.image.resize_nearest_neighbor(queue, [scale_size, scale_size])
if data_format == 'NCHW':
queue = tf.transpose(queue, [0, 3, 1, 2])
elif data_format == 'NHWC':
pass
else:
raise Exception("[!] Unkown data_format: {}".format(data_format))
return tf.to_float(queue)