forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathThreadLocalDebugInfo.h
46 lines (35 loc) · 1.17 KB
/
ThreadLocalDebugInfo.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
#pragma once
#include <c10/macros/Export.h>
#include <memory>
#include <string>
namespace at {
// Thread local debug information is propagated across the forward
// (including async fork tasks) and backward passes and is supposed
// to be utilized by the user's code to pass extra information from
// the higher layers (e.g. model id) down to the operator callbacks
// (e.g. used for logging)
class CAFFE2_API ThreadLocalDebugInfoBase {
public:
ThreadLocalDebugInfoBase() {}
virtual ~ThreadLocalDebugInfoBase() {}
};
CAFFE2_API std::shared_ptr<ThreadLocalDebugInfoBase>
getThreadLocalDebugInfo() noexcept;
// Sets thread local debug information, returns the previously set
// debug information
CAFFE2_API std::shared_ptr<ThreadLocalDebugInfoBase>
setThreadLocalDebugInfo(
std::shared_ptr<ThreadLocalDebugInfoBase> info) noexcept;
class CAFFE2_API DebugInfoGuard {
public:
explicit DebugInfoGuard(
std::shared_ptr<ThreadLocalDebugInfoBase> info) {
prev_info_ = setThreadLocalDebugInfo(std::move(info));
}
~DebugInfoGuard() {
setThreadLocalDebugInfo(std::move(prev_info_));
}
private:
std::shared_ptr<ThreadLocalDebugInfoBase> prev_info_;
};
} // namespace at