-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheigenvi.jl
94 lines (82 loc) · 2.87 KB
/
eigenvi.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
include("bam.jl")
include("polynomials.jl")
using Arpack
function solve_eigenvalue_problem(A; B=nothing, arpack=false)
# if C++ is installed, we can use ARPACK
if arpack
if B != nothing
obj, α = Arpack.eigs(A, B; nev=1, which=:SM, maxiter=800, tol=1e-6)
α /= norm(α) # rescale
else
obj, α = Arpack.eigs(A; nev=1, which=:SM)
end
return vec(α), A, obj[1]
else
# Compute eigenvalues and eigenvectors
if B != nothing
evals, evecs = eigen(A, B)
# Most basic usage:
else
evals, evecs = eigen(A)
end
min_ind = argmin(real(evals))
α = @view(evecs[:, min_ind])
α /= norm(α) # rescale
return α, A, evals[min_ind]
end
end
function eigenVI_2D(K, X, dlogP, basis_fn, d_basis_fn; grads=false, denom=false, arpack=false, logPi=nothing)
D, N = size(X)
@assert D == 2
W = zeros(K^D, K^D)
if denom
V = zeros(K^D, K^D)
end
for n in 1:N
x = @view(X[:, n])
# importance weight; by default is uniform
if logPi != nothing
logweight = -logPi(x)
else
logweight = -log(N)
end
for d in 1:D
w_nd = zeros(K, K)
phi = zeros(K, K)
for k1 in 1:K, k2 in 1:K
# assumes D == 2
if d == 1
if grads
w_ndk = 2 * d_basis_fn(x[1], k1-1) * basis_fn(x[2], k2-1) - dlogP[d,n] * basis_fn(x[1], k1-1) * basis_fn(x[2], k2-1)
else
w_ndk = 2 * d_basis_fn(x[1], k1-1) * basis_fn(x[2], k2-1) - dlogP(x)[d] * basis_fn(x[1], k1-1) * basis_fn(x[2], k2-1)
end
else
if grads
w_ndk = 2 * basis_fn(x[1], k1-1) * d_basis_fn(x[2], k2-1) - dlogP[d,n] * basis_fn(x[1], k1-1) * basis_fn(x[2], k2-1)
else
w_ndk = 2 * basis_fn(x[1], k1-1) * d_basis_fn(x[2], k2-1) - dlogP(x)[d] * basis_fn(x[1], k1-1) * basis_fn(x[2], k2-1)
end
end
w_nd[k1, k2] = w_ndk
if denom
phi_ndk = basis_fn(x[1], k1-1) * basis_fn(x[2], k2-1)
phi[k1, k2] = phi_ndk
end
end
# flatten into a vector
w = vec(w_nd)
# sum over all dimension 1:D and samples 1:N
W += w * w' * exp(logweight)
if denom
phiv = vec(phi)
V += phiv * phiv' * exp(logweight)
end
end
end
if denom
return solve_eigenvalue_problem(W; B=V, arpack=arpack)
else
return solve_eigenvalue_problem(W; B=nothing, arpack=arpack)
end
end