-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathmemory.py
53 lines (42 loc) · 1.4 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
import qsgd
"""
Gradient memory module
Keep unused gradient in memory and use them later.
Can be used with random or top k sparsifier.
"""
class GradientMemory:
def __init__(self, take_k=None, take_top=False, with_memory=False, qsgd_s=None):
self.with_memory = with_memory
self.take_top = take_top
self.take_k = take_k
self.qsgd_s = qsgd_s
self.m = None
def __call__(self, g, sparse=False): # , no_apply=False):
if self.qsgd_s:
return qsgd.quantize(g, self.qsgd_s)
if not self.take_k:
return g
if self.with_memory:
# create the memory if does not exist
if self.m is None:
self.m = np.zeros(g.shape, dtype=np.float64)
self.m += g
else:
self.m = g
# for k < 1 sometimes no gradient is used from the memory
# if no_apply:
# return None
d = np.prod(self.m.shape)
k = min(self.take_k, d)
if self.take_top:
indices = np.argpartition(np.abs(self.m.ravel()), -k)[-k:]
else:
indices = np.random.choice(d, k, replace=False)
if not sparse:
out_grad = np.zeros_like(self.m)
out_grad[indices] = self.m[indices]
else:
out_grad = (indices, self.m[indices])
self.m[indices] = 0.
return out_grad