From 9b9a12b1b77af4ab5782864bb9dd7fbd1b56122b Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Tue, 21 Jan 2025 09:38:29 +0100 Subject: [PATCH 1/2] double state layer forward fix --- src/generics.jl | 10 ++++++++++ src/wrappers/stackedrnn.jl | 4 +++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/generics.jl b/src/generics.jl index 01de690..ebb509b 100644 --- a/src/generics.jl +++ b/src/generics.jl @@ -30,11 +30,21 @@ end function (rlayer::AbstractRecurrentLayer{false})(inp::AbstractArray, state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}}) @assert ndims(inp) == 2 || ndims(inp) == 3 + @assert typeof(state) == typeof(initialstates(rlayer)) """\n + The layer $rlayer is calling states not supported by its + forward method. Check if this is a single or double return + recurrent layer, and adjust your inputs accordingly. + """ return first(scan(rlayer.cell, inp, state)) end function (rlayer::AbstractRecurrentLayer{true})(inp::AbstractArray, state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}}) @assert ndims(inp) == 2 || ndims(inp) == 3 + @assert typeof(state) == typeof(initialstates(rlayer)) """\n + The layer $rlayer is calling states not supported by its + forward method. Check if this is a single or double return + recurrent layer, and adjust your inputs accordingly. + """ return scan(rlayer.cell, inp, state) end diff --git a/src/wrappers/stackedrnn.jl b/src/wrappers/stackedrnn.jl index 4b81748..ac9d4bd 100644 --- a/src/wrappers/stackedrnn.jl +++ b/src/wrappers/stackedrnn.jl @@ -66,7 +66,9 @@ function StackedRNN(rlayer, (input_size, hidden_size)::Pair{<:Int, <:Int}, args. end function (stackedrnn::StackedRNN)(inp::AbstractArray) - @assert length(stackedrnn.layers)==length(stackedrnn.states) "Mismatch in layers vs. states length!" + @assert length(stackedrnn.layers)==length(stackedrnn.states) """\n + Mismatch in layers vs. states length! + """ @assert !isempty(stackedrnn.layers) "StackedRNN has no layers!" for idx in eachindex(stackedrnn.layers) inp = stackedrnn.layers[idx](inp, stackedrnn.states[idx]) From 0c5f9301c3bd712e82a3332ddc30bdac058b83a0 Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Tue, 21 Jan 2025 09:43:13 +0100 Subject: [PATCH 2/2] format --- src/generics.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/generics.jl b/src/generics.jl index ebb509b..baea127 100644 --- a/src/generics.jl +++ b/src/generics.jl @@ -30,21 +30,21 @@ end function (rlayer::AbstractRecurrentLayer{false})(inp::AbstractArray, state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}}) @assert ndims(inp) == 2 || ndims(inp) == 3 - @assert typeof(state) == typeof(initialstates(rlayer)) """\n - The layer $rlayer is calling states not supported by its - forward method. Check if this is a single or double return - recurrent layer, and adjust your inputs accordingly. - """ + @assert typeof(state)==typeof(initialstates(rlayer)) """\n + The layer $rlayer is calling states not supported by its + forward method. Check if this is a single or double return + recurrent layer, and adjust your inputs accordingly. + """ return first(scan(rlayer.cell, inp, state)) end function (rlayer::AbstractRecurrentLayer{true})(inp::AbstractArray, state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}}) @assert ndims(inp) == 2 || ndims(inp) == 3 - @assert typeof(state) == typeof(initialstates(rlayer)) """\n - The layer $rlayer is calling states not supported by its - forward method. Check if this is a single or double return - recurrent layer, and adjust your inputs accordingly. - """ + @assert typeof(state)==typeof(initialstates(rlayer)) """\n + The layer $rlayer is calling states not supported by its + forward method. Check if this is a single or double return + recurrent layer, and adjust your inputs accordingly. + """ return scan(rlayer.cell, inp, state) end