Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Mar 3, 2024
1 parent 1175903 commit cf811fb
Show file tree
Hide file tree
Showing 37 changed files with 91 additions and 84 deletions.
4 changes: 2 additions & 2 deletions dfdx/examples/09-module-sequential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ use dfdx::prelude::*;
struct MlpConfig {
// Linear with compile time input size & runtime known output size
linear1: LinearConfig<Const<784>, usize>,
act1: ReLU,
act1: ops::ReLU,
// Linear with runtime input & output size
linear2: LinearConfig<usize, usize>,
act2: Tanh,
act2: ops::Tanh,
// Linear with runtime input & compile time output size.
linear3: LinearConfig<usize, Const<10>>,
}
Expand Down
4 changes: 2 additions & 2 deletions dfdx/examples/10-module-gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use dfdx::prelude::*;
#[derive(Clone, Default, Debug, Sequential)]
struct MlpConfig<const I: usize, const O: usize> {
linear1: LinearConstConfig<I, 64>,
act1: ReLU,
act1: ops::ReLU,
linear2: LinearConstConfig<64, 64>,
act2: ReLU,
act2: ops::ReLU,
linear3: LinearConstConfig<64, O>,
}

Expand Down
6 changes: 3 additions & 3 deletions dfdx/examples/11-module-optim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ use dfdx::prelude::*;
#[built(Mlp)]
struct MlpConfig {
l1: LinearConstConfig<5, 32>,
act1: ReLU,
act1: ops::ReLU,
l2: LinearConstConfig<32, 32>,
act2: ReLU,
act2: ops::ReLU,
l3: LinearConstConfig<32, 2>,
act3: Tanh,
act3: ops::Tanh,
}

