Skip to content

Commit

Permalink
fix: fou vae feed forward dim errors, mutlihead attention still not w…
Browse files Browse the repository at this point in the history
…orking
  • Loading branch information
dancixx committed Nov 3, 2024
1 parent ac9a3cc commit 8fb7ce2
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ time = { version = "0.3.36", features = [
], optional = true }
tokio-test = "0.4.4"
tracing = "0.1.40"
tracing-test = "0.2.5"
yahoo_finance_api = { version = "2.3.0", optional = true }

[dev-dependencies]
Expand Down
2 changes: 1 addition & 1 deletion src/ai/fou.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod fou_lstm_datasets;
pub mod fou_lstm_model_1_d;
pub mod fou_lstm_model_2_d;
// pub mod fou_vae;
pub mod fou_vae;
36 changes: 22 additions & 14 deletions src/ai/fou/fou_vae.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::{borrow::BorrowMut, cell::RefCell, f64::consts::PI};
use std::cell::RefCell;

use candle_core::{DType, Device, IndexOp, Result, Shape, Tensor};
use candle_core::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{
layer_norm, linear, linear_no_bias, ops::dropout, seq, Activation, Dropout, LayerNorm,
LayerNormConfig, Linear, Module, Sequential, VarBuilder,
layer_norm, linear, linear_no_bias, seq, Activation, Dropout, LayerNorm, LayerNormConfig, Linear,
Module, Sequential, VarBuilder,
};
use candle_transformers::models::mimi::transformer::StreamingMultiheadAttention;

pub struct Time2Vec {
seq_len: usize,
Expand All @@ -18,10 +17,10 @@ pub struct Time2Vec {

impl Time2Vec {
pub fn new(seq_len: usize, embed_dim: usize, device: &Device) -> Result<Self> {
let wb = Tensor::zeros((embed_dim,), DType::F32, device)?;
let bb = Tensor::zeros((embed_dim,), DType::F32, device)?;
let wa = Tensor::zeros((embed_dim,), DType::F32, device)?;
let ba = Tensor::zeros((embed_dim,), DType::F32, device)?;
let wb = Tensor::zeros((embed_dim,), DType::F64, device)?;
let bb = Tensor::zeros((embed_dim,), DType::F64, device)?;
let wa = Tensor::zeros((embed_dim,), DType::F64, device)?;
let ba = Tensor::zeros((embed_dim,), DType::F64, device)?;

Ok(Self {
seq_len,
Expand All @@ -38,12 +37,20 @@ impl Module for Time2Vec {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
// x shape: (batch_size, seq_len, input_dim)
let batch_size = xs.shape().dims()[0];
let tt = Tensor::arange(0f32, self.seq_len as f32, xs.device())?.unsqueeze(1)?;
let tt = Tensor::arange(0f64, self.seq_len as f64, xs.device())?.unsqueeze(1)?;
let tt = tt.repeat(&[1, self.embed_dim])?;
let v = (&self.wb * &tt + &self.bb + (&self.wa * &tt + &self.ba)?.sin()?)?;

// TODO: maybe this is slow
// https://github.com/huggingface/candle/issues/2499
let v = &self
.wb
.broadcast_mul(&tt)?
.broadcast_add(&self.bb)?
.broadcast_add(&self.wa.broadcast_mul(&tt)?.broadcast_add(&self.ba)?.sin()?)?;
let v = v
.unsqueeze(0)?
.expand(&[batch_size, self.seq_len, self.embed_dim])?;

Ok(v)
}
}
Expand All @@ -56,13 +63,12 @@ pub struct FeedForward {
impl FeedForward {
pub fn new(n_embd: usize, dropout_rate: f32, vs: VarBuilder) -> Result<Self> {
let linear1 = linear(n_embd, 4 * n_embd, vs.pp("feedforward_linear1"))?;
let new_gelu = Activation::NewGelu;
let linear2 = linear(4 * n_embd, n_embd, vs.pp("feedforward_linear2"))?;
let dropout = Dropout::new(dropout_rate);

let net = seq()
.add(linear1)
.add(new_gelu)
.add(Activation::NewGelu)
.add(linear2)
.add_fn(move |xs| Ok(dropout.forward(&xs, true).unwrap()));

Expand Down Expand Up @@ -204,6 +210,7 @@ impl Block {
let ffwd = FeedForward::new(n_embd, dropout_rate, vs.pp("ffwd"))?;
let ln1 = layer_norm(n_embd, LayerNormConfig::default(), vs.pp("ln1"))?;
let ln2 = layer_norm(n_embd, LayerNormConfig::default(), vs.pp("ln2"))?;

Ok(Self {
sa: RefCell::new(sa),
ffwd,
Expand Down Expand Up @@ -292,7 +299,6 @@ impl TransformerEncoder {
let mu = self.fc_mu.forward(&xs_pooled)?;
let log_var = self.fc_log_var.forward(&xs_pooled)?;
let sigma_estimated = self.fc_volatility.forward(&xs_pooled)?;

Ok((mu, log_var, sigma_estimated))
}
}
Expand Down Expand Up @@ -428,7 +434,9 @@ mod tests {
use super::*;
use candle_core::{DType, Device, Result, Tensor};
use candle_nn::VarMap;
use tracing_test::traced_test;

#[traced_test]
#[test]
fn test_transformer_vae_forward() -> Result<()> {
let seq_len = 10;
Expand Down

0 comments on commit 8fb7ce2

Please sign in to comment.