Skip to content

Commit

Permalink
Merge pull request #2 from MartinuzziFrancesco/fm/design
Browse files Browse the repository at this point in the history
Conform to novel Flux RNN design
  • Loading branch information
MartinuzziFrancesco authored Oct 25, 2024
2 parents 940f650 + 8558f10 commit b1f5090
Show file tree
Hide file tree
Showing 7 changed files with 350 additions and 70 deletions.
9 changes: 6 additions & 3 deletions src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
module RecurrentLayers

using Flux
import Flux: _size_check, _match_eltype, multigate, reshape_cell_output
import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like
import Flux: glorot_uniform

export MGUCell, LiGRUCell
export MGU, LiGRU
export MGUCell, LiGRUCell, IndRNNCell, RANCell, LRUCell
export MGU, LiGRU, IndRNN, RAN

include("mgu_cell.jl")
include("ligru_cell.jl")
include("indrnn_cell.jl")
include("ran_cell.jl")
include("lru_cell.jl")

end #module
60 changes: 60 additions & 0 deletions src/indrnn_cell.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#https://arxiv.org/pdf/1803.04831
struct IndRNNCell{F,I,H,V}
σ::F
Wi::I
u::H
b::V
end

Flux.@layer IndRNNCell

function IndRNNCell((in, out)::Pair, σ=relu; init = glorot_uniform, bias = true)
Wi = init(out, in)
u = init(out)
b = create_bias(Wi, bias, size(Wi, 1))
return IndRNNCell(σ, Wi, u, b)
end

function (indrnn::IndRNNCell)(x::AbstractVecOrMat)
state = zeros_like(x, size(indrnn.u, 1))
return indrnn(x, state)
end

function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat)
_size_check(indrnn, inp, 1 => size(indrnn.Wi, 2))
σ = NNlib.fast_act(indrnn.σ, inp)
state = σ.(indrnn.Wi*inp .+ indrnn.u.*state .+ indrnn.b)
return state
end

function Base.show(io::IO, m::IndRNNCell)
print(io, "IndRNNCell(", size(m.Wi, 2), " => ", size(indrnn.Wi, 1))
print(io, ", ", indrnn.σ)
print(io, ")")
end

struct IndRNN{M}
cell::M
end

Flux.@layer :expand IndRNN

function IndRNN((in, out)::Pair, σ = tanh; bias = true, init = glorot_uniform)
cell = IndRNNCell(in => out, σ; bias=bias, init=init)
return IndRNN(cell)
end

function (indrnn::IndRNN)(inp)
state = zeros_like(inp, size(indrnn.cell.u, 1))
return indrnn(inp, state)
end

function (indrnn::IndRNN)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
for inp_t in eachslice(inp, dims=2)
state = indrnn.cell(inp_t, state)
new_state = vcat(new_state, [state])
end
return stack(new_state, dims=2)
end
81 changes: 48 additions & 33 deletions src/ligru_cell.jl
Original file line number Diff line number Diff line change
@@ -1,49 +1,64 @@
struct LiGRUCell{I, H, V, S, F1, F2}
Wf::I
#https://arxiv.org/pdf/1803.10225
struct LiGRUCell{I, H, V}
Wi::I
Wh::H
b::V
state0::S
activation_fn::F1
gate_activation_fn::F2
bias::V
end

function LiGRUCell((in, out)::Pair;
init=glorot_uniform,
initb=zeros32,
init_state=zeros32,
activation_fn=tanh_fast,
gate_activation_fn=sigmoid_fast)
init = glorot_uniform,
bias = true)

Wf = init(out * 2, in)
Wi = init(out * 2, in)
Wh = init(out * 2, out)
b = initb(out * 2)
state0 = init_state(out, 1)
return LiGRUCell(Wf, Wh, b, state0, activation_fn, gate_activation_fn)
b = create_bias(Wi, bias, size(Wi, 1))

return LiGRUCell(Wi, Wh, b)
end

LiGRUCell(in, out; kwargs...) = LiGRUCell(in => out; kwargs...)

