Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General fixes #42

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
style = "blue"
style = "sciml"
format_markdown = false
whitespace_in_kwargs = false
always_use_return = true
margin = 92
indent = 4
format_docstrings = true
separate_kwargs_with_semicolon = true
always_for_in = true
annotate_untyped_fields_with_any = false
annotate_untyped_fields_with_any = false
122 changes: 3 additions & 119 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
[julia-img]: https://img.shields.io/badge/julia-v1.10+-blue.svg
[julia-url]: https://julialang.org/

[style-img]: https://img.shields.io/badge/code%20style-blue-4495d1.svg
[style-url]: https://github.com/invenia/BlueStyle
[style-img]: https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826
[style-url]: https://github.com/SciML/SciMLStyle

[aqua-img]: https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg
[aqua-url]: https://github.com/JuliaTesting/Aqua.jl
Expand Down Expand Up @@ -71,124 +71,8 @@ pkg> add RecurrentLayers

## Getting started 🛠️

The workflow is identical to any recurrent Flux layer:
The workflow is identical to any recurrent Flux layer: just plug in a new recurrent layer in your workflow and test it out!

```julia
using RecurrentLayers

using Flux
using MLUtils: DataLoader
using Statistics
using Random

# Create dataset
function create_data(input_size, seq_length::Int, num_samples::Int)
data = randn(input_size, seq_length, num_samples) #(input_size, seq_length, num_samples)
labels = sum(data, dims=(1, 2)) .>= 0
labels = Int.(labels)
labels = dropdims(labels, dims=(1))
return data, labels
end

function create_dataset(input_size, seq_length, n_train::Int, n_test::Int, batch_size)
train_data, train_labels = create_data(input_size, seq_length, n_train)
train_loader = DataLoader((train_data, train_labels), batchsize=batch_size, shuffle=true)

test_data, test_labels = create_data(input_size, seq_length, n_test)
test_loader = DataLoader((test_data, test_labels), batchsize=batch_size, shuffle=false)
return train_loader, test_loader
end

struct RecurrentModel{H,C,D}
h0::H
rnn::C
dense::D
end

Flux.@layer RecurrentModel trainable=(rnn, dense)

function RecurrentModel(input_size::Int, hidden_size::Int)
return RecurrentModel(
zeros(Float32, hidden_size),
MGU(input_size => hidden_size),
Dense(hidden_size => 1, sigmoid))
end

function (model::RecurrentModel)(inp)
state = model.rnn(inp, model.h0)
state = state[:, end, :]
output = model.dense(state)
return output
end

function criterion(model, batch_data, batch_labels)
y_pred = model(batch_data)
loss = Flux.binarycrossentropy(y_pred, batch_labels)
return loss
end

function train_recurrent!(epoch, train_loader, opt, model, criterion)
total_loss = 0.0
for (batch_data, batch_labels) in train_loader
# Compute gradients and update parameters
grads = gradient(() -> criterion(model, batch_data, batch_labels), Flux.params(model))
Flux.Optimise.update!(opt, Flux.params(model), grads)

# Accumulate loss
total_loss += criterion(model, batch_data, batch_labels)
end
avg_loss = total_loss / length(train_loader)
println("Epoch $epoch/$num_epochs, Loss: $(round(avg_loss, digits=4))")
end

function test_recurrent(test_loader, model)
# Evaluation
correct = 0
total = 0
for (batch_data, batch_labels) in test_loader

# Forward pass
predicted = model(batch_data)

# Decode predictions: convert probabilities to class labels (0 or 1)
predicted_labels = vec(predicted .>= 0.5) # Threshold at 0.5 for binary classification

# Compare predicted labels to actual labels
correct += sum(predicted_labels .== vec(batch_labels))
total += length(batch_labels)
end
accuracy = correct / total
println("Accuracy: ", accuracy * 100, "%")
end

function main(;
input_size = 1, # Each element in the sequence is a scalar
hidden_size = 64, # Size of the hidden state
seq_length = 10, # Length of each sequence
batch_size = 16, # Batch size
num_epochs = 50, # Number of epochs for training
n_train = 1000, # Number of samples in train dataset
n_test = 200 # Number of samples in test dataset)
)
model = RecurrentModel(input_size, hidden_size)
# Generate test data
train_loader, test_loader = create_dataset(input_size, seq_length, n_train, n_test, batch_size)
# Define the optimizer
opt = Adam(0.001)

for epoch in 1:num_epochs
train_recurrent!(epoch, train_loader, opt, model, criterion)
end

test_recurrent(test_loader, model)

end

main()



```
## License 📜

