Skip to content

Commit

Permalink
added ran, pt1
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Oct 23, 2024
1 parent a28c24f commit 0f2b1f9
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 5 deletions.
5 changes: 3 additions & 2 deletions src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ using Flux
import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like
import Flux: glorot_uniform

export MGUCell, LiGRUCell, IndRNNCell
export MGU, LiGRU, IndRNN
export MGUCell, LiGRUCell, IndRNNCell, RANCell
export MGU, LiGRU, IndRNN, RAN

include("mgu_cell.jl")
include("ligru_cell.jl")
include("indrnn_cell.jl")
include("ran_cell.jl")

end #module
1 change: 1 addition & 0 deletions src/indrnn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ end
function (indrnn::IndRNNCell)(x::AbstractVecOrMat)
state = zeros_like(x, size(indrnn.u, 1))
return indrnn(x, state)
end

function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat)
_size_check(indrnn, inp, 1 => size(indrnn.Wi, 2))
Expand Down
2 changes: 1 addition & 1 deletion src/ligru_cell.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
struct LiGRUCell{I, H, V}
Wi::I
Wh::H
b::V
bias::V
end

function LiGRUCell((in, out)::Pair;
Expand Down
4 changes: 2 additions & 2 deletions src/mgu_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
struct MGUCell{I, H, V}
Wi::I
Wh::H
b::V
bias::V
end

function MGUCell((in, out)::Pair;
Expand All @@ -20,7 +20,7 @@ MGUCell(in, out; kwargs...) = MGUCell(in => out; kwargs...)

function (mgu::MGUCell)(inp::AbstractVecOrMat, state)
_size_check(mgu, inp, 1 => size(mgu.Wi,2))
Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.b
Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.bias
#split
gxs = chunk(Wi * inp, 2, dims=1)
bs = chunk(b, 2, dims=1)
Expand Down
74 changes: 74 additions & 0 deletions src/ran_cell.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#https://arxiv.org/pdf/1705.07393
struct RANCell{I,H,V}
Wi::I
Wh::H
bias::V
end

function RANCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true)
Wi = init(3 * out, in)
Wh = init(2 * out, out)
b = create_bias(Wi, bias, size(Wh, 1))

return RANCell(Wi, Wh, b)
end

RANCell(in, out; kwargs...) = RANCell(in => out; kwargs...)

function (ran::RANCell)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(ran.Wh, 2))
c_state = zeros_like(state)
return ran(inp, (state, c_state))
end

function (ran::RANCell)(inp::AbstractVecOrMat, (state, c_state))
_size_check(ran, inp, 1 => size(ran.Wi,2))
Wi, Wh, b = ran.Wi, ran.Wh, ran.bias

#split
gxs = chunk(Wi * inp, 3, dims=1)
bs = chunk(b, 2, dims=1)
ghs = chunk(Wh * state, 2, dims=1)

#compute
input_gate = @. sigmoid_fast(gxs[2] + ghs[1] + bs[1])
forget_gate = @. sigmoid_fast(gxs[3] + ghs[2] + bs[2])
candidate_state = @. input_gate * gxs[1] + forget_gate * c_state
new_state = tanh_fast(candidate_state)
return new_state, candidate_state
end

Base.show(io::IO, ran::RANCell) =
print(io, "RANCell(", size(ran.Wi, 2), " => ", size(ran.Wi, 1)÷2, ")")


struct RAN{M}
cell::M
end

Flux.@layer :expand RAN

function RAN((in, out)::Pair; init = glorot_uniform, bias = true)
cell = RANCell(in => out; init, bias)
return RAN(cell)
end

function (ran::RAN)(inp)
state = zeros_like(inp, size(ran.cell.Wh, 2))
c_state = zeros_like(state)
return ran(inp, (state, c_state))
end

function (ran::RAN)(inp, (state, c_state))
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
new_cstate = []
for inp_t in eachslice(inp, dims=2)
state, c_state = ran.cell(inp_t, (state, c_state))
new_state = vcat(new_state, [state])
new_cstate = vcat(new_cstate, [c_state])
end
return stack(new_state, dims=2), stack(new_cstate, dims=2)
end


0 comments on commit 0f2b1f9

Please sign in to comment.