AppFlowy-Cloud/libs/indexer/src/thread_pool.rs

168 lines
5.0 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use rayon::{ThreadPool, ThreadPoolBuilder};
use thiserror::Error;
/// A thread pool that does not abort on panics.
///
/// This custom thread pool wraps Rayons `ThreadPool` and ensures that the thread pool
/// can recover from panics gracefully. It detects any panics in worker threads and
/// prevents the entire application from aborting.
#[derive(Debug)]
pub struct ThreadPoolNoAbort {
/// Internal Rayon thread pool.
thread_pool: ThreadPool,
/// Atomic flag to detect if a panic occurred in the thread pool.
catched_panic: Arc<AtomicBool>,
}
impl ThreadPoolNoAbort {
/// Executes a closure within the thread pool.
///
/// This method runs the provided closure (`op`) inside the thread pool. If a panic
/// occurs during the execution, it is detected and returned as an error.
///
/// # Arguments
/// * `op` - A closure that will be executed within the thread pool.
///
/// # Returns
/// * `Ok(R)` - The result of the closure if execution was successful.
/// * `Err(PanicCatched)` - An error indicating that a panic occurred during execution.
///
pub fn install<OP, R>(&self, op: OP) -> Result<R, CatchedPanic>
where
OP: FnOnce() -> R + Send,
R: Send,
{
let output = self.thread_pool.install(op);
// Reset the panic flag and return an error if a panic was detected.
if self.catched_panic.swap(false, Ordering::SeqCst) {
Err(CatchedPanic)
} else {
Ok(output)
}
}
/// Returns the current number of threads in the thread pool.
///
/// # Returns
/// The number of threads being used by the thread pool.
pub fn current_num_threads(&self) -> usize {
self.thread_pool.current_num_threads()
}
}
/// Error indicating that a panic occurred during thread pool execution.
///
/// This error is returned when a closure executed in the thread pool panics.
#[derive(Error, Debug)]
#[error("A panic occurred happened in the thread pool. Check the logs for more information")]
pub struct CatchedPanic;
/// A builder for creating a `ThreadPoolNoAbort` instance.
///
/// This builder wraps Rayons `ThreadPoolBuilder` and customizes the panic handling behavior.
#[derive(Default)]
pub struct ThreadPoolNoAbortBuilder(ThreadPoolBuilder);
impl ThreadPoolNoAbortBuilder {
pub fn new() -> ThreadPoolNoAbortBuilder {
ThreadPoolNoAbortBuilder::default()
}
/// Sets a custom naming function for threads in the pool.
///
/// # Arguments
/// * `closure` - A function that takes a thread index and returns a thread name.
///
pub fn thread_name<F>(mut self, closure: F) -> Self
where
F: FnMut(usize) -> String + 'static,
{
self.0 = self.0.thread_name(closure);
self
}
/// Sets the number of threads for the thread pool.
///
/// # Arguments
/// * `num_threads` - The number of threads to create in the thread pool.
pub fn num_threads(mut self, num_threads: usize) -> ThreadPoolNoAbortBuilder {
self.0 = self.0.num_threads(num_threads);
self
}
/// Builds the `ThreadPoolNoAbort` instance.
///
/// This method creates a `ThreadPoolNoAbort` with the specified configurations,
/// including custom panic handling behavior.
///
/// # Returns
/// * `Ok(ThreadPoolNoAbort)` - The constructed thread pool.
/// * `Err(ThreadPoolBuildError)` - If the thread pool failed to build.
///
pub fn build(mut self) -> Result<ThreadPoolNoAbort, rayon::ThreadPoolBuildError> {
let catched_panic = Arc::new(AtomicBool::new(false));
self.0 = self.0.panic_handler({
let catched_panic = catched_panic.clone();
move |_result| catched_panic.store(true, Ordering::SeqCst)
});
Ok(ThreadPoolNoAbort {
thread_pool: self.0.build()?,
catched_panic,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_install_closure_success() {
// Create a thread pool with 4 threads.
let pool = ThreadPoolNoAbortBuilder::new()
.num_threads(4)
.build()
.expect("Failed to build thread pool");
// Run a closure that executes successfully.
let result = pool.install(|| 42);
// Ensure the result is correct.
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[test]
fn test_multiple_threads_execution() {
// Create a thread pool with multiple threads.
let pool = ThreadPoolNoAbortBuilder::new()
.num_threads(8)
.build()
.expect("Failed to build thread pool");
// Shared atomic counter to verify parallel execution.
let counter = Arc::new(AtomicUsize::new(0));
let handles: Vec<_> = (0..100)
.map(|_| {
let counter_clone = counter.clone();
pool.install(move || {
counter_clone.fetch_add(1, Ordering::SeqCst);
})
})
.collect();
// Ensure all tasks completed successfully.
for handle in handles {
assert!(handle.is_ok());
}
// Verify that the counter equals the number of tasks executed.
assert_eq!(counter.load(Ordering::SeqCst), 100);
}
}