This project is licensed under the MIT License, except for `nas_cell.jl`, which is licensed under the Apache License, Version 2.0.
Expand Down
55 changes: 27 additions & 28 deletions benchmarks/adding_problem/main.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using Flux, RecurrentLayers, MLUtils, StatsBase, Comonicon, Printf, CUDA

function generate_adding_data(
sequence_length::Int,
n_samples::Int;
kwargs...
sequence_length::Int,
n_samples::Int;
kwargs...
)
random_sequence = rand(Float32, 1, sequence_length, n_samples)
mask_sequence = zeros(Float32, 1, sequence_length, n_samples)
Expand All @@ -15,7 +15,7 @@ function generate_adding_data(
targets[i] = sum(Float32, random_sequence[1, idxs, i])
end

inputs = cat(random_sequence, mask_sequence, dims=1)
inputs = cat(random_sequence, mask_sequence; dims=1)
@assert size(inputs, 3) == size(targets, 1)

dataloader = DataLoader(
Expand All @@ -26,17 +26,16 @@ function generate_adding_data(
end

function generate_dataloaders(
sequence_length::Int,
n_train::Int,
n_test::Int;
kwargs...)
sequence_length::Int,
n_train::Int,
n_test::Int;
kwargs...)
train_loader = generate_adding_data(sequence_length, n_train; kwargs...)
test_loader = generate_adding_data(sequence_length, n_test; kwargs...)
return train_loader, test_loader
end


struct RecurrentModel{H,C,D}
struct RecurrentModel{H, C, D}
h0::H
rnn::C
dense::D
Expand All @@ -46,9 +45,9 @@ Flux.@layer RecurrentModel trainable=(rnn, dense)

function RecurrentModel(rnn_wrapper, input_size::Int, hidden_size::Int)
return RecurrentModel(
zeros(Float32, hidden_size),
rnn_wrapper(input_size => hidden_size),
Dense(hidden_size => 1, sigmoid))
zeros(Float32, hidden_size),
rnn_wrapper(input_size => hidden_size),
Dense(hidden_size => 1, sigmoid))
end

function (model::RecurrentModel)(inp)
Expand Down Expand Up @@ -82,24 +81,25 @@ function test_recurrent(epoch, test_loader, model, criterion)
end

Comonicon.@main function main(rnn_wrapper;
epochs::Int = 50,
shuffle::Bool = true,
batchsize::Int = 64,
sequence_length::Int = 1000,
n_train::Int = 500,
n_test::Int = 200,
hidden_size::Int = 20,
learning_rate::Float64 = 0.01)

epochs::Int=50,
shuffle::Bool=true,
batchsize::Int=64,
sequence_length::Int=1000,
n_train::Int=500,
n_test::Int=200,
hidden_size::Int=20,
learning_rate::Float64=0.01)
train_loader, test_loader = generate_dataloaders(
sequence_length, n_train, n_test; batchsize = batchsize, shuffle = shuffle
sequence_length, n_train, n_test; batchsize=batchsize, shuffle=shuffle
)

input_size = 2
model = RecurrentModel(rnn_wrapper, input_size, hidden_size)
criterion(input_data, target_data) = Flux.mse(
model(input_data), reshape(target_data, 1, :)
)
function criterion(input_data, target_data)
Flux.mse(
model(input_data), reshape(target_data, 1, :)
)
end
model = Flux.gpu(model)
opt = Flux.Adam(learning_rate)

