diff --git a/src/generics.jl b/src/generics.jl index 01de690..baea127 100644 --- a/src/generics.jl +++ b/src/generics.jl @@ -30,11 +30,21 @@ end function (rlayer::AbstractRecurrentLayer{false})(inp::AbstractArray, state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}}) @assert ndims(inp) == 2 || ndims(inp) == 3 + @assert typeof(state)==typeof(initialstates(rlayer)) """\n + The layer $rlayer is calling states not supported by its + forward method. Check if this is a single or double return + recurrent layer, and adjust your inputs accordingly. + """ return first(scan(rlayer.cell, inp, state)) end function (rlayer::AbstractRecurrentLayer{true})(inp::AbstractArray, state::Union{AbstractVecOrMat, Tuple{AbstractVecOrMat, AbstractVecOrMat}}) @assert ndims(inp) == 2 || ndims(inp) == 3 + @assert typeof(state)==typeof(initialstates(rlayer)) """\n + The layer $rlayer is calling states not supported by its + forward method. Check if this is a single or double return + recurrent layer, and adjust your inputs accordingly. + """ return scan(rlayer.cell, inp, state) end diff --git a/src/wrappers/stackedrnn.jl b/src/wrappers/stackedrnn.jl index 4b81748..ac9d4bd 100644 --- a/src/wrappers/stackedrnn.jl +++ b/src/wrappers/stackedrnn.jl @@ -66,7 +66,9 @@ function StackedRNN(rlayer, (input_size, hidden_size)::Pair{<:Int, <:Int}, args. end function (stackedrnn::StackedRNN)(inp::AbstractArray) - @assert length(stackedrnn.layers)==length(stackedrnn.states) "Mismatch in layers vs. states length!" + @assert length(stackedrnn.layers)==length(stackedrnn.states) """\n + Mismatch in layers vs. states length! + """ @assert !isempty(stackedrnn.layers) "StackedRNN has no layers!" for idx in eachindex(stackedrnn.layers) inp = stackedrnn.layers[idx](inp, stackedrnn.states[idx])