forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdevice_checker.py
119 lines (112 loc) · 4.99 KB
/
device_checker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
## @package device_checker
# Module caffe2.python.device_checker
import numpy as np
import copy
from caffe2.python import workspace
from caffe2.python.core import InferOpBlobDevicesAsDict
class DeviceChecker:
"""A device checker in Python to check consistency across multiple devices.
This is not the most efficient way to check devices, as the Python interface
will involve a lot of copies back and forth operations. Use at your own risk.
"""
def __init__(self, threshold, device_options):
self._threshold = threshold
self._device_options = device_options
def CheckSimple(self, op, inputs, outputs_to_check,
input_device_options=None):
"""Checks the operator with different device implementations.
Inputs:
op: the operator to be checked.
inputs: the input data in numpy arrays.
outputs_to_check: the outputs to check between devices.
input_device_options: a mapping from input name to a device to use
(instead of self._device_options)
Outputs:
boolean: True if it passes, False if it does not pass.
"""
op = copy.deepcopy(op)
# Entering the checker workspace
old_ws_name = workspace.CurrentWorkspace()
results = []
workspace.SwitchWorkspace("_device_check_", True)
for i, device_option in enumerate(self._device_options):
op.device_option.CopyFrom(device_option)
_input_device_options = input_device_options or \
InferOpBlobDevicesAsDict(op)[0]
print(_input_device_options)
for i, arr in enumerate(inputs):
workspace.FeedBlob(
op.input[i], np.array(arr),
_input_device_options.get(op.input[i], device_option)
)
workspace.RunOperatorOnce(op)
results.append(
[workspace.FetchBlob(op.output[idx])
for idx in outputs_to_check])
# Everything is done, reset the workspace.
workspace.ResetWorkspace()
# After running on all devices, check correctness
success = True
for i in range(1, len(self._device_options)):
for j in range(len(outputs_to_check)):
x = results[i][j]
y = results[0][j]
if not np.allclose(x, y,
atol=self._threshold, rtol=self._threshold):
print('Failure in checking device option {}'
' and output {}. The outputs are:'
.format(i, op.output[outputs_to_check[j]]))
print(x.flatten())
print(y.flatten())
print(np.max(np.abs(x - y)))
success = False
# else:
# print ('Passed device pair (0, %d), %s %s' %
# (i, outputs_to_check[j], y.shape))
workspace.SwitchWorkspace(old_ws_name)
return success
def CheckNet(self, net, inputs=None, blobs_to_check=None, ignore=None):
"""Checks a network by inspecting all of its intermediate results, and
see if things match.
"""
if inputs is None:
inputs = {}
if ignore is None:
ignore = set()
old_ws_name = workspace.CurrentWorkspace()
results = []
if blobs_to_check is None:
blobs_to_check = sum([list(op.output) for op in net.op], [])
blobs_to_check = [b for b in blobs_to_check if b not in ignore]
workspace.SwitchWorkspace("_device_check_", True)
for device_option in self._device_options:
for name, arr in inputs.items():
# print 'feeding', name
workspace.FeedBlob(name, arr, device_option)
for op in net.op:
op.device_option.CopyFrom(device_option)
workspace.RunNetOnce(net)
results.append(
[workspace.FetchBlob(name) for name in blobs_to_check]
)
# After running on all devices, check correctness
success = True
for i in range(1, len(results)):
for j in range(len(blobs_to_check)):
x = results[i][j]
y = results[0][j]
if not np.allclose(x, y,
atol=self._threshold, rtol=self._threshold):
print('Failure in checking device option {}'
' and output {}. The outputs are:'
.format(i, blobs_to_check[j]))
print(x.flatten())
print(y.flatten())
print(np.max(np.abs(x - y)))
success = False
# else:
# print ('Passed device pair (%d, %d), %s %s: %s' %
# (i, j, blobs_to_check[j], y.shape,
# str(y.flatten())))
workspace.SwitchWorkspace(old_ws_name)
return success