fn main() {
Expand Down
6 changes: 3 additions & 3 deletions dfdx/examples/12-mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ impl ExactSizeDataset for MnistTrainSet {

// our network structure
type Mlp = (
(LinearConstConfig<784, 512>, ReLU),
(LinearConstConfig<512, 128>, ReLU),
(LinearConstConfig<128, 32>, ReLU),
(LinearConstConfig<784, 512>, ops::ReLU),
(LinearConstConfig<512, 128>, ops::ReLU),
(LinearConstConfig<128, 32>, ops::ReLU),
LinearConstConfig<32, 10>,
);

Expand Down
4 changes: 2 additions & 2 deletions dfdx/examples/advanced-gradient-accum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ fn main() {

type Model = (
LinearConstConfig<2, 5>,
ReLU,
ops::ReLU,
LinearConstConfig<5, 10>,
Tanh,
ops::Tanh,
LinearConstConfig<10, 20>,
);
let model = dev.build_module::<f32>(Model::default());
Expand Down
16 changes: 8 additions & 8 deletions dfdx/examples/advanced-resnet18.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fn main() {
pub struct BasicBlockInternal<const C: usize> {
conv1: Conv2DConstConfig<C, C, 3, 1, 1>,
bn1: BatchNorm2DConstConfig<C>,
relu: ReLU,
relu: ops::ReLU,
conv2: Conv2DConstConfig<C, C, 3, 1, 1>,
bn2: BatchNorm2DConstConfig<C>,
}
Expand All @@ -24,7 +24,7 @@ fn main() {
pub struct DownsampleA<const C: usize, const D: usize> {
conv1: Conv2DConstConfig<C, D, 3, 2, 1>,
bn1: BatchNorm2DConstConfig<D>,
relu: ReLU,
relu: ops::ReLU,
conv2: Conv2DConstConfig<D, D, 3, 1, 1>,
bn2: BatchNorm2DConstConfig<D>,
}
Expand All @@ -44,18 +44,18 @@ fn main() {
pub struct Head {
conv: Conv2DConstConfig<3, 64, 7, 2, 3>,
bn: BatchNorm2DConstConfig<64>,
relu: ReLU,
pool: MaxPool2DConst<3, 2, 1>,
relu: ops::ReLU,
pool: ops::MaxPool2DConst<3, 2, 1>,
}

#[derive(Default, Clone, Sequential)]
#[built(Resnet18)]
pub struct Resnet18Config<const NUM_CLASSES: usize> {
head: Head,
l1: (BasicBlock<64>, ReLU, BasicBlock<64>, ReLU),
l2: (Downsample<64, 128>, ReLU, BasicBlock<128>, ReLU),
l3: (Downsample<128, 256>, ReLU, BasicBlock<256>, ReLU),
l4: (Downsample<256, 512>, ReLU, BasicBlock<512>, ReLU),
l1: (BasicBlock<64>, ops::ReLU, BasicBlock<64>, ops::ReLU),
l2: (Downsample<64, 128>, ops::ReLU, BasicBlock<128>, ops::ReLU),
l3: (Downsample<128, 256>, ops::ReLU, BasicBlock<256>, ops::ReLU),
l4: (Downsample<256, 512>, ops::ReLU, BasicBlock<512>, ops::ReLU),
l5: (AvgPoolGlobal, LinearConstConfig<512, NUM_CLASSES>),
}

Expand Down
4 changes: 2 additions & 2 deletions dfdx/examples/advanced-rl-dqn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ const ACTION: usize = 2;

// our simple 2 layer feedforward network with ReLU activations
type QNetwork = (
(LinearConstConfig<STATE, 32>, ReLU),
(LinearConstConfig<32, 32>, ReLU),
(LinearConstConfig<STATE, 32>, ops::ReLU),
(LinearConstConfig<32, 32>, ops::ReLU),
LinearConstConfig<32, ACTION>,
);

Expand Down
4 changes: 2 additions & 2 deletions dfdx/examples/advanced-rl-ppo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ const STATE: usize = 4;
const ACTION: usize = 2;

type PolicyNetwork = (
(LinearConstConfig<STATE, 32>, ReLU),
(LinearConstConfig<32, 32>, ReLU),
(LinearConstConfig<STATE, 32>, ops::ReLU),
(LinearConstConfig<32, 32>, ops::ReLU),
LinearConstConfig<32, ACTION>,
);

Expand Down
8 changes: 4 additions & 4 deletions dfdx/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@
//! struct MlpConfig {
//! // Linear with compile time input size & runtime known output size
//! linear1: LinearConfig<Const<784>, usize>,
//! act1: ReLU,
//! act1: ops::ReLU,
//! // Linear with runtime input & output size
//! linear2: LinearConfig<usize, usize>,
//! act2: Tanh,
//! act2: ops::Tanh,
//! // Linear with runtime input & compile time output size.
//! linear3: LinearConfig<usize, Const<10>>,
//! }
Expand All @@ -208,7 +208,7 @@
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! type Arch = (LinearConstConfig<3, 5>, ReLU, LinearConstConfig<5, 10>);
//! type Arch = (LinearConstConfig<3, 5>, ops::ReLU, LinearConstConfig<5, 10>);
//! let mut model = dev.build_module::<f32>(Arch::default());
//! let x: Tensor<(usize, Const<3>), f32, _> = dev.sample_uniform_like(&(100, Const));
//! let y = model.forward_mut(x);
Expand All @@ -233,7 +233,7 @@
//! ```rust
//! # use dfdx::prelude::*;
//! # let dev: Cpu = Default::default();
//! type Arch = (LinearConstConfig<3, 5>, ReLU, LinearConstConfig<5, 10>);
//! type Arch = (LinearConstConfig<3, 5>, ops::ReLU, LinearConstConfig<5, 10>);
//! let arch = Arch::default();
//! let mut model = dev.build_module::<f32>(arch);
//! // 1. allocate gradients for the model
Expand Down
2 changes: 1 addition & 1 deletion dfdx/src/nn/layers/add_into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ mod tests {
// check if it works in a longer neural net
type Model = (
AddInto<(LinearConstConfig<5, 3>, LinearConstConfig<5, 3>)>,
ReLU,
ops::ReLU,
LinearConstConfig<3, 1>,
);
let mut model = dev.build_module::<TestDtype>(Model::default());
Expand Down
2 changes: 1 addition & 1 deletion dfdx/src/nn/layers/generalized_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::prelude::*;
/// # use dfdx::prelude::*;
/// # use dfdx::*;
/// # let dev: Cpu = Default::default();
/// type Model = GeneralizedAdd<ReLU, Square>;
/// type Model = GeneralizedAdd<ops::ReLU, ops::Square>;
/// let model = dev.build_module::<f32>(Model::default());
/// let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]);
/// let y = model.forward(x);
Expand Down
2 changes: 1 addition & 1 deletion dfdx/src/nn/layers/generalized_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::prelude::*;
/// # use dfdx::prelude::*;
/// # use dfdx::*;
/// # let dev: Cpu = Default::default();
/// type Model = GeneralizedMul<ReLU, Square>;
/// type Model = GeneralizedMul<ops::ReLU, ops::Square>;
/// let model = dev.build_module::<f32>(Model::default());
/// let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]);
/// let y = model.forward(x);
Expand Down
47 changes: 1 addition & 46 deletions dfdx/src/nn/layers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
mod abs;
mod add_into;
mod batch_norm1d;
mod batch_norm2d;
Expand All @@ -10,48 +9,26 @@ mod conv1d;
mod conv2d;
#[cfg(feature = "nightly")]
mod conv_trans2d;
mod cos;
mod dropout;
mod embedding;
mod exp;
#[cfg(feature = "nightly")]
mod flatten2d;
mod gelu;
mod generalized_add;
mod generalized_mul;
mod layer_norm1d;
mod leaky_relu;
mod linear;
mod ln;
mod log_softmax;
mod matmul;
mod multi_head_attention;
#[cfg(feature = "nightly")]
mod pool_2d_avg;
#[cfg(feature = "nightly")]
mod pool_2d_max;
#[cfg(feature = "nightly")]
mod pool_2d_min;
pub mod ops;
mod pool_global_avg;
mod pool_global_max;
mod pool_global_min;
mod prelu;
mod prelu1d;
mod relu;
mod reshape;
mod residual_add;
mod residual_mul;
mod sigmoid;
mod sin;
mod softmax;
mod split_into;
mod sqrt;
mod square;
mod tanh;
mod transformer;
mod upscale2d;

pub use abs::Abs;
pub use add_into::AddInto;
pub use batch_norm1d::{BatchNorm1D, BatchNorm1DConfig, BatchNorm1DConstConfig};
pub use batch_norm2d::{BatchNorm2D, BatchNorm2DConfig, BatchNorm2DConstConfig};
Expand All @@ -63,44 +40,22 @@ pub use conv1d::{Conv1D, Conv1DConfig, Conv1DConstConfig};
pub use conv2d::{Conv2D, Conv2DConfig, Conv2DConstConfig};
#[cfg(feature = "nightly")]
pub use conv_trans2d::{ConvTrans2D, ConvTrans2DConfig, ConvTrans2DConstConfig};
pub use cos::Cos;
pub use dropout::{Dropout, DropoutOneIn};
pub use embedding::{Embedding, EmbeddingConfig, EmbeddingConstConfig};
pub use exp::Exp;
#[cfg(feature = "nightly")]
pub use flatten2d::Flatten2D;
pub use gelu::{AccurateGeLU, FastGeLU};
pub use generalized_add::GeneralizedAdd;
pub use generalized_mul::GeneralizedMul;
pub use layer_norm1d::{LayerNorm1D, LayerNorm1DConfig, LayerNorm1DConstConfig};
pub use leaky_relu::LeakyReLU;
pub use linear::{Linear, LinearConfig, LinearConstConfig};
pub use ln::Ln;
pub use log_softmax::LogSoftmax;
pub use matmul::{MatMul, MatMulConfig, MatMulConstConfig};
pub use multi_head_attention::{MultiHeadAttention, MultiHeadAttentionConfig};
#[cfg(feature = "nightly")]
pub use pool_2d_avg::{AvgPool2D, AvgPool2DConst};
#[cfg(feature = "nightly")]
pub use pool_2d_max::{MaxPool2D, MaxPool2DConst};
#[cfg(feature = "nightly")]
pub use pool_2d_min::{MinPool2D, MinPool2DConst};
pub use pool_global_avg::AvgPoolGlobal;
pub use pool_global_max::MaxPoolGlobal;
pub use pool_global_min::MinPoolGlobal;
pub use prelu::{PReLU, PReLUConfig};
pub use prelu1d::{PReLU1D, PReLU1DConfig};
pub use relu::ReLU;
pub use reshape::Reshape;
pub use residual_add::ResidualAdd;
pub use residual_mul::ResidualMul;
pub use sigmoid::Sigmoid;
pub use sin::Sin;
pub use softmax::Softmax;
pub use split_into::SplitInto;
pub use sqrt::Sqrt;
pub use square::Square;
pub use tanh::Tanh;
pub use transformer::{
DecoderBlock, DecoderBlockConfig, EncoderBlock, EncoderBlockConfig, Transformer,
TransformerConfig,
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::prelude::*;
/// # use dfdx::prelude::*;
/// # use dfdx::*;
/// # let dev: Cpu = Default::default();
/// let mut dropout: DropoutOneIn<2> = Default::default();
/// let mut dropout: ops::DropoutOneIn<2> = Default::default();
/// let grads = dropout.alloc_grads();
/// let x: Tensor<Rank2<2, 5>, f32, _> = dev.ones();
/// let r = dropout.forward_mut(x.trace(grads));
Expand Down Expand Up @@ -49,7 +49,7 @@ impl<const N: usize, S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Ten
/// ```rust
/// # use dfdx::prelude::*;
/// # let dev: Cpu = Default::default();
/// let mut dropout = Dropout { p: 0.5 };
/// let mut dropout = ops::Dropout { p: 0.5 };
/// let grads = dropout.alloc_grads();
/// let x: Tensor<Rank2<2, 5>, f32, _> = dev.ones();
/// let r = dropout.forward_mut(x.trace(grads));
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
52 changes: 52 additions & 0 deletions dfdx/src/nn/layers/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//! Layers that mostly wraps the functionality of a [tensor_ops].
#[allow(unused_imports)]
use crate::tensor_ops;

pub mod abs;
pub mod cos;
pub mod dropout;
pub mod exp;
pub mod gelu;
pub mod leaky_relu;
pub mod ln;
pub mod log_softmax;
#[cfg(feature = "nightly")]
pub mod pool_2d_avg;
#[cfg(feature = "nightly")]
pub mod pool_2d_max;
#[cfg(feature = "nightly")]
pub mod pool_2d_min;
pub mod prelu;
pub mod prelu1d;
pub mod relu;
pub mod sigmoid;
pub mod sin;
pub mod softmax;
pub mod sqrt;
pub mod square;
pub mod tanh;

pub use abs::Abs;
pub use cos::Cos;
pub use dropout::{Dropout, DropoutOneIn};
pub use exp::Exp;
pub use gelu::{AccurateGeLU, FastGeLU};
pub use leaky_relu::LeakyReLU;
pub use ln::Ln;
pub use log_softmax::LogSoftmax;
#[cfg(feature = "nightly")]
pub use pool_2d_avg::{AvgPool2D, AvgPool2DConst};
#[cfg(feature = "nightly")]
pub use pool_2d_max::{MaxPool2D, MaxPool2DConst};
pub use pool_2d_min::{MinPool2D, MinPool2DConst};

Check failure on line 42 in dfdx/src/nn/layers/ops/mod.rs

View workflow job for this annotation

GitHub Actions / cargo-check

unresolved import `pool_2d_min`
#[cfg(feature = "nightly")]
pub use prelu::{PReLU, PReLUConfig};
pub use prelu1d::{PReLU1D, PReLU1DConfig};
pub use relu::ReLU;
pub use sigmoid::Sigmoid;
pub use sin::Sin;
pub use softmax::Softmax;
pub use sqrt::Sqrt;
pub use square::Square;
pub use tanh::Tanh;
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion dfdx/src/nn/layers/residual_add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::prelude::*;
/// # use dfdx::prelude::*;
/// # use dfdx::*;
/// # let dev: Cpu = Default::default();
/// type Model = ResidualAdd<ReLU>;
/// type Model = ResidualAdd<ops::ReLU>;
/// let model = dev.build_module::<f32>(Model::default());
/// let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]);
/// let y = model.forward(x);
Expand Down
2 changes: 1 addition & 1 deletion dfdx/src/nn/layers/residual_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::prelude::*;
/// # use dfdx::prelude::*;
/// # use dfdx::*;
/// # let dev: Cpu = Default::default();
/// type Model = ResidualMul<ReLU>;
/// type Model = ResidualMul<ops::ReLU>;
/// let model = dev.build_module::<f32>(Model::default());
/// let x = dev.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]);
/// let y = model.forward(x);
Expand Down
Loading

0 comments on commit cf811fb

Please sign in to comment.