-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbase.py
147 lines (117 loc) · 3.94 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from __future__ import division
import os.path as osp
import numpy as np
import tensorflow as tf
def make_session(frac=None):
"""
Create a tf.Session(), limiting fraction of gpu that is allocated.
"""
if frac is None:
return tf.Session()
return tf.Session(config=tf.ConfigProto(
gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=frac)
))
def make_placeholders(variables, dtype='tf.float32'):
"""
Input: a dict{name: shape} or {name: (shape, dtype_str)}
<shape> is a tuple of ints
<dtype_str> is a str like 'tf.float32', 'tf.int32', etc
If `dtype_str` is not given for a placeholder,
it will use the one passed into this function,
which is tf.float32 by default.
Usage:
variables = {
'X_pl': (4, 5),
'Y_pl': ((2, 3), 'tf.int32')
}
for var in make_placeholders(variables):
exec(var)
Z = X_pl + Y_pl
"""
commands = []
for name, args in variables.items():
if len(args) == 2 and isinstance(args[1], basestring):
shape, dtype = args
else:
shape, dtype = args, 'tf.float32'
commands.append(
"{0} = tf.placeholder({2}, {1}, name='{0}')".format(
name, shape, dtype
)
)
return commands
def scoped_variable(var_name, scope_name, **kwargs):
"""
Get a variable from a scope, or create it if it doesn't exist.
**kwargs will be passed to tf.get_variable if a new one is created.
:param var_name: the variable name
:param scope_name: the scope name
"""
try:
with tf.variable_scope(scope_name) as scope:
return tf.get_variable(var_name, **kwargs)
except ValueError:
with tf.variable_scope(scope_name, reuse=True) as scope:
return tf.get_variable(var_name, **kwargs)
def make_scoped_cell(CellType, **scope_kwargs):
"""
Take a cell from `tf.nn.rnn_cell`,
and make a version of it that sets `reuse` in its scope as needed.
For example:
```
from tf.nn.rnn_cell import BasicLSTMCell
ScopedLSTMCell = tfu.make_scoped_cell(BasicLSTMCell)
```
Now, `ScopedLSTMCell` can be used in place of `BasicLSTMCell`,
and it should take care of reusing correctly.
"""
class ScopedCell(CellType):
def __init__(self, scope_name, *args, **kwargs):
self.name = scope_name
super(ScopedCell, self).__init__(*args, **kwargs)
def __call__(self, X, H):
try:
with tf.variable_scope(self.name, **scope_kwargs) as scope:
return super(ScopedCell, self).__call__(X, H)
except ValueError:
with tf.variable_scope(self.name, reuse=True, **scope_kwargs) as scope:
return super(ScopedCell, self).__call__(X, H)
ScopedCell.__name__ = "Scoped%s" % CellType.__name__
return ScopedCell
class struct(dict):
"""
A dict that exposes its entries as attributes.
"""
def __init__(self, **kwargs):
dict.__init__(self, kwargs)
self.__dict__ = self
def structify(obj):
"""
Modify `obj` by replacing `dict`s with `tfu.struct`s.
"""
if isinstance(obj, dict):
obj = struct(**{
key: structify(val) for key, val in obj.items()
})
elif isinstance(obj, list):
obj = [structify(val) for val in obj]
return obj
"""
Below here should not exposed.
"""
def _default_value(params_dict, key, value):
if key not in params_dict:
params_dict[key] = value
def _validate_axes(axes):
if isinstance(axes, int):
axes = [axes]
assert axes is None or isinstance(axes, list)
return axes
def _get_existing_vars(names_and_scopes):
variables = {}
for name, scope in names_and_scopes:
try:
variables[(name, scope)] = scoped_variable(name, scope)
except ValueError:
pass
return variables