Skip to content

Commit

Permalink
added lru
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Oct 25, 2024
1 parent 0f2b1f9 commit 8558f10
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ using Flux
import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like
import Flux: glorot_uniform

export MGUCell, LiGRUCell, IndRNNCell, RANCell
export MGUCell, LiGRUCell, IndRNNCell, RANCell, LRUCell
export MGU, LiGRU, IndRNN, RAN

include("mgu_cell.jl")
include("ligru_cell.jl")
include("indrnn_cell.jl")
include("ran_cell.jl")
include("lru_cell.jl")

end #module
1 change: 1 addition & 0 deletions src/indrnn_cell.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#https://arxiv.org/pdf/1803.04831
struct IndRNNCell{F,I,H,V}
σ::F
Wi::I
Expand Down
1 change: 1 addition & 0 deletions src/ligru_cell.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#https://arxiv.org/pdf/1803.10225
struct LiGRUCell{I, H, V}
Wi::I
Wh::H
Expand Down
66 changes: 66 additions & 0 deletions src/lru_cell.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#https://www.mdpi.com/2079-9292/13/16/3204
struct LRUCell{I,H,V}
Wi::I
Wh::H
bias::V
end

function LRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true)
Wi = init(2 * out, in)
Wh = init(out, out)
b = create_bias(Wi, bias, size(Wh, 1))

return LRUCell(Wi, Wh, b)
end

LRUCell(in, out; kwargs...) = LRUCell(in => out; kwargs...)

function (lru::LRUCell)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(lru.Wh, 2))
return lru(inp, state)
end

function (lru::LRUCell)(inp::AbstractVecOrMat, state)
_size_check(lru, inp, 1 => size(lru.Wi,2))
Wi, Wh, b = lru.Wi, lru.Wh, lru.bias

#split
gxs = chunk(Wi * inp, 2, dims=1)

#compute
candidate_state = @. tanh_fast(gxs[1])
forget_gate = sigmoid_fast(gxs[2] .+ Wh * state .+ b)
new_state = @. (1 - forget_gate) * state + forget_gate * candidate_state
return new_state
end

Base.show(io::IO, lru::LRUCell) =
print(io, "LRUCell(", size(lru.Wi, 2), " => ", size(lru.Wi, 1)÷2, ")")



struct LRU{M}
cell::M
end

Flux.@layer :expand LRU

function LRU((in, out)::Pair; init = glorot_uniform, bias = true)
cell = LRUCell(in => out; init, bias)
return LRU(cell)
end

function (lru::LRU)(inp)
state = zeros_like(inp, size(lru.cell.Wh, 2))
return lru(inp, state)
end

function (lru::LRU)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
for inp_t in eachslice(inp, dims=2)
state = lru.cell(inp_t, state)
new_state = vcat(new_state, [state])
end
return stack(new_state, dims=2)
end
13 changes: 9 additions & 4 deletions src/mgu_cell.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Define the MGU cell in Flux.jl
#https://arxiv.org/pdf/1603.09420
struct MGUCell{I, H, V}
Wi::I
Wh::H
Expand All @@ -18,6 +18,11 @@ end

MGUCell(in, out; kwargs...) = MGUCell(in => out; kwargs...)

function (mgu::MGUCell)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(mgu.Wh, 2))
return mgu(inp, state)
end

function (mgu::MGUCell)(inp::AbstractVecOrMat, state)
_size_check(mgu, inp, 1 => size(mgu.Wi,2))
Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.bias
Expand All @@ -26,9 +31,9 @@ function (mgu::MGUCell)(inp::AbstractVecOrMat, state)
bs = chunk(b, 2, dims=1)
ghs = chunk(Wh, 2, dims=1)

forget_gate = @. sigmoid_fast(gxs[1] + ghs[1]*state + bs[1])
candidate_state = @. tanh_fast(gxs[2] + ghs[2]*(forget_gate*state) + bs[2])
new_state = @. forget_gate .* state .+ (1 .- forget_gate) .* candidate_state
forget_gate = sigmoid_fast.(gxs[1] .+ ghs[1]*state .+ bs[1])
candidate_state = tanh_fast.(gxs[2] .+ ghs[2]*(forget_gate.*state) .+ bs[2])
new_state = forget_gate .* state .+ (1 .- forget_gate) .* candidate_state
return new_state
end

Expand Down
3 changes: 1 addition & 2 deletions src/ran_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function (ran::RANCell)(inp::AbstractVecOrMat, (state, c_state))
end

Base.show(io::IO, ran::RANCell) =
print(io, "RANCell(", size(ran.Wi, 2), " => ", size(ran.Wi, 1)÷2, ")")
print(io, "RANCell(", size(ran.Wi, 2), " => ", size(ran.Wi, 1)÷3, ")")


struct RAN{M}
Expand Down Expand Up @@ -71,4 +71,3 @@ function (ran::RAN)(inp, (state, c_state))
return stack(new_state, dims=2), stack(new_cstate, dims=2)
end


45 changes: 45 additions & 0 deletions src/sru_cell.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#https://arxiv.org/pdf/1709.02755
struct SRUCell{I,H,B,V}
Wi::I
Wh::H
v::B
bias::V
end

function SRUCell((in, out)::Pair, σ=tanh; init = glorot_uniform, bias = true)
Wi = init(2 * out, in)
Wh = init(2 * out, out)
v = init(2 * out)
b = create_bias(Wi, bias, size(Wh, 1))

return SRUCell(Wi, Wh, v, b)
end

SRUCell(in, out; kwargs...) = SRUCell(in => out; kwargs...)

function (sru::SRUCell)(inp::AbstractVecOrMat)
state = zeros_like(inp, size(sru.Wh, 2))
c_state = zeros_like(state)
return sru(inp, (state, c_state))
end

function (sru::SRUCell)(inp::AbstractVecOrMat, (state, c_state))
_size_check(sru, inp, 1 => size(sru.Wi,2))
Wi, Wh, v, b = sru.Wi, sru.Wh, sru.v, sru.bias

#split
gxs = chunk(Wi * inp, 3, dims=1)
ghs = chunk(Wh * state, 2, dims=1)
bs = chunk(b, 2, dims=1)
vs = chunk(v, 2, dims=1)

#compute
input_gate = @. sigmoid_fast(gxs[2] + ghs[1] + bs[1])
forget_gate = @. sigmoid_fast(gxs[3] + ghs[2] + bs[2])
candidate_state = @. input_gate * gxs[1] + forget_gate * c_state
new_state = tanh_fast(candidate_state)
return new_state, candidate_state
end

Base.show(io::IO, sru::SRUCell) =
print(io, "SRUCell(", size(sru.Wi, 2), " => ", size(sru.Wi, 1)÷2, ")")

0 comments on commit 8558f10

Please sign in to comment.