-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtyped.lua
172 lines (152 loc) · 4.12 KB
/
typed.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
--------------------------------------------------------------------------------
-- Lua programming with types
--------------------------------------------------------------------------------
local _, inspect = pcall(require, "inspect")
inspect = inspect or tostring
local typed = {}
local FAST = false
local function is_sequence(xs)
if type(xs) ~= "table" then
return false
end
if FAST then
return true
end
local l = #xs
for k, _ in pairs(xs) do
if type(k) ~= "number" or k < 1 or k > l or math.floor(k) ~= k then
return false
end
end
return true
end
local function type_of(t)
local mt = getmetatable(t)
return (mt and mt.__name) or (is_sequence(t) and "array") or type(t)
end
local function set_type(t, typ)
local mt = getmetatable(t)
if not mt then
mt = {}
end
mt.__name = typ
return setmetatable(t, mt)
end
local function typed_table(typ, t)
return set_type(t, typ)
end
local function try_check(val, expected)
local optional = expected:match("^(.*)%?$")
if optional then
if val == nil then
return true
end
expected = optional
end
local seq_type = expected:match("^{(.+)}$")
if seq_type then
if type(val) == "table" then
if FAST then
return true
end
local allok = true
for _, v in ipairs(val) do
local ok = try_check(v, seq_type)
if not ok then
allok = false
break
end
end
if allok then
return true
end
end
end
-- if all we want is a table, don't perform further checks
if expected == "table" and type(val) == "table" then
return true
end
local actual = type_of(val)
if actual == expected then
return true
end
return nil, actual
end
local function typed_check(val, expected, category, n)
local ok, actual = try_check(val, expected)
if ok then
return true
end
if category and n then
error(("type error: %s %d: expected %s, got %s (%s)"):format(category, n, expected, actual, inspect(val)), category == "value" and 2 or 3)
else
error(("type error: expected %s, got %s (%s)"):format(expected, actual, inspect(val)), 2)
end
end
local function split(s, sep)
local i, j, k = 1, s:find(sep, 1)
local out = {}
while j do
table.insert(out, s:sub(i, j - 1))
i = k + 1
j, k = s:find(sep, i)
end
table.insert(out, s:sub(i, #s))
return out
end
local function typed_function(types, fn)
local inp, outp = types:match("(.*[^%s])%s*%->%s*([^%s].*)")
local ins = split(inp, ",%s*")
local outs = split(outp, ",%s*")
return function(...)
local args = table.pack(...)
if args.n ~= #ins then
error("wrong number of inputs (given " .. args.n .. " - expects " .. types .. ")", 2)
end
for i = 1, #ins do
typed_check(args[i], ins[i], "argument", i)
end
local rets = table.pack(fn(...))
if outp == "()" then
if rets.n ~= 0 then
error("wrong number of outputs (given " .. rets.n .. " - expects " .. types .. ")", 2)
end
else
if rets.n ~= #outs then
error("wrong number of outputs (given " .. rets.n .. " - expects " .. types .. ")", 2)
end
if outs[1] ~= "*" then
for i = 1, #outs do
typed_check(rets[i], outs[i], "return", i)
end
end
end
return table.unpack(rets, 1, rets.n)
end
end
local typed_mt_on = {
__call = function(_, types, fn)
return typed_function(types, fn)
end
}
local typed_mt_off = {
__call = function(_, _, fn)
return fn
end
}
function typed.on()
typed.check = typed_check
typed.typed = typed_function
typed.set_type = set_type
typed.table = typed_table
setmetatable(typed, typed_mt_on)
end
function typed.off()
typed.check = function() end
typed.typed = function(_, fn) return fn end
typed.set_type = function(t, _) return t end
typed.table = function(_, t) return t end
setmetatable(typed, typed_mt_off)
end
typed.off()
return typed