1#[derive(Debug, Clone, PartialEq)]
7pub struct TensorData {
8 pub data: Vec<f32>,
9 pub shape: Vec<usize>,
10}
11
12impl TensorData {
13 pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
14 debug_assert_eq!(
15 data.len(),
16 shape.iter().product::<usize>(),
17 "data length {} must match shape product {:?} = {}",
18 data.len(),
19 shape,
20 shape.iter().product::<usize>()
21 );
22 Self { data, shape }
23 }
24
25 pub fn zeros(shape: Vec<usize>) -> Self {
26 let len = shape.iter().product();
27 Self {
28 data: vec![0.0; len],
29 shape,
30 }
31 }
32
33 pub fn ones(shape: Vec<usize>) -> Self {
34 let len = shape.iter().product();
35 Self {
36 data: vec![1.0; len],
37 shape,
38 }
39 }
40
41 pub fn from_f64(data: &[f64], shape: Vec<usize>) -> Self {
42 Self::new(data.iter().map(|&x| x as f32).collect(), shape)
43 }
44
45 pub fn numel(&self) -> usize {
46 self.data.len()
47 }
48
49 pub fn ndim(&self) -> usize {
50 self.shape.len()
51 }
52
53 pub fn is_empty(&self) -> bool {
54 self.data.is_empty()
55 }
56}
57
58#[cfg(test)]
59mod tests {
60 use super::*;
61
62 #[test]
63 fn test_zeros() {
64 let t = TensorData::zeros(vec![2, 3]);
65 assert_eq!(t.data.len(), 6);
66 assert!(t.data.iter().all(|&x| x == 0.0));
67 assert_eq!(t.shape, vec![2, 3]);
68 }
69
70 #[test]
71 fn test_ones() {
72 let t = TensorData::ones(vec![4]);
73 assert_eq!(t.data.len(), 4);
74 assert!(t.data.iter().all(|&x| x == 1.0));
75 }
76
77 #[test]
78 fn test_from_f64() {
79 let vals = vec![1.0_f64, 2.0, 3.0];
80 let t = TensorData::from_f64(&vals, vec![3]);
81 assert_eq!(t.data, vec![1.0_f32, 2.0, 3.0]);
82 }
83
84 #[test]
85 fn test_numel_ndim() {
86 let t = TensorData::zeros(vec![2, 3, 4]);
87 assert_eq!(t.numel(), 24);
88 assert_eq!(t.ndim(), 3);
89 assert!(!t.is_empty());
90 }
91
92 #[test]
93 fn test_empty() {
94 let t = TensorData::zeros(vec![0]);
95 assert!(t.is_empty());
96 }
97
98 #[test]
99 #[should_panic(expected = "data length")]
100 fn test_shape_mismatch_panics_in_debug() {
101 TensorData::new(vec![1.0, 2.0], vec![3]);
102 }
103}