1use std::time::{SystemTime, UNIX_EPOCH};
2
3use serde::{Deserialize, Serialize};
4
5use crate::error::RloxError;
6
7#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
13pub struct TransitionMeta {
14 pub env_id: u32,
15 pub episode_id: u64,
16 pub step_in_episode: u32,
17 pub global_step: u64,
18 pub policy_version: u64,
19 pub reward_model_version: u64,
20 pub timestamp_ns: u64,
21}
22
23const SERIALIZED_SIZE: usize = 48;
24
25impl TransitionMeta {
26 pub fn new(
29 env_id: u32,
30 episode_id: u64,
31 step_in_episode: u32,
32 global_step: u64,
33 policy_version: u64,
34 ) -> Self {
35 let timestamp_ns = SystemTime::now()
36 .duration_since(UNIX_EPOCH)
37 .expect("system clock before UNIX epoch")
38 .as_nanos() as u64;
39 Self {
40 env_id,
41 episode_id,
42 step_in_episode,
43 global_step,
44 policy_version,
45 reward_model_version: 0,
46 timestamp_ns,
47 }
48 }
49
50 pub fn serialize(&self) -> Vec<u8> {
52 let mut buf = Vec::with_capacity(SERIALIZED_SIZE);
53 buf.extend_from_slice(&self.env_id.to_le_bytes());
54 buf.extend_from_slice(&self.episode_id.to_le_bytes());
55 buf.extend_from_slice(&self.step_in_episode.to_le_bytes());
56 buf.extend_from_slice(&self.global_step.to_le_bytes());
57 buf.extend_from_slice(&self.policy_version.to_le_bytes());
58 buf.extend_from_slice(&self.reward_model_version.to_le_bytes());
59 buf.extend_from_slice(&self.timestamp_ns.to_le_bytes());
60 buf
61 }
62
63 pub fn deserialize(bytes: &[u8]) -> Result<Self, RloxError> {
65 if bytes.len() != SERIALIZED_SIZE {
66 return Err(RloxError::BufferError(format!(
67 "TransitionMeta requires exactly {SERIALIZED_SIZE} bytes, got {}",
68 bytes.len()
69 )));
70 }
71
72 let env_id = u32::from_le_bytes(bytes[0..4].try_into().unwrap());
73 let episode_id = u64::from_le_bytes(bytes[4..12].try_into().unwrap());
74 let step_in_episode = u32::from_le_bytes(bytes[12..16].try_into().unwrap());
75 let global_step = u64::from_le_bytes(bytes[16..24].try_into().unwrap());
76 let policy_version = u64::from_le_bytes(bytes[24..32].try_into().unwrap());
77 let reward_model_version = u64::from_le_bytes(bytes[32..40].try_into().unwrap());
78 let timestamp_ns = u64::from_le_bytes(bytes[40..48].try_into().unwrap());
79
80 Ok(Self {
81 env_id,
82 episode_id,
83 step_in_episode,
84 global_step,
85 policy_version,
86 reward_model_version,
87 timestamp_ns,
88 })
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 fn sample_meta() -> TransitionMeta {
97 TransitionMeta {
98 env_id: 42,
99 episode_id: 1000,
100 step_in_episode: 7,
101 global_step: 50000,
102 policy_version: 3,
103 reward_model_version: 1,
104 timestamp_ns: 1_700_000_000_000_000_000,
105 }
106 }
107
108 #[test]
109 fn serialize_is_48_bytes() {
110 let meta = sample_meta();
111 let bytes = meta.serialize();
112 assert_eq!(bytes.len(), 48);
113 }
114
115 #[test]
116 fn roundtrip() {
117 let meta = sample_meta();
118 let bytes = meta.serialize();
119 let restored = TransitionMeta::deserialize(&bytes).unwrap();
120 assert_eq!(meta, restored);
121 }
122
123 #[test]
124 fn deserialize_wrong_length_errors() {
125 let result = TransitionMeta::deserialize(&[0u8; 47]);
126 assert!(result.is_err());
127 let result = TransitionMeta::deserialize(&[0u8; 49]);
128 assert!(result.is_err());
129 }
130
131 #[test]
132 fn roundtrip_zeros() {
133 let meta = TransitionMeta {
134 env_id: 0,
135 episode_id: 0,
136 step_in_episode: 0,
137 global_step: 0,
138 policy_version: 0,
139 reward_model_version: 0,
140 timestamp_ns: 0,
141 };
142 let bytes = meta.serialize();
143 let restored = TransitionMeta::deserialize(&bytes).unwrap();
144 assert_eq!(meta, restored);
145 }
146
147 #[test]
148 fn roundtrip_max_values() {
149 let meta = TransitionMeta {
150 env_id: u32::MAX,
151 episode_id: u64::MAX,
152 step_in_episode: u32::MAX,
153 global_step: u64::MAX,
154 policy_version: u64::MAX,
155 reward_model_version: u64::MAX,
156 timestamp_ns: u64::MAX,
157 };
158 let bytes = meta.serialize();
159 let restored = TransitionMeta::deserialize(&bytes).unwrap();
160 assert_eq!(meta, restored);
161 }
162
163 #[test]
164 fn test_transition_meta_roundtrip() {
165 let meta = TransitionMeta::new(7, 100, 5, 99999, 2);
166 let bytes = meta.serialize();
167 let restored = TransitionMeta::deserialize(&bytes).unwrap();
168 assert_eq!(meta, restored);
169 }
170
171 #[test]
172 fn test_transition_meta_timestamp_nonzero() {
173 let meta = TransitionMeta::new(0, 0, 0, 0, 0);
174 assert!(
175 meta.timestamp_ns > 0,
176 "auto-filled timestamp must be non-zero"
177 );
178 }
179
180 #[test]
181 fn test_transition_meta_serialize_size() {
182 let meta = TransitionMeta::new(1, 2, 3, 4, 5);
183 assert_eq!(meta.serialize().len(), 48);
184 }
185
186 mod proptests {
187 use super::*;
188 use proptest::prelude::*;
189
190 proptest! {
191 #[test]
192 fn roundtrip_arbitrary(
193 env_id: u32,
194 episode_id: u64,
195 step_in_episode: u32,
196 global_step: u64,
197 policy_version: u64,
198 reward_model_version: u64,
199 timestamp_ns: u64,
200 ) {
201 let meta = TransitionMeta {
202 env_id,
203 episode_id,
204 step_in_episode,
205 global_step,
206 policy_version,
207 reward_model_version,
208 timestamp_ns,
209 };
210 let bytes = meta.serialize();
211 prop_assert_eq!(bytes.len(), 48);
212 let restored = TransitionMeta::deserialize(&bytes).unwrap();
213 prop_assert_eq!(meta, restored);
214 }
215 }
216 }
217}