1pub mod columnar;
23pub mod concurrent;
24pub mod episode;
25pub mod extra_columns;
26#[cfg(feature = "gpu")]
27pub mod flat;
28pub mod her;
29pub mod mixed;
30pub mod mmap;
31pub mod offline;
32pub mod priority;
33pub mod provenance;
34pub mod ringbuf;
35pub mod sequence;
36pub mod varlen;
37
38#[derive(Debug, Clone)]
41pub struct ExperienceRecord {
42 pub obs: Vec<f32>,
43 pub next_obs: Vec<f32>,
44 pub action: Vec<f32>,
45 pub reward: f32,
46 pub terminated: bool,
47 pub truncated: bool,
48}
49
50#[cfg(test)]
51pub(crate) fn sample_record(obs_dim: usize) -> ExperienceRecord {
52 ExperienceRecord {
53 obs: vec![1.0; obs_dim],
54 next_obs: vec![2.0; obs_dim],
55 action: vec![0.0],
56 reward: 1.0,
57 terminated: false,
58 truncated: false,
59 }
60}
61
62#[cfg(test)]
63pub(crate) fn sample_record_multidim(obs_dim: usize, act_dim: usize) -> ExperienceRecord {
64 ExperienceRecord {
65 obs: vec![1.0; obs_dim],
66 next_obs: vec![2.0; obs_dim],
67 action: vec![0.0; act_dim],
68 reward: 1.0,
69 terminated: false,
70 truncated: false,
71 }
72}
73
74#[cfg(test)]
75mod fix_verification_tests {
76 use super::*;
77 use crate::buffer::columnar::ExperienceTable;
78 use crate::buffer::ringbuf::ReplayBuffer;
79
80 #[test]
81 fn experience_record_action_is_vec() {
82 let record = ExperienceRecord {
83 obs: vec![0.0f32; 17],
84 next_obs: vec![0.0f32; 17],
85 action: vec![0.1, -0.2, 0.3, -0.4, 0.5, -0.6],
86 reward: 1.0,
87 terminated: false,
88 truncated: false,
89 };
90 assert_eq!(record.action.len(), 6);
91 assert_eq!(record.obs.len(), 17);
92 }
93
94 #[test]
95 fn experience_table_stores_multi_dim_action() {
96 let obs_dim = 17;
97 let act_dim = 6;
98 let mut table = ExperienceTable::new(obs_dim, act_dim);
99 let action = vec![0.1f32, -0.2, 0.3, -0.4, 0.5, -0.6];
100 let record = ExperienceRecord {
101 obs: vec![1.0; obs_dim],
102 next_obs: vec![2.0; obs_dim],
103 action: action.clone(),
104 reward: 5.0,
105 terminated: false,
106 truncated: false,
107 };
108 table.push(record).unwrap();
109 assert_eq!(table.actions_raw().len(), act_dim);
110 assert_eq!(&table.actions_raw()[..act_dim], action.as_slice());
111 }
112
113 #[test]
114 fn replay_buffer_multi_dim_action_roundtrip() {
115 let obs_dim = 4;
116 let act_dim = 3;
117 let mut buf = ReplayBuffer::new(100, obs_dim, act_dim);
118 let action = vec![0.5f32, -0.5, 1.0];
119 let record = ExperienceRecord {
120 obs: vec![0.1; obs_dim],
121 next_obs: vec![0.2; obs_dim],
122 action: action.clone(),
123 reward: 1.0,
124 terminated: false,
125 truncated: false,
126 };
127 buf.push(record).unwrap();
128 let batch = buf.sample(1, 42).unwrap();
129 assert_eq!(batch.act_dim, act_dim);
130 assert_eq!(batch.actions.len(), act_dim);
131 assert_eq!(&batch.actions[..act_dim], action.as_slice());
132 }
133
134 #[test]
135 fn experience_table_action_dim_mismatch_returns_error() {
136 let mut table = ExperienceTable::new(4, 2);
137 let record = ExperienceRecord {
138 obs: vec![1.0; 4],
139 next_obs: vec![2.0; 4],
140 action: vec![0.1, 0.2, 0.3], reward: 1.0,
142 terminated: false,
143 truncated: false,
144 };
145 let result = table.push(record);
146 assert!(result.is_err(), "action dim mismatch must return Err");
147 let err_str = result.unwrap_err().to_string();
148 assert!(
149 err_str.contains("act_dim"),
150 "error must mention act_dim, got: {err_str}"
151 );
152 }
153
154 #[test]
155 fn experience_table_scalar_action_dim_one() {
156 let mut table = ExperienceTable::new(4, 1);
157 let record = ExperienceRecord {
158 obs: vec![1.0; 4],
159 next_obs: vec![2.0; 4],
160 action: vec![0.0],
161 reward: 1.0,
162 terminated: false,
163 truncated: false,
164 };
165 table.push(record).unwrap();
166 assert_eq!(table.len(), 1);
167 assert_eq!(table.actions_raw().len(), 1);
168 }
169}