-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathLog.lua
89 lines (75 loc) · 2.28 KB
/
Log.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
--[[
-- stores a log of various aspects of training for the purpose of visualization
]]
local json = require 'cjson'
local paths = require 'paths'
local function write_json(file, t)
local filename = file .. '.json'
local f = io.open(filename, 'w')
f:write(json.encode(t))
f:close()
end
local function load_json(file)
local filename = file .. '.json'
if not paths.filep(filename) then
return nil
end
local f = io.open(filename, 'r')
local contents = f:read('*a')
f:close()
return json.decode(contents)
end
local Log = torch.class("Log")
function Log:__init(name, hyperparams, saveDir, xLabel, saveFrequency)
self.name = name
self.xLabel = xLabel or "Iterations" -- name of the x axis, usually related to time / number of batches / epochs
self.hyperparams = hyperparams
self.saveLoc = paths.concat(saveDir, name)
self.saveFrequency = saveFrequency or 0
self.data = {}
self.updatesCounter = 0
if not paths.filep(self.saveLoc .. '.json') then
write_json(self.saveLoc, {})
end
-- update index file
local indexLoc = paths.concat(saveDir, 'index')
local models = {}
for f in paths.files(saveDir, '.json') do
if f ~= "index.json" and f:sub(1,1) ~= '.' then
table.insert(models, f:sub(1, -6))
end
end
write_json(indexLoc, models)
end
--[[
-- adds the data point (x, ys), where ys is a dictionary of different statistics to keep track of
]]
function Log:update(ys, x)
local x = x or self.updatesCounter
for name, y in pairs(ys) do
local point = {x = x, y = y }
-- if dataset doesn't exist, creat eit
if not self.data[name] then
self.data[name] = {}
end
-- add the point to it
table.insert(self.data[name], point)
end
self.updatesCounter = self.updatesCounter + 1
if self.saveFrequency > 0 and self.updatesCounter % self.saveFrequency == 0 then
self:save()
end
end
--[[
-- Saves all the data as saveDir/name.json, along with the given statistics
]]
function Log:save(stats)
local stats = stats or {}
write_json(self.saveLoc, {
name = self.name,
xLabel = self.xLabel,
hyperparams = self.hyperparams,
data = self.data,
stats = stats
});
end