Skip to content

Commit

Permalink
Merge pull request #42 from MartinuzziFrancesco/fm/fixes
Browse files Browse the repository at this point in the history
General fixes
  • Loading branch information
MartinuzziFrancesco authored Jan 14, 2025
2 parents d8e92ea + 1dbafad commit c1e466c
Show file tree
Hide file tree
Showing 24 changed files with 307 additions and 461 deletions.
6 changes: 3 additions & 3 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
style = "blue"
style = "sciml"
format_markdown = false
whitespace_in_kwargs = false
always_use_return = true
margin = 92
indent = 4
format_docstrings = true
separate_kwargs_with_semicolon = true
always_for_in = true
annotate_untyped_fields_with_any = false
annotate_untyped_fields_with_any = false
122 changes: 3 additions & 119 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
[julia-img]: https://img.shields.io/badge/julia-v1.10+-blue.svg
[julia-url]: https://julialang.org/

[style-img]: https://img.shields.io/badge/code%20style-blue-4495d1.svg
[style-url]: https://github.com/invenia/BlueStyle
[style-img]: https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826
[style-url]: https://github.com/SciML/SciMLStyle

[aqua-img]: https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg
[aqua-url]: https://github.com/JuliaTesting/Aqua.jl
Expand Down Expand Up @@ -71,124 +71,8 @@ pkg> add RecurrentLayers

## Getting started 🛠️

The workflow is identical to any recurrent Flux layer:
The workflow is identical to any recurrent Flux layer: just plug in a new recurrent layer in your workflow and test it out!

```julia
using RecurrentLayers

using Flux
using MLUtils: DataLoader
using Statistics
using Random

# Create dataset
function create_data(input_size, seq_length::Int, num_samples::Int)
data = randn(input_size, seq_length, num_samples) #(input_size, seq_length, num_samples)
labels = sum(data, dims=(1, 2)) .>= 0
labels = Int.(labels)
labels = dropdims(labels, dims=(1))
return data, labels
end

function create_dataset(input_size, seq_length, n_train::Int, n_test::Int, batch_size)
train_data, train_labels = create_data(input_size, seq_length, n_train)
train_loader = DataLoader((train_data, train_labels), batchsize=batch_size, shuffle=true)

test_data, test_labels = create_data(input_size, seq_length, n_test)
test_loader = DataLoader((test_data, test_labels), batchsize=batch_size, shuffle=false)
return train_loader, test_loader
end

struct RecurrentModel{H,C,D}
h0::H
rnn::C
dense::D
end

Flux.@layer RecurrentModel trainable=(rnn, dense)

function RecurrentModel(input_size::Int, hidden_size::Int)
return RecurrentModel(
zeros(Float32, hidden_size),
MGU(input_size => hidden_size),
Dense(hidden_size => 1, sigmoid))
end

function (model::RecurrentModel)(inp)
state = model.rnn(inp, model.h0)
state = state[:, end, :]
output = model.dense(state)
return output
end

function criterion(model, batch_data, batch_labels)
y_pred = model(batch_data)
loss = Flux.binarycrossentropy(y_pred, batch_labels)
return loss
end

function train_recurrent!(epoch, train_loader, opt, model, criterion)
total_loss = 0.0
for (batch_data, batch_labels) in train_loader
# Compute gradients and update parameters
grads = gradient(() -> criterion(model, batch_data, batch_labels), Flux.params(model))
Flux.Optimise.update!(opt, Flux.params(model), grads)

# Accumulate loss
total_loss += criterion(model, batch_data, batch_labels)
end
avg_loss = total_loss / length(train_loader)
println("Epoch $epoch/$num_epochs, Loss: $(round(avg_loss, digits=4))")
end

function test_recurrent(test_loader, model)
# Evaluation
correct = 0
total = 0
for (batch_data, batch_labels) in test_loader

# Forward pass
predicted = model(batch_data)

# Decode predictions: convert probabilities to class labels (0 or 1)
predicted_labels = vec(predicted .>= 0.5) # Threshold at 0.5 for binary classification

# Compare predicted labels to actual labels
correct += sum(predicted_labels .== vec(batch_labels))
total += length(batch_labels)
end
accuracy = correct / total
println("Accuracy: ", accuracy * 100, "%")
end

function main(;
input_size = 1, # Each element in the sequence is a scalar
hidden_size = 64, # Size of the hidden state
seq_length = 10, # Length of each sequence
batch_size = 16, # Batch size
num_epochs = 50, # Number of epochs for training
n_train = 1000, # Number of samples in train dataset
n_test = 200 # Number of samples in test dataset)
)
model = RecurrentModel(input_size, hidden_size)
# Generate test data
train_loader, test_loader = create_dataset(input_size, seq_length, n_train, n_test, batch_size)
# Define the optimizer
opt = Adam(0.001)

for epoch in 1:num_epochs
train_recurrent!(epoch, train_loader, opt, model, criterion)
end

test_recurrent(test_loader, model)

end

main()



```
## License 📜

