Skip to content

Commit

Permalink
readme cleanup, start of rhn end
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Oct 31, 2024
1 parent 9047e2c commit aec45e4
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 59 deletions.
138 changes: 79 additions & 59 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,91 +41,111 @@ using MLUtils: DataLoader
using Statistics
using Random

# Parameters
input_size = 1 # Each element in the sequence is a scalar
hidden_size = 64 # Size of the hidden state in MGU
num_classes = 2 # Binary classification
seq_length = 10 # Length of each sequence
batch_size = 16 # Batch size
num_epochs = 50 # Number of epochs for training
num_samples = 1000 # Number of samples in dataset

# Create dataset
function create_dataset(seq_length, num_samples)
data = randn(input_size, seq_length, num_samples)
labels = sum(data, dims=(1,2)) .>= 0
function create_data(input_size, seq_length::Int, num_samples::Int)
data = randn(input_size, seq_length, num_samples) #(input_size, seq_length, num_samples)
labels = sum(data, dims=(1, 2)) .>= 0
labels = Int.(labels)
labels = dropdims(labels, dims=(1))
return data, labels
end

# Generate training data
train_data, train_labels = create_dataset(seq_length, num_samples)
train_loader = DataLoader((train_data, train_labels), batchsize=batch_size, shuffle=true)
function create_dataset(input_size, seq_length, n_train::Int, n_test::Int, batch_size)
train_data, train_labels = create_data(input_size, seq_length, n_train)
train_loader = DataLoader((train_data, train_labels), batchsize=batch_size, shuffle=true)

# Define the model
model = Chain(
RAN(input_size => hidden_size),
x -> x[:, end, :], # Extract the last hidden state
Dense(hidden_size, num_classes)
)
test_data, test_labels = create_data(input_size, seq_length, n_test)
test_loader = DataLoader((test_data, test_labels), batchsize=batch_size, shuffle=false)
return train_loader, test_loader
end

function adjust_labels(labels)
return labels .+ 1
struct RecurrentModel{H,C,D}
h0::H
rnn::C
dense::D
end

# Define the loss function
function loss_fn(batch_data, batch_labels)
# Adjust labels
batch_labels = adjust_labels(batch_labels)
# One-hot encode labels and remove any extra singleton dimensions
batch_labels_oh = dropdims(Flux.onehotbatch(batch_labels, 1:num_classes), dims=(2, 3))
# Forward pass
y_pred = model(batch_data)
# Compute loss
loss = Flux.logitcrossentropy(y_pred, batch_labels_oh)
return loss
Flux.@layer RecurrentModel trainable=(rnn, dense)

function RecurrentModel(input_size::Int, hidden_size::Int)
return RecurrentModel(
zeros(Float32, hidden_size),
MGU(input_size => hidden_size),
Dense(hidden_size => 1, sigmoid))
end

function (model::RecurrentModel)(inp)
state = model.rnn(inp, model.h0)
state = state[:, end, :]
output = model.dense(state)
return output
end

# Define the optimizer
opt = Adam(0.01)
function criterion(model, batch_data, batch_labels)
y_pred = model(batch_data)
loss = Flux.binarycrossentropy(y_pred, batch_labels)
return loss
end

# Training loop
for epoch in 1:num_epochs
function train_recurrent!(epoch, train_loader, opt, model, criterion)
total_loss = 0.0
for (batch_data, batch_labels) in train_loader
# Compute gradients and update parameters
grads = gradient(() -> loss_fn(batch_data, batch_labels), Flux.params(model))
grads = gradient(() -> criterion(model, batch_data, batch_labels), Flux.params(model))
Flux.Optimise.update!(opt, Flux.params(model), grads)

# Accumulate loss
total_loss += loss_fn(batch_data, batch_labels)
total_loss += criterion(model, batch_data, batch_labels)
end
avg_loss = total_loss / length(train_loader)
println("Epoch $epoch/$num_epochs, Loss: $(round(avg_loss, digits=4))")
end

# Generate test data
test_data, test_labels = create_dataset(seq_length, 200)
test_loader = DataLoader((test_data, test_labels), batchsize=batch_size, shuffle=false)

# Evaluation
correct = 0
total = 0
for (batch_data, batch_labels) in test_loader
# Adjust labels
batch_labels = adjust_labels(batch_labels)
# Forward pass
y_pred = model(batch_data)
# Decode predictions
predicted = Flux.onecold(y_pred, 1:num_classes)
# Flatten and compare
correct += sum(vec(predicted) .== vec(batch_labels))
total += length(batch_labels)
function test_recurrent(test_loader, model)
# Evaluation
correct = 0
total = 0
for (batch_data, batch_labels) in test_loader

# Forward pass
predicted = model(batch_data)

# Decode predictions: convert probabilities to class labels (0 or 1)
predicted_labels = vec(predicted .>= 0.5) # Threshold at 0.5 for binary classification

# Compare predicted labels to actual labels
correct += sum(predicted_labels .== vec(batch_labels))
total += length(batch_labels)
end
accuracy = correct / total
println("Accuracy: ", accuracy * 100, "%")
end

accuracy = 100 * correct / total
println("Test Accuracy: $(round(accuracy, digits=2))%")
function main(;
input_size = 1, # Each element in the sequence is a scalar
hidden_size = 64, # Size of the hidden state
seq_length = 10, # Length of each sequence
batch_size = 16, # Batch size
num_epochs = 50, # Number of epochs for training
n_train = 1000, # Number of samples in train dataset
n_test = 200 # Number of samples in test dataset)
)
model = RecurrentModel(input_size, hidden_size)
# Generate test data
train_loader, test_loader = create_dataset(input_size, seq_length, n_train, n_test, batch_size)
# Define the optimizer
opt = Adam(0.001)

for epoch in 1:num_epochs
train_recurrent!(epoch, train_loader, opt, model, criterion)
end

test_recurrent(test_loader, model)

end

main()



```
Expand Down
30 changes: 30 additions & 0 deletions src/rhn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,33 @@ function (rhn::RHNCell)(inp, state=nothing)

return current_state
end

# TODO fix implementation here
struct RHN{M}
cell::M
end

Flux.@layer :expand RHN

"""
RHN((in, out)::Pair depth=3; kwargs...)
"""
function RHN((in, out)::Pair, depth=3; kwargs...)
cell = RHNCell(in => out, depth; kwargs...)
return RHN(cell)
end

function (rhn::RHN)(inp)
state = zeros_like(inp, size(rhn.cell.layers[2].weights, 2))
return rhn(inp, state)
end

function (rhn::RHN)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
for inp_t in eachslice(inp, dims=2)
state = rhn.cell(inp_t, state)
new_state = vcat(new_state, [state])
end
return stack(new_state, dims=2)
end

0 comments on commit aec45e4

Please sign in to comment.