Expand All @@ -111,6 +111,5 @@ Comonicon.@main function main(rnn_wrapper;

@printf "Epoch %2d: Train Loss: %.4f, Test Loss: %.4f, \
Time: %.2fs\n" epoch train_loss test_loss total_time

end
end
end
27 changes: 14 additions & 13 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,30 @@ using RecurrentLayers
using Documenter, DocumenterInterLinks
include("pages.jl")

DocMeta.setdocmeta!(RecurrentLayers, :DocTestSetup, :(using RecurrentLayers); recursive=true)
DocMeta.setdocmeta!(
RecurrentLayers, :DocTestSetup, :(using RecurrentLayers); recursive=true)
mathengine = Documenter.MathJax()

links = InterLinks(
"Flux" => "https://fluxml.ai/Flux.jl/stable/",
)

makedocs(;
modules = [RecurrentLayers],
authors = "Francesco Martinuzzi",
sitename = "RecurrentLayers.jl",
format = Documenter.HTML(;
modules=[RecurrentLayers],
authors="Francesco Martinuzzi",
sitename="RecurrentLayers.jl",
format=Documenter.HTML(;
mathengine,
assets = ["assets/favicon.ico"],
canonical = "https://MartinuzziFrancesco.github.io/RecurrentLayers.jl",
edit_link = "main",
assets=["assets/favicon.ico"],
canonical="https://MartinuzziFrancesco.github.io/RecurrentLayers.jl",
edit_link="main"
),
pages = pages,
plugins = [links],
pages=pages,
plugins=[links]
)

deploydocs(;
repo = "github.com/MartinuzziFrancesco/RecurrentLayers.jl",
devbranch = "main",
push_preview = true,
repo="github.com/MartinuzziFrancesco/RecurrentLayers.jl",
devbranch="main",
push_preview=true
)
18 changes: 9 additions & 9 deletions docs/pages.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pages=[
"Home" => "index.md",
"API Documentation" => [
"Cells" => "api/cells.md",
"Layers" => "api/layers.md",
"Wrappers" => "api/wrappers.md",
],
"Roadmap" => "roadmap.md"
]
pages = [
"Home" => "index.md",
"API Documentation" => [
"Cells" => "api/cells.md",
"Layers" => "api/layers.md",
"Wrappers" => "api/wrappers.md"
],
"Roadmap" => "roadmap.md"
]
18 changes: 8 additions & 10 deletions src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
module RecurrentLayers

using Compat: @compat
using Flux: _size_check, _match_eltype, chunk, create_bias,
zeros_like, glorot_uniform, scan, @layer,
default_rng, Chain, Dropout
using Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like, glorot_uniform,
scan, @layer, default_rng, Chain, Dropout
import Flux: initialstates
import Functors: functor
#to remove
using NNlib: fast_act, sigmoid_fast, tanh_fast, relu

export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell,
RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell,
FastRNNCell, FastGRNNCell
RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell,
FastRNNCell, FastGRNNCell
export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN, MUT1, MUT2, MUT3,
SCRN, PeepholeLSTM, FastRNN, FastGRNN
SCRN, PeepholeLSTM, FastRNN, FastGRNN
export StackedRNN

@compat(public, (initialstates))
Expand All @@ -34,7 +33,6 @@ include("cells/fastrnn_cell.jl")

include("wrappers/stackedrnn.jl")


### fallbacks for functors ###
rlayers = (:FastRNN, :FastGRNN, :IndRNN, :LightRU, :LiGRU, :MGU, :MUT1,
:MUT2, :MUT3, :NAS, :PeepholeLSTM, :RAN, :SCRN)
Expand All @@ -43,9 +41,9 @@ rcells = (:FastRNNCell, :FastGRNNCell, :IndRNNCell, :LightRUCell, :LiGRUCell,
:MGUCell, :MUT1Cell, :MUT2Cell, :MUT3Cell, :NASCell, :PeepholeLSTMCell,
:RANCell, :SCRNCell)

for (rlayer,rcell) in zip(rlayers, rcells)
for (rlayer, rcell) in zip(rlayers, rcells)
@eval begin
function ($rlayer)(rc::$rcell; return_state::Bool = false)
function ($rlayer)(rc::$rcell; return_state::Bool=false)
return $rlayer{return_state, typeof(rc)}(rc)
end

Expand All @@ -58,4 +56,4 @@ for (rlayer,rcell) in zip(rlayers, rcells)
end
end

end #module
end #module
Loading
Loading