This project is licensed under the MIT License, except for `nas_cell.jl`, which is licensed under the Apache License, Version 2.0.
Expand Down
55 changes: 27 additions & 28 deletions benchmarks/adding_problem/main.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using Flux, RecurrentLayers, MLUtils, StatsBase, Comonicon, Printf, CUDA

function generate_adding_data(
sequence_length::Int,
n_samples::Int;
kwargs...
sequence_length::Int,
n_samples::Int;
kwargs...
)
random_sequence = rand(Float32, 1, sequence_length, n_samples)
mask_sequence = zeros(Float32, 1, sequence_length, n_samples)
Expand All @@ -15,7 +15,7 @@ function generate_adding_data(
targets[i] = sum(Float32, random_sequence[1, idxs, i])
end

inputs = cat(random_sequence, mask_sequence, dims=1)
inputs = cat(random_sequence, mask_sequence; dims=1)
@assert size(inputs, 3) == size(targets, 1)

dataloader = DataLoader(
Expand All @@ -26,17 +26,16 @@ function generate_adding_data(
end

function generate_dataloaders(
sequence_length::Int,
n_train::Int,
n_test::Int;
kwargs...)
sequence_length::Int,
n_train::Int,
n_test::Int;
kwargs...)
train_loader = generate_adding_data(sequence_length, n_train; kwargs...)
test_loader = generate_adding_data(sequence_length, n_test; kwargs...)
return train_loader, test_loader
end


struct RecurrentModel{H,C,D}
struct RecurrentModel{H, C, D}
h0::H
rnn::C
dense::D
Expand All @@ -46,9 +45,9 @@ Flux.@layer RecurrentModel trainable=(rnn, dense)

function RecurrentModel(rnn_wrapper, input_size::Int, hidden_size::Int)
return RecurrentModel(
zeros(Float32, hidden_size),
rnn_wrapper(input_size => hidden_size),
Dense(hidden_size => 1, sigmoid))
zeros(Float32, hidden_size),
rnn_wrapper(input_size => hidden_size),
Dense(hidden_size => 1, sigmoid))
end

function (model::RecurrentModel)(inp)
Expand Down Expand Up @@ -82,24 +81,25 @@ function test_recurrent(epoch, test_loader, model, criterion)
end

Comonicon.@main function main(rnn_wrapper;
epochs::Int = 50,
shuffle::Bool = true,
batchsize::Int = 64,
sequence_length::Int = 1000,
n_train::Int = 500,
n_test::Int = 200,
hidden_size::Int = 20,
learning_rate::Float64 = 0.01)

epochs::Int=50,
shuffle::Bool=true,
batchsize::Int=64,
sequence_length::Int=1000,
n_train::Int=500,
n_test::Int=200,
hidden_size::Int=20,
learning_rate::Float64=0.01)
train_loader, test_loader = generate_dataloaders(
sequence_length, n_train, n_test; batchsize = batchsize, shuffle = shuffle
sequence_length, n_train, n_test; batchsize=batchsize, shuffle=shuffle
)

input_size = 2
model = RecurrentModel(rnn_wrapper, input_size, hidden_size)
criterion(input_data, target_data) = Flux.mse(
model(input_data), reshape(target_data, 1, :)
)
function criterion(input_data, target_data)
Flux.mse(
model(input_data), reshape(target_data, 1, :)
)
end
model = Flux.gpu(model)
opt = Flux.Adam(learning_rate)