function (ligru::LiGRUCell{I,H,V,<:AbstractMatrix{T},F1, F2})(hidden, inp::AbstractVecOrMat) where {I,H,V,T,F1,F2}
_size_check(ligru, inp, 1 => size(ligru.Wf,2))
Wf, Wh, bias, o = ligru.Wf, ligru.Wh, ligru.b, size(hidden, 1)
inp_t = _match_eltype(ligru, T, inp)
gxs, ghs, bs = multigate(Wf*inp_t, o, Val(2)), multigate(Wh*(hidden), o, Val(2)), multigate(bias, o, Val(2))
forget_gate = @. ligru.gate_activation_fn(gxs[1] + ghs[1] + bs[1])

function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state)
_size_check(ligru, inp, 1 => size(ligru.Wi,2))
Wi, Wh, b = ligru.Wi, ligru.Wh, ligru.b
#split
gxs = chunk(Wi * inp, 2, dims=1)
ghs = chunk(Wh * state, 2, dims=1)
bs = chunk(b, 2, dims=1)
#compute
forget_gate = @. sigmoid_fast(gxs[1] + ghs[1] + bs[1])
candidate_hidden = @. tanh_fast(gxs[2] + ghs[2] + bs[2])
new_h = forget_gate .* hidden .+ (1 .- forget_gate) .* candidate_hidden
return new_h, reshape_cell_output(new_h, inp)
new_state = forget_gate .* hidden .+ (1 .- forget_gate) .* candidate_hidden
return new_state
end

Flux.@layer LiGRUCell

Base.show(io::IO, ligru::LiGRUCell) =
print(io, "LiGRUCell(", size(ligru.Wf, 2), " => ", size(ligru.Wf, 1) ÷ 2, ")")
struct LiGRU{M}
cell::M
end

Flux.@layer :expand LiGRU

function LiGRU(args...; kwargs...)
return Flux.Recur(LiGRUCell(args...; kwargs...))
end
function LiGRU((in, out)::Pair; init = glorot_uniform, bias = true)
cell = LiGRUCell(in => out; init, bias)
return LiGRU(cell)
end

function (ligru::LiGRU)(inp)
state = zeros_like(inp, size(ligru.cell.Wh, 2))
return ligru(inp, state)
end

function (ligru::LiGRU)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
for inp_t in eachslice(inp, dims=2)
state = ligru.cell(inp_t, state)
new_state = vcat(new_state, [state])
end
return stack(new_state, dims=2)
end

function Flux.Recur(ligru::LiGRUCell)
return Flux.Recur(ligru, ligru.state0)
end

Base.show(io::IO, ligru::LiGRUCell) =
print(io, "LiGRUCell(", size(ligru.Wi, 2), " => ", size(ligru.Wi, 1) ÷ 2, ")")
66 changes: 66 additions & 0 deletions src/lru_cell.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#https://www.mdpi.com/2079-9292/13/16/3204
struct LRUCell{I,H,V}
Wi::I
Wh::H
bias::V
end

function LRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true)
Wi = init(2 * out, in)
Wh = init(out, out)
b = create_bias(Wi, bias, size(Wh, 1))

return LRUCell(Wi, Wh, b)
end

LRUCell(in, out; kwargs...) = LRUCell(in => out; kwargs...)

function (lru::LRUCell)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(lru.Wh, 2))
return lru(inp, state)
end

function (lru::LRUCell)(inp::AbstractVecOrMat, state)
_size_check(lru, inp, 1 => size(lru.Wi,2))
Wi, Wh, b = lru.Wi, lru.Wh, lru.bias

#split
gxs = chunk(Wi * inp, 2, dims=1)

#compute
candidate_state = @. tanh_fast(gxs[1])
forget_gate = sigmoid_fast(gxs[2] .+ Wh * state .+ b)
new_state = @. (1 - forget_gate) * state + forget_gate * candidate_state
return new_state
end

Base.show(io::IO, lru::LRUCell) =
print(io, "LRUCell(", size(lru.Wi, 2), " => ", size(lru.Wi, 1)÷2, ")")



struct LRU{M}
cell::M
end

Flux.@layer :expand LRU

function LRU((in, out)::Pair; init = glorot_uniform, bias = true)
cell = LRUCell(in => out; init, bias)
return LRU(cell)
end

function (lru::LRU)(inp)
state = zeros_like(inp, size(lru.cell.Wh, 2))
return lru(inp, state)
end

