-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata.py
96 lines (78 loc) · 2.96 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import tensorflow as tf
import tensorflow_datasets as tfds
from functools import lru_cache
import numpy as np
import jax.numpy as jnp
from tensorflow.data.experimental import AUTOTUNE
import logging
def _convert_dtype(x):
return tf.cast(x, tf.float32) / 255
@tf.autograph.experimental.do_not_convert
def _augment_and_convert_dtype(x, y):
# rotate 0, 90, 180 or 270 deg
k = tf.random.uniform([], 0, 3, dtype=tf.int32)
x = tf.image.rot90(x, k)
# flip L/R 50% time
x = tf.image.random_flip_left_right(x)
# convert to float
x = _convert_dtype(x)
# colour distortion
x = tf.image.random_saturation(x, 0.5, 1.5)
x = tf.image.random_brightness(x, 0.1)
x = tf.image.random_contrast(x, 0.7, 1.3)
x = tf.clip_by_value(x, 0.0, 1.0)
return x, y
def _non_training_dataset(batch_size, ds_split):
@tf.autograph.experimental.do_not_convert
def _convert_image_dtype(x, y):
return _convert_dtype(x), y
dataset = (tfds.load('eurosat/rgb', split=ds_split,
as_supervised=True)
.map(_convert_image_dtype, num_parallel_calls=AUTOTUNE)
.batch(batch_size))
return tfds.as_numpy(dataset)
def validation_dataset(batch_size, sample_data=False):
# 2700 records
# [293, 307, 335, 258, 253, 194, 239, 284, 243, 294]
if sample_data:
split = 'train[80%:82%]'
else:
split = 'train[80%:90%]'
logging.debug("validation_dataset %s" % split)
return _non_training_dataset(batch_size, split)
def test_dataset(batch_size, sample_data=False):
# 2700 records
# [307, 300, 296, 221, 262, 216, 251, 296, 250, 301]
if sample_data:
split = 'train[90%:92%]'
else:
split = 'train[90%:]'
logging.debug("test_dataset %s" % split)
return _non_training_dataset(batch_size, split)
def training_dataset(batch_size, shuffle_seed, num_inputs=1, sample_data=False):
logging.debug("training_dataset shuffle_seed %d" % shuffle_seed)
if sample_data:
logging.warn("using small sample_data for training")
split = 'train[:2%]'
else:
split = 'train[:80%]'
logging.debug("training_dataset %s" % split)
dataset = (tfds.load('eurosat/rgb', split=split,
as_supervised=True)
.map(_augment_and_convert_dtype, num_parallel_calls=AUTOTUNE)
.shuffle(1024, seed=shuffle_seed))
if num_inputs == 1:
dataset = dataset.batch(batch_size)
else:
# TODO: don't use this in v2?
@tf.autograph.experimental.do_not_convert
def _reshape_inputs(x, y):
_b, h, w, c = x.shape
x = tf.reshape(x, (num_inputs, batch_size, h, w, c))
y = tf.reshape(y, (num_inputs, batch_size))
return x, y
dataset = dataset.batch(batch_size * num_inputs, drop_remainder=True)
dataset = dataset.map(_reshape_inputs)
pass
dataset = dataset.prefetch(AUTOTUNE)
return tfds.as_numpy(dataset)