Expand All @@ -111,6 +111,5 @@ Comonicon.@main function main(rnn_wrapper;

@printf "Epoch %2d: Train Loss: %.4f, Test Loss: %.4f, \
Time: %.2fs\n" epoch train_loss test_loss total_time

end
end
end
27 changes: 14 additions & 13 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,30 @@ using RecurrentLayers
using Documenter, DocumenterInterLinks
include("pages.jl")

DocMeta.setdocmeta!(RecurrentLayers, :DocTestSetup, :(using RecurrentLayers); recursive=true)
DocMeta.setdocmeta!(
RecurrentLayers, :DocTestSetup, :(using RecurrentLayers); recursive=true)
mathengine = Documenter.MathJax()

links = InterLinks(
"Flux" => "https://fluxml.ai/Flux.jl/stable/",
)

makedocs(;
modules = [RecurrentLayers],
authors = "Francesco Martinuzzi",
sitename = "RecurrentLayers.jl",
format = Documenter.HTML(;
modules=[RecurrentLayers],
authors="Francesco Martinuzzi",
sitename="RecurrentLayers.jl",
format=Documenter.HTML(;
mathengine,
assets = ["assets/favicon.ico"],
canonical = "https://MartinuzziFrancesco.github.io/RecurrentLayers.jl",
edit_link = "main",
assets=["assets/favicon.ico"],
canonical="https://MartinuzziFrancesco.github.io/RecurrentLayers.jl",
edit_link="main"
),
pages = pages,
plugins = [links],
pages=pages,
plugins=[links]
)

deploydocs(;
repo = "github.com/MartinuzziFrancesco/RecurrentLayers.jl",
devbranch = "main",
push_preview = true,
repo="github.com/MartinuzziFrancesco/RecurrentLayers.jl",
devbranch="main",
push_preview=true
)
18 changes: 9 additions & 9 deletions docs/pages.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pages=[
"Home" => "index.md",
"API Documentation" => [
"Cells" => "api/cells.md",
"Layers" => "api/layers.md",
"Wrappers" => "api/wrappers.md",
],
"Roadmap" => "roadmap.md"
]
pages = [
"Home" => "index.md",
"API Documentation" => [
"Cells" => "api/cells.md",
"Layers" => "api/layers.md",
"Wrappers" => "api/wrappers.md"
],
"Roadmap" => "roadmap.md"
]
18 changes: 8 additions & 10 deletions src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
module RecurrentLayers

using Compat: @compat
using Flux: _size_check, _match_eltype, chunk, create_bias,
zeros_like, glorot_uniform, scan, @layer,
default_rng, Chain, Dropout
using Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like, glorot_uniform,
scan, @layer, default_rng, Chain, Dropout
import Flux: initialstates
import Functors: functor
#to remove
using NNlib: fast_act, sigmoid_fast, tanh_fast, relu

export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell,
RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell,
FastRNNCell, FastGRNNCell
RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell,
FastRNNCell, FastGRNNCell
export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN, MUT1, MUT2, MUT3,
SCRN, PeepholeLSTM, FastRNN, FastGRNN
SCRN, PeepholeLSTM, FastRNN, FastGRNN
export StackedRNN

@compat(public, (initialstates))
Expand All @@ -34,7 +33,6 @@ include("cells/fastrnn_cell.jl")

include("wrappers/stackedrnn.jl")


### fallbacks for functors ###
rlayers = (:FastRNN, :FastGRNN, :IndRNN, :LightRU, :LiGRU, :MGU, :MUT1,
:MUT2, :MUT3, :NAS, :PeepholeLSTM, :RAN, :SCRN)
Expand All @@ -43,9 +41,9 @@ rcells = (:FastRNNCell, :FastGRNNCell, :IndRNNCell, :LightRUCell, :LiGRUCell,
:MGUCell, :MUT1Cell, :MUT2Cell, :MUT3Cell, :NASCell, :PeepholeLSTMCell,
:RANCell, :SCRNCell)

for (rlayer,rcell) in zip(rlayers, rcells)
for (rlayer, rcell) in zip(rlayers, rcells)
@eval begin
function ($rlayer)(rc::$rcell; return_state::Bool = false)
function ($rlayer)(rc::$rcell; return_state::Bool=false)
return $rlayer{return_state, typeof(rc)}(rc)
end

Expand All @@ -58,4 +56,4 @@ for (rlayer,rcell) in zip(rlayers, rcells)
end
end

end #module
end #module
Loading

0 comments on commit c1e466c

Please sign in to comment.