function (lru::LRU)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
for inp_t in eachslice(inp, dims=2)
state = lru.cell(inp_t, state)
new_state = vcat(new_state, [state])
end
return stack(new_state, dims=2)
end
86 changes: 52 additions & 34 deletions src/mgu_cell.jl
Original file line number Diff line number Diff line change
@@ -1,50 +1,68 @@
# Define the MGU cell in Flux.jl
struct MGUCell{I, H, V, S, F1, F2}
Wf::I
#https://arxiv.org/pdf/1603.09420
struct MGUCell{I, H, V}
Wi::I
Wh::H
b::V
state0::S
activation_fn::F1
gate_activation_fn::F2
bias::V
end

function MGUCell((in, out)::Pair;
init=glorot_uniform,
initb=zeros32,
init_state=zeros32,
activation_fn=tanh_fast,
gate_activation_fn=sigmoid_fast)
init = glorot_uniform,
bias = true)

Wf = init(out * 2, in)
Wi = init(out * 2, in)
Wh = init(out * 2, out)
b = initb(out * 2)
state0 = init_state(out, 1)
return MGUCell(Wf, Wh, b, state0, activation_fn, gate_activation_fn)
b = create_bias(Wi, bias, size(Wi, 1))

return MGUCell(Wi, Wh, b)
end

MGUCell(in, out; kwargs...) = MGUCell(in => out; kwargs...)

function (mgu::MGUCell{I,H,V,<:AbstractMatrix{T},F1, F2})(hidden, inp::AbstractVecOrMat) where {I,H,V,T,F1,F2}
_size_check(mgu, inp, 1 => size(mgu.Wf,2))
Wf, Wh, bias, o = mgu.Wf, mgu.Wh, mgu.b, size(hidden, 1)
inp_t = _match_eltype(mgu, T, inp)
gxs, ghs, bs = multigate(Wf*inp_t, o, Val(2)), multigate(Wh*(hidden), o, Val(2)), multigate(bias, o, Val(2))
forget_gate = @. mgu.gate_activation_fn(gxs[1] + ghs[1] + bs[1])
function (mgu::MGUCell)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(mgu.Wh, 2))
return mgu(inp, state)
end

function (mgu::MGUCell)(inp::AbstractVecOrMat, state)
_size_check(mgu, inp, 1 => size(mgu.Wi,2))
Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.bias
#split
gxs = chunk(Wi * inp, 2, dims=1)
bs = chunk(b, 2, dims=1)
ghs = chunk(Wh, 2, dims=1)

candidate_hidden = @. tanh_fast(gxs[2] + forget_gate * (ghs[2]*hidden) + bs[2])
new_h = forget_gate .* hidden .+ (1 .- forget_gate) .* candidate_hidden
return new_h, reshape_cell_output(new_h, inp)
forget_gate = sigmoid_fast.(gxs[1] .+ ghs[1]*state .+ bs[1])
candidate_state = tanh_fast.(gxs[2] .+ ghs[2]*(forget_gate.*state) .+ bs[2])
new_state = forget_gate .* state .+ (1 .- forget_gate) .* candidate_state
return new_state
end

Flux.@layer MGUCell
Base.show(io::IO, mgu::MGUCell) =
print(io, "MGUCell(", size(mgu.Wi, 2), " => ", size(mgu.Wi, 1) ÷ 2, ")")

Base.show(io::IO, l::MGUCell) =
print(io, "MGUCell(", size(l.Wf, 2), " => ", size(l.Wf, 1) ÷ 2, ")")

function MGU(args...; kwargs...)
return Flux.Recur(MGUCell(args...; kwargs...))
end
struct MGU{M}
cell::M
end

Flux.@layer :expand MGU

function Flux.Recur(mgu::MGUCell)
return Flux.Recur(mgu, mgu.state0)
end
function MGU((in, out)::Pair; init = glorot_uniform, bias = true)
cell = MGUCell(in => out; init, bias)
return MGU(cell)
end

function (mgu::MGU)(inp)
state = zeros_like(inp, size(mgu.cell.Wh, 2))
return mgu(inp, state)
end

function (mgu::MGU)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
for inp_t in eachslice(inp, dims=2)
state = mgu.cell(inp_t, state)
new_state = vcat(new_state, [state])
end
return stack(new_state, dims=2)
end
Loading

0 comments on commit b1f5090

Please sign in to comment.