1use candle_core::{Device, Tensor};
2use rlox_nn::TensorData;
3
4pub fn to_tensor_1d(data: &TensorData, device: &Device) -> candle_core::Result<Tensor> {
5 Tensor::from_vec(data.data.clone(), data.data.len(), device)
6}
7
8pub fn to_tensor_2d(data: &TensorData, device: &Device) -> candle_core::Result<Tensor> {
9 assert_eq!(data.shape.len(), 2);
10 Tensor::from_vec(data.data.clone(), (data.shape[0], data.shape[1]), device)
11}
12
13pub fn from_tensor_1d(tensor: &Tensor) -> candle_core::Result<TensorData> {
14 let data: Vec<f32> = tensor.to_vec1()?;
15 let len = data.len();
16 Ok(TensorData::new(data, vec![len]))
17}
18
19pub fn from_tensor_2d(tensor: &Tensor) -> candle_core::Result<TensorData> {
20 let dims = tensor.dims();
21 let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
22 Ok(TensorData::new(data, vec![dims[0], dims[1]]))
23}
24
25pub fn to_int_tensor_1d(data: &TensorData, device: &Device) -> candle_core::Result<Tensor> {
26 let ints: Vec<u32> = data.data.iter().map(|&x| x as u32).collect();
27 let len = ints.len();
28 Tensor::from_vec(ints, len, device)
29}
30
31fn candle_err(e: candle_core::Error) -> rlox_nn::NNError {
32 rlox_nn::NNError::Backend(e.to_string())
33}
34
35pub trait IntoNNError<T> {
36 fn nn_err(self) -> Result<T, rlox_nn::NNError>;
37}
38
39impl<T> IntoNNError<T> for candle_core::Result<T> {
40 fn nn_err(self) -> Result<T, rlox_nn::NNError> {
41 self.map_err(candle_err)
42 }
43}
44
45#[cfg(test)]
46mod tests {
47 use super::*;
48
49 #[test]
50 fn test_roundtrip_1d() {
51 let original = TensorData::new(vec![1.0, 2.0, 3.0], vec![3]);
52 let tensor = to_tensor_1d(&original, &Device::Cpu).unwrap();
53 let result = from_tensor_1d(&tensor).unwrap();
54 assert_eq!(result.data, original.data);
55 }
56
57 #[test]
58 fn test_roundtrip_2d() {
59 let original = TensorData::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
60 let tensor = to_tensor_2d(&original, &Device::Cpu).unwrap();
61 let result = from_tensor_2d(&tensor).unwrap();
62 assert_eq!(result.data, original.data);
63 assert_eq!(result.shape, vec![2, 2]);
64 }
65}