Skip to main content

phytrace_sdk/reliability/
retry.rs

1//! Retry logic with exponential backoff and jitter.
2
3use rand::Rng;
4use std::time::Duration;
5
6use crate::core::config::RetryConfig;
7use crate::error::{PhyTraceError, PhyTraceResult};
8use crate::transport::traits::SendResult;
9
10/// Retry handler with exponential backoff.
11#[derive(Debug, Clone)]
12pub struct RetryHandler {
13    config: RetryConfig,
14}
15
16impl RetryHandler {
17    /// Create a new retry handler from configuration.
18    pub fn new(config: &RetryConfig) -> Self {
19        Self {
20            config: config.clone(),
21        }
22    }
23
24    /// Create with default configuration.
25    pub fn default_handler() -> Self {
26        Self::new(&RetryConfig::default())
27    }
28
29    /// Execute an async operation with retries.
30    pub async fn execute<F, Fut, T>(&self, mut operation: F) -> PhyTraceResult<T>
31    where
32        F: FnMut() -> Fut,
33        Fut: std::future::Future<Output = PhyTraceResult<T>>,
34    {
35        let mut attempt = 0;
36        let mut last_error = None;
37
38        while attempt <= self.config.max_retries {
39            match operation().await {
40                Ok(result) => return Ok(result),
41                Err(e) => {
42                    if !self.is_retryable(&e) {
43                        return Err(e);
44                    }
45
46                    last_error = Some(e);
47                    attempt += 1;
48
49                    if attempt <= self.config.max_retries {
50                        let delay = self.calculate_delay(attempt);
51                        tokio::time::sleep(delay).await;
52                    }
53                }
54            }
55        }
56
57        Err(last_error
58            .unwrap_or_else(|| PhyTraceError::Transport("Max retries exceeded".to_string())))
59    }
60
61    /// Execute with retries and return attempt count.
62    pub async fn execute_with_stats<F, Fut, T>(
63        &self,
64        mut operation: F,
65    ) -> (PhyTraceResult<T>, RetryStats)
66    where
67        F: FnMut() -> Fut,
68        Fut: std::future::Future<Output = PhyTraceResult<T>>,
69    {
70        let mut stats = RetryStats::default();
71        let mut last_error = None;
72
73        while stats.attempts <= self.config.max_retries {
74            stats.attempts += 1;
75
76            match operation().await {
77                Ok(result) => {
78                    stats.succeeded = true;
79                    return (Ok(result), stats);
80                }
81                Err(e) => {
82                    if !self.is_retryable(&e) {
83                        return (Err(e), stats);
84                    }
85
86                    last_error = Some(e);
87                    stats.retries += 1;
88
89                    if stats.attempts <= self.config.max_retries {
90                        let delay = self.calculate_delay(stats.attempts);
91                        stats.total_delay_ms += delay.as_millis() as u64;
92                        tokio::time::sleep(delay).await;
93                    }
94                }
95            }
96        }
97
98        (
99            Err(last_error
100                .unwrap_or_else(|| PhyTraceError::Transport("Max retries exceeded".to_string()))),
101            stats,
102        )
103    }
104
105    /// Check if a send result indicates a retryable failure.
106    pub fn should_retry(&self, result: &SendResult) -> bool {
107        !result.success && result.is_retryable()
108    }
109
110    /// Calculate delay for a given attempt number.
111    pub fn calculate_delay(&self, attempt: u32) -> Duration {
112        let base_delay = self.config.initial_backoff_ms as f64;
113        let multiplier = self.config.backoff_multiplier;
114
115        // Exponential backoff
116        let delay = base_delay * multiplier.powi(attempt.saturating_sub(1) as i32);
117
118        // Cap at max backoff
119        let delay = delay.min(self.config.max_backoff_ms as f64);
120
121        // Add jitter if enabled
122        let delay = if self.config.jitter {
123            let jitter = rand::thread_rng().gen_range(0.0..=0.3);
124            delay * (1.0 + jitter)
125        } else {
126            delay
127        };
128
129        #[expect(
130            clippy::cast_sign_loss,
131            clippy::cast_possible_truncation,
132            reason = "delay is computed from non-negative values; saturating cast is acceptable for retry timing"
133        )]
134        let delay_ms = delay.max(0.0) as u64;
135        Duration::from_millis(delay_ms)
136    }
137
138    /// Check if an error is retryable.
139    fn is_retryable(&self, error: &PhyTraceError) -> bool {
140        matches!(
141            error,
142            PhyTraceError::Transport(_) | PhyTraceError::Timeout(_)
143        )
144    }
145
146    /// Get the maximum number of retries.
147    pub fn max_retries(&self) -> u32 {
148        self.config.max_retries
149    }
150
151    /// Get the initial backoff duration.
152    pub fn initial_backoff(&self) -> Duration {
153        Duration::from_millis(self.config.initial_backoff_ms)
154    }
155}
156
157/// Statistics from a retry operation.
158#[derive(Debug, Clone, Default)]
159pub struct RetryStats {
160    /// Number of attempts made.
161    pub attempts: u32,
162    /// Number of retries (attempts - 1 if successful on first try).
163    pub retries: u32,
164    /// Total delay time in milliseconds.
165    pub total_delay_ms: u64,
166    /// Whether the operation eventually succeeded.
167    pub succeeded: bool,
168}
169
170impl RetryStats {
171    /// Check if any retries were needed.
172    pub fn had_retries(&self) -> bool {
173        self.retries > 0
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use std::sync::atomic::{AtomicU32, Ordering};
181
182    #[test]
183    fn test_calculate_delay() {
184        let config = RetryConfig {
185            initial_backoff_ms: 100,
186            max_backoff_ms: 10000,
187            backoff_multiplier: 2.0,
188            jitter: false,
189            max_retries: 5,
190        };
191        let handler = RetryHandler::new(&config);
192
193        // Without jitter, delays should be predictable
194        assert_eq!(handler.calculate_delay(1), Duration::from_millis(100));
195        assert_eq!(handler.calculate_delay(2), Duration::from_millis(200));
196        assert_eq!(handler.calculate_delay(3), Duration::from_millis(400));
197        assert_eq!(handler.calculate_delay(4), Duration::from_millis(800));
198    }
199
200    #[test]
201    fn test_delay_capped() {
202        let config = RetryConfig {
203            initial_backoff_ms: 1000,
204            max_backoff_ms: 5000,
205            backoff_multiplier: 2.0,
206            jitter: false,
207            max_retries: 10,
208        };
209        let handler = RetryHandler::new(&config);
210
211        // Should be capped at max
212        assert_eq!(handler.calculate_delay(5), Duration::from_millis(5000));
213        assert_eq!(handler.calculate_delay(6), Duration::from_millis(5000));
214    }
215
216    #[tokio::test]
217    async fn test_execute_success_first_try() {
218        let handler = RetryHandler::default_handler();
219
220        let result = handler
221            .execute(|| async { Ok::<_, PhyTraceError>(42) })
222            .await;
223
224        assert_eq!(result.unwrap(), 42);
225    }
226
227    #[tokio::test]
228    async fn test_execute_success_after_retry() {
229        let handler = RetryHandler::new(&RetryConfig {
230            max_retries: 3,
231            initial_backoff_ms: 1, // Minimal delay for testing
232            ..Default::default()
233        });
234
235        let counter = AtomicU32::new(0);
236
237        let result = handler
238            .execute(|| async {
239                let count = counter.fetch_add(1, Ordering::Relaxed);
240                if count < 2 {
241                    Err(PhyTraceError::Transport("temporary failure".to_string()))
242                } else {
243                    Ok(42)
244                }
245            })
246            .await;
247
248        assert_eq!(result.unwrap(), 42);
249        assert_eq!(counter.load(Ordering::Relaxed), 3);
250    }
251
252    #[tokio::test]
253    async fn test_execute_all_retries_fail() {
254        let handler = RetryHandler::new(&RetryConfig {
255            max_retries: 2,
256            initial_backoff_ms: 1,
257            ..Default::default()
258        });
259
260        let counter = AtomicU32::new(0);
261
262        let result = handler
263            .execute(|| async {
264                counter.fetch_add(1, Ordering::Relaxed);
265                Err::<i32, _>(PhyTraceError::Transport("persistent failure".to_string()))
266            })
267            .await;
268
269        result.unwrap_err();
270        assert_eq!(counter.load(Ordering::Relaxed), 3); // Initial + 2 retries
271    }
272
273    #[tokio::test]
274    async fn test_execute_with_stats() {
275        let handler = RetryHandler::new(&RetryConfig {
276            max_retries: 3,
277            initial_backoff_ms: 1,
278            ..Default::default()
279        });
280
281        let counter = AtomicU32::new(0);
282
283        let (result, stats) = handler
284            .execute_with_stats(|| async {
285                let count = counter.fetch_add(1, Ordering::Relaxed);
286                if count < 2 {
287                    Err(PhyTraceError::Transport("temporary".to_string()))
288                } else {
289                    Ok(42)
290                }
291            })
292            .await;
293
294        result.unwrap();
295        assert_eq!(stats.attempts, 3);
296        assert_eq!(stats.retries, 2);
297        assert!(stats.succeeded);
298    }
299
300    #[test]
301    fn test_should_retry() {
302        let handler = RetryHandler::default_handler();
303
304        let retryable = SendResult::failure(Some(500), "Server error");
305        assert!(handler.should_retry(&retryable));
306
307        let not_retryable = SendResult::failure(Some(400), "Bad request");
308        assert!(!handler.should_retry(&not_retryable));
309
310        let success = SendResult::success(50);
311        assert!(!handler.should_retry(&success));
312    }
313}