From 8fb7ce2d855c53eb9ad3856bffb75911a471c067 Mon Sep 17 00:00:00 2001 From: Daniel Boros Date: Sun, 3 Nov 2024 13:47:11 +0100 Subject: [PATCH] fix: fou vae feed forward dim errors, mutlihead attention still not working --- Cargo.toml | 1 + src/ai/fou.rs | 2 +- src/ai/fou/fou_vae.rs | 36 ++++++++++++++++++++++-------------- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index dceb16d..0576c26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,6 +56,7 @@ time = { version = "0.3.36", features = [ ], optional = true } tokio-test = "0.4.4" tracing = "0.1.40" +tracing-test = "0.2.5" yahoo_finance_api = { version = "2.3.0", optional = true } [dev-dependencies] diff --git a/src/ai/fou.rs b/src/ai/fou.rs index 1aab25a..64822b1 100644 --- a/src/ai/fou.rs +++ b/src/ai/fou.rs @@ -1,4 +1,4 @@ pub mod fou_lstm_datasets; pub mod fou_lstm_model_1_d; pub mod fou_lstm_model_2_d; -// pub mod fou_vae; +pub mod fou_vae; diff --git a/src/ai/fou/fou_vae.rs b/src/ai/fou/fou_vae.rs index 90a4832..82ec972 100644 --- a/src/ai/fou/fou_vae.rs +++ b/src/ai/fou/fou_vae.rs @@ -1,11 +1,10 @@ -use std::{borrow::BorrowMut, cell::RefCell, f64::consts::PI}; +use std::cell::RefCell; -use candle_core::{DType, Device, IndexOp, Result, Shape, Tensor}; +use candle_core::{DType, Device, IndexOp, Result, Tensor}; use candle_nn::{ - layer_norm, linear, linear_no_bias, ops::dropout, seq, Activation, Dropout, LayerNorm, - LayerNormConfig, Linear, Module, Sequential, VarBuilder, + layer_norm, linear, linear_no_bias, seq, Activation, Dropout, LayerNorm, LayerNormConfig, Linear, + Module, Sequential, VarBuilder, }; -use candle_transformers::models::mimi::transformer::StreamingMultiheadAttention; pub struct Time2Vec { seq_len: usize, @@ -18,10 +17,10 @@ pub struct Time2Vec { impl Time2Vec { pub fn new(seq_len: usize, embed_dim: usize, device: &Device) -> Result { - let wb = Tensor::zeros((embed_dim,), DType::F32, device)?; - let bb = Tensor::zeros((embed_dim,), DType::F32, device)?; - let wa = Tensor::zeros((embed_dim,), DType::F32, device)?; - let ba = Tensor::zeros((embed_dim,), DType::F32, device)?; + let wb = Tensor::zeros((embed_dim,), DType::F64, device)?; + let bb = Tensor::zeros((embed_dim,), DType::F64, device)?; + let wa = Tensor::zeros((embed_dim,), DType::F64, device)?; + let ba = Tensor::zeros((embed_dim,), DType::F64, device)?; Ok(Self { seq_len, @@ -38,12 +37,20 @@ impl Module for Time2Vec { fn forward(&self, xs: &Tensor) -> Result { // x shape: (batch_size, seq_len, input_dim) let batch_size = xs.shape().dims()[0]; - let tt = Tensor::arange(0f32, self.seq_len as f32, xs.device())?.unsqueeze(1)?; + let tt = Tensor::arange(0f64, self.seq_len as f64, xs.device())?.unsqueeze(1)?; let tt = tt.repeat(&[1, self.embed_dim])?; - let v = (&self.wb * &tt + &self.bb + (&self.wa * &tt + &self.ba)?.sin()?)?; + + // TODO: maybe this is slow + // https://github.com/huggingface/candle/issues/2499 + let v = &self + .wb + .broadcast_mul(&tt)? + .broadcast_add(&self.bb)? + .broadcast_add(&self.wa.broadcast_mul(&tt)?.broadcast_add(&self.ba)?.sin()?)?; let v = v .unsqueeze(0)? .expand(&[batch_size, self.seq_len, self.embed_dim])?; + Ok(v) } } @@ -56,13 +63,12 @@ pub struct FeedForward { impl FeedForward { pub fn new(n_embd: usize, dropout_rate: f32, vs: VarBuilder) -> Result { let linear1 = linear(n_embd, 4 * n_embd, vs.pp("feedforward_linear1"))?; - let new_gelu = Activation::NewGelu; let linear2 = linear(4 * n_embd, n_embd, vs.pp("feedforward_linear2"))?; let dropout = Dropout::new(dropout_rate); let net = seq() .add(linear1) - .add(new_gelu) + .add(Activation::NewGelu) .add(linear2) .add_fn(move |xs| Ok(dropout.forward(&xs, true).unwrap())); @@ -204,6 +210,7 @@ impl Block { let ffwd = FeedForward::new(n_embd, dropout_rate, vs.pp("ffwd"))?; let ln1 = layer_norm(n_embd, LayerNormConfig::default(), vs.pp("ln1"))?; let ln2 = layer_norm(n_embd, LayerNormConfig::default(), vs.pp("ln2"))?; + Ok(Self { sa: RefCell::new(sa), ffwd, @@ -292,7 +299,6 @@ impl TransformerEncoder { let mu = self.fc_mu.forward(&xs_pooled)?; let log_var = self.fc_log_var.forward(&xs_pooled)?; let sigma_estimated = self.fc_volatility.forward(&xs_pooled)?; - Ok((mu, log_var, sigma_estimated)) } } @@ -428,7 +434,9 @@ mod tests { use super::*; use candle_core::{DType, Device, Result, Tensor}; use candle_nn::VarMap; + use tracing_test::traced_test; + #[traced_test] #[test] fn test_transformer_vae_forward() -> Result<()> { let seq_len = 10;