1use candle_core::{Result, Tensor};
2use candle_nn::{linear, Linear, Module, VarBuilder};
3use rlox_nn::{Activation, MLPConfig};
4
5pub struct MLP {
6 layers: Vec<Linear>,
7 activation: Activation,
8 output_activation: Option<Activation>,
9}
10
11fn apply_activation(x: &Tensor, act: Activation) -> Result<Tensor> {
12 match act {
13 Activation::ReLU => x.relu(),
14 Activation::Tanh => x.tanh(),
15 }
16}
17
18impl MLP {
19 pub fn new(config: &MLPConfig, vb: VarBuilder) -> Result<Self> {
20 let mut dims = Vec::new();
21 dims.push(config.input_dim);
22 dims.extend_from_slice(&config.hidden_dims);
23 dims.push(config.output_dim);
24
25 let layers: Vec<Linear> = dims
26 .windows(2)
27 .enumerate()
28 .map(|(i, w)| linear(w[0], w[1], vb.pp(format!("layer_{i}"))))
29 .collect::<Result<_>>()?;
30
31 Ok(Self {
32 layers,
33 activation: config.activation,
34 output_activation: config.output_activation,
35 })
36 }
37
38 pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
39 let n = self.layers.len();
40 let mut x = input.clone();
41 for (i, layer) in self.layers.iter().enumerate() {
42 x = layer.forward(&x)?;
43 if i < n - 1 {
44 x = apply_activation(&x, self.activation)?;
45 } else if let Some(out_act) = self.output_activation {
46 x = apply_activation(&x, out_act)?;
47 }
48 }
49 Ok(x)
50 }
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56 use candle_core::Device;
57 use candle_nn::VarMap;
58
59 #[test]
60 fn test_mlp_forward_shape() {
61 let varmap = VarMap::new();
62 let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &Device::Cpu);
63 let config = MLPConfig::new(4, 2).with_hidden(vec![64, 64]);
64 let mlp = MLP::new(&config, vb).unwrap();
65
66 let input = Tensor::zeros((8, 4), candle_core::DType::F32, &Device::Cpu).unwrap();
67 let output = mlp.forward(&input).unwrap();
68 assert_eq!(output.dims(), &[8, 2]);
69 }
70
71 #[test]
72 fn test_mlp_single_hidden() {
73 let varmap = VarMap::new();
74 let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &Device::Cpu);
75 let config = MLPConfig::new(3, 1).with_hidden(vec![16]);
76 let mlp = MLP::new(&config, vb).unwrap();
77
78 let input = Tensor::zeros((4, 3), candle_core::DType::F32, &Device::Cpu).unwrap();
79 let output = mlp.forward(&input).unwrap();
80 assert_eq!(output.dims(), &[4, 1]);
81 }
82
83 #[test]
84 fn test_mlp_tanh_output() {
85 let varmap = VarMap::new();
86 let vb = VarBuilder::from_varmap(&varmap, candle_core::DType::F32, &Device::Cpu);
87 let config = MLPConfig::new(4, 2)
88 .with_hidden(vec![32])
89 .with_output_activation(Activation::Tanh);
90 let mlp = MLP::new(&config, vb).unwrap();
91
92 let input = Tensor::ones((4, 4), candle_core::DType::F32, &Device::Cpu).unwrap();
93 let input = (&input * 100.0).unwrap();
94 let output = mlp.forward(&input).unwrap();
95 let data: Vec<f32> = output.flatten_all().unwrap().to_vec1().unwrap();
96 for &v in &data {
97 assert!(v >= -1.0 && v <= 1.0, "tanh output out of range: {v}");
98 }
99 }
100}