rlox_nn/
tensor_data.rs

1/// Backend-agnostic tensor data container.
2///
3/// This is the "lingua franca" across the trait boundary between
4/// rlox-core (raw buffers/slices) and NN backends. Data is stored
5/// as flat f32 in row-major order.
6#[derive(Debug, Clone, PartialEq)]
7pub struct TensorData {
8    pub data: Vec<f32>,
9    pub shape: Vec<usize>,
10}
11
12impl TensorData {
13    pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
14        debug_assert_eq!(
15            data.len(),
16            shape.iter().product::<usize>(),
17            "data length {} must match shape product {:?} = {}",
18            data.len(),
19            shape,
20            shape.iter().product::<usize>()
21        );
22        Self { data, shape }
23    }
24
25    pub fn zeros(shape: Vec<usize>) -> Self {
26        let len = shape.iter().product();
27        Self {
28            data: vec![0.0; len],
29            shape,
30        }
31    }
32
33    pub fn ones(shape: Vec<usize>) -> Self {
34        let len = shape.iter().product();
35        Self {
36            data: vec![1.0; len],
37            shape,
38        }
39    }
40
41    pub fn from_f64(data: &[f64], shape: Vec<usize>) -> Self {
42        Self::new(data.iter().map(|&x| x as f32).collect(), shape)
43    }
44
45    pub fn numel(&self) -> usize {
46        self.data.len()
47    }
48
49    pub fn ndim(&self) -> usize {
50        self.shape.len()
51    }
52
53    pub fn is_empty(&self) -> bool {
54        self.data.is_empty()
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61
62    #[test]
63    fn test_zeros() {
64        let t = TensorData::zeros(vec![2, 3]);
65        assert_eq!(t.data.len(), 6);
66        assert!(t.data.iter().all(|&x| x == 0.0));
67        assert_eq!(t.shape, vec![2, 3]);
68    }
69
70    #[test]
71    fn test_ones() {
72        let t = TensorData::ones(vec![4]);
73        assert_eq!(t.data.len(), 4);
74        assert!(t.data.iter().all(|&x| x == 1.0));
75    }
76
77    #[test]
78    fn test_from_f64() {
79        let vals = vec![1.0_f64, 2.0, 3.0];
80        let t = TensorData::from_f64(&vals, vec![3]);
81        assert_eq!(t.data, vec![1.0_f32, 2.0, 3.0]);
82    }
83
84    #[test]
85    fn test_numel_ndim() {
86        let t = TensorData::zeros(vec![2, 3, 4]);
87        assert_eq!(t.numel(), 24);
88        assert_eq!(t.ndim(), 3);
89        assert!(!t.is_empty());
90    }
91
92    #[test]
93    fn test_empty() {
94        let t = TensorData::zeros(vec![0]);
95        assert!(t.is_empty());
96    }
97
98    #[test]
99    #[should_panic(expected = "data length")]
100    fn test_shape_mismatch_panics_in_debug() {
101        TensorData::new(vec![1.0, 2.0], vec![3]);
102    }
103}