rlox_candle/
convert.rs

1use candle_core::{Device, Tensor};
2use rlox_nn::TensorData;
3
4pub fn to_tensor_1d(data: &TensorData, device: &Device) -> candle_core::Result<Tensor> {
5    Tensor::from_vec(data.data.clone(), data.data.len(), device)
6}
7
8pub fn to_tensor_2d(data: &TensorData, device: &Device) -> candle_core::Result<Tensor> {
9    assert_eq!(data.shape.len(), 2);
10    Tensor::from_vec(data.data.clone(), (data.shape[0], data.shape[1]), device)
11}
12
13pub fn from_tensor_1d(tensor: &Tensor) -> candle_core::Result<TensorData> {
14    let data: Vec<f32> = tensor.to_vec1()?;
15    let len = data.len();
16    Ok(TensorData::new(data, vec![len]))
17}
18
19pub fn from_tensor_2d(tensor: &Tensor) -> candle_core::Result<TensorData> {
20    let dims = tensor.dims();
21    let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
22    Ok(TensorData::new(data, vec![dims[0], dims[1]]))
23}
24
25pub fn to_int_tensor_1d(data: &TensorData, device: &Device) -> candle_core::Result<Tensor> {
26    let ints: Vec<u32> = data.data.iter().map(|&x| x as u32).collect();
27    let len = ints.len();
28    Tensor::from_vec(ints, len, device)
29}
30
31fn candle_err(e: candle_core::Error) -> rlox_nn::NNError {
32    rlox_nn::NNError::Backend(e.to_string())
33}
34
35pub trait IntoNNError<T> {
36    fn nn_err(self) -> Result<T, rlox_nn::NNError>;
37}
38
39impl<T> IntoNNError<T> for candle_core::Result<T> {
40    fn nn_err(self) -> Result<T, rlox_nn::NNError> {
41        self.map_err(candle_err)
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use super::*;
48
49    #[test]
50    fn test_roundtrip_1d() {
51        let original = TensorData::new(vec![1.0, 2.0, 3.0], vec![3]);
52        let tensor = to_tensor_1d(&original, &Device::Cpu).unwrap();
53        let result = from_tensor_1d(&tensor).unwrap();
54        assert_eq!(result.data, original.data);
55    }
56
57    #[test]
58    fn test_roundtrip_2d() {
59        let original = TensorData::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
60        let tensor = to_tensor_2d(&original, &Device::Cpu).unwrap();
61        let result = from_tensor_2d(&tensor).unwrap();
62        assert_eq!(result.data, original.data);
63        assert_eq!(result.shape, vec![2, 2]);
64    }
65}