Skip to content

Commit

Permalink
double state layer forward fix
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Jan 21, 2025
1 parent e8b403e commit 9b9a12b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
10 changes: 10 additions & 0 deletions src/generics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,21 @@ end
function (rlayer::AbstractRecurrentLayer{false})(inp::AbstractArray,
state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}})
@assert ndims(inp) == 2 || ndims(inp) == 3
@assert typeof(state) == typeof(initialstates(rlayer)) """\n
The layer $rlayer is calling states not supported by its
forward method. Check if this is a single or double return
recurrent layer, and adjust your inputs accordingly.
"""
return first(scan(rlayer.cell, inp, state))
end

function (rlayer::AbstractRecurrentLayer{true})(inp::AbstractArray,
state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}})
@assert ndims(inp) == 2 || ndims(inp) == 3
@assert typeof(state) == typeof(initialstates(rlayer)) """\n
The layer $rlayer is calling states not supported by its
forward method. Check if this is a single or double return
recurrent layer, and adjust your inputs accordingly.
"""
return scan(rlayer.cell, inp, state)
end
4 changes: 3 additions & 1 deletion src/wrappers/stackedrnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ function StackedRNN(rlayer, (input_size, hidden_size)::Pair{<:Int, <:Int}, args.
end

function (stackedrnn::StackedRNN)(inp::AbstractArray)
@assert length(stackedrnn.layers)==length(stackedrnn.states) "Mismatch in layers vs. states length!"
@assert length(stackedrnn.layers)==length(stackedrnn.states) """\n
Mismatch in layers vs. states length!
"""
@assert !isempty(stackedrnn.layers) "StackedRNN has no layers!"
for idx in eachindex(stackedrnn.layers)
inp = stackedrnn.layers[idx](inp, stackedrnn.states[idx])
Expand Down

0 comments on commit 9b9a12b

Please sign in to comment.