rlox_candle/
mlp.rs

1use candle_core::{Result, Tensor};
2use candle_nn::{linear, Linear, Module, VarBuilder};
3use rlox_nn::{Activation, MLPConfig};
4
5pub struct MLP {
6    layers: Vec<Linear>,
7    activation: Activation,
8    output_activation: Option<Activation>,
9}
10
11fn apply_activation(x: &Tensor, act: Activation) -> Result<Tensor> {
12    match act {
13        Activation::ReLU => x.relu(),
14        Activation::Tanh => x.tanh(),
15    }
16}
17
18impl MLP {
19    pub fn new(config: &MLPConfig, vb: VarBuilder) -> Result<Self> {
20        let mut dims = Vec::new();
21        dims.push(config.input_dim);
22        dims.extend_from_slice(&config.hidden_dims);
23        dims.push(config.output_dim);
24
25        let layers: Vec<Linear> = dims
26            .windows(2)
27            .enumerate()
28            .map(|(i, w)| linear(w[0], w[1], vb.pp(format!("layer_{i}"))))
29            .collect::<Result<_>>()?;
30
31        Ok(Self {
32            layers,
33            activation: config.activation,
34            output_activation: config.output_activation,
35        })
36    }
37
38    pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
39        let n = self.layers.len();
40        let mut x = input.clone();
41        for (i, layer) in self.layers.iter().enumerate() {
42            x = layer.forward(&x)?;
43            if i < n - 1 {
44                x = apply_activation(&x, self.activation)?;
45            } else if let Some(out_act) = self.output_activation {
46                x = apply_activation(&x, out_act)?;
47            }
48        }
49        Ok(x)
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56    use candle_core::Device;
57    use candle_nn::VarMap;
58
59    #[test]
60    fn test_mlp_forward_shape() {
61        let varmap = VarMap::new();
62        let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &Device::Cpu);
63        let config = MLPConfig::new(4, 2).with_hidden(vec![64, 64]);
64        let mlp = MLP::new(&config, vb).unwrap();
65
66        let input = Tensor::zeros((8, 4), candle_core::DType::F32, &Device::Cpu).unwrap();
67        let output = mlp.forward(&input).unwrap();
68        assert_eq!(output.dims(), &[8, 2]);
69    }
70
71    #[test]
72    fn test_mlp_single_hidden() {
73        let varmap = VarMap::new();
74        let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &Device::Cpu);
75        let config = MLPConfig::new(3, 1).with_hidden(vec![16]);
76        let mlp = MLP::new(&config, vb).unwrap();
77
78        let input = Tensor::zeros((4, 3), candle_core::DType::F32, &Device::Cpu).unwrap();
79        let output = mlp.forward(&input).unwrap();
80        assert_eq!(output.dims(), &[4, 1]);
81    }
82
83    #[test]
84    fn test_mlp_tanh_output() {
85        let varmap = VarMap::new();
86        let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &Device::Cpu);
87        let config = MLPConfig::new(4, 2)
88            .with_hidden(vec![32])
89            .with_output_activation(Activation::Tanh);
90        let mlp = MLP::new(&config, vb).unwrap();
91
92        let input = Tensor::ones((4, 4), candle_core::DType::F32, &Device::Cpu).unwrap();
93        let input = (&input * 100.0).unwrap();
94        let output = mlp.forward(&input).unwrap();
95        let data: Vec<f32> = output.flatten_all().unwrap().to_vec1().unwrap();
96        for &v in &data {
97            assert!(v >= -1.0 && v <= 1.0, "tanh output out of range: {v}");
98        }
99    }
100}