-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathGraph.lua
39 lines (34 loc) · 890 Bytes
/
Graph.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
local graph = {}
-- convert tensor of edges to adjacency list
function graph.tensorToAdj(edges)
local adj = {}
for i = 1, edges:size(1) do
local edge = edges[i]
local s, e = edge[1], edge[2]
if not adj[s] then
adj[s] = {}
end
table.insert(adj[s], e)
end
return adj
end
-- computes transitive closure of the given tensor of directed edges
function graph.transitiveClosure(edges)
local adj = graph.tensorToAdj(edges)
local edges = {}
local function dfs(s, e) -- dfs starting from s, currently at e
if e ~= s then
table.insert(edges, {s,e})
end
if adj[e] then
for _, e in ipairs(adj[e]) do
dfs(s, e)
end
end
end
for s, _ in pairs(adj) do
dfs(s, s)
end
return torch.LongTensor(edges)
end
return graph