use crate::error::PERes;
use std::{
mem::replace,
sync::{Arc, Condvar, Mutex},
thread::{Builder, JoinHandle},
};
struct PendingOps<T> {
ops: Vec<T>,
running: bool,
}
impl<T> PendingOps<T> {
fn new() -> Self {
Self {
ops: Vec::new(),
running: true,
}
}
fn push(&mut self, op: T) {
self.ops.push(op);
}
fn terminate(&mut self) {
self.running = false;
}
}
pub(crate) struct BackgroundOps<T> {
pending: Arc<(Mutex<PendingOps<T>>, Condvar)>,
flush_thread: Option<JoinHandle<()>>,
}
fn sync_on_need<T, F, FO>(ops: &Mutex<PendingOps<T>>, cond: &Condvar, operation: FO, release_all: F) -> PERes<()>
where
F: Fn(Vec<T>) -> PERes<bool>,
FO: Fn() -> PERes<bool>,
{
let mut required_next = false;
loop {
let pending;
let running;
{
let mut lock = ops.lock().expect("lock not poisoned");
lock = cond
.wait_while(lock, |x| x.ops.is_empty() && x.running && !required_next)
.expect("lock not poisoned");
pending = replace(&mut lock.ops, Vec::new());
running = lock.running;
}
required_next = operation()?;
required_next |= release_all(pending)?;
if !running {
break Ok(());
}
}
}
impl<T: 'static + Send> BackgroundOps<T> {
pub fn new<F, FO>(operation: FO, release_all: F) -> PERes<Self>
where
F: Fn(Vec<T>) -> PERes<bool>,
F: Send + 'static,
FO: Fn() -> PERes<bool>,
FO: Send + 'static,
{
let pending = Arc::new((Mutex::new(PendingOps::new()), Condvar::new()));
let pass = pending.clone();
let th = Builder::new()
.name("Disc sync".into())
.spawn(move || {
sync_on_need(&pass.0, &pass.1, operation, release_all).unwrap();
})
.unwrap();
Ok(Self {
pending,
flush_thread: Some(th),
})
}
pub fn add_pending(&self, op: T) -> PERes<()> {
let (ops, cond) = &*self.pending;
let mut lock = ops.lock().expect("lock not poisoned");
lock.push(op);
cond.notify_one();
Ok(())
}
}
impl<T> BackgroundOps<T> {
pub fn finish(&mut self) {
if let Some(handle) = self.flush_thread.take() {
let (ops, cond) = &*self.pending;
{
let mut pend = ops.lock().unwrap();
pend.terminate();
cond.notify_one();
}
handle.join().expect("no failure on background thread termination");
}
}
}
#[cfg(test)]
mod tests {
use super::BackgroundOps;
use std::{
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
thread::sleep,
time::Duration,
};
#[test]
fn test_execute_delayed() {
let counter = Arc::new(AtomicU64::new(0));
let cr = counter.clone();
let bg = BackgroundOps::new(
|| {
sleep(Duration::from_millis(10));
Ok(false)
},
move |ops| {
cr.fetch_add(ops.len() as u64, Ordering::SeqCst);
Ok(false)
},
)
.unwrap();
bg.add_pending(1).unwrap();
sleep(Duration::from_millis(1));
bg.add_pending(2).unwrap();
bg.add_pending(3).unwrap();
sleep(Duration::from_millis(61));
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
}