use std::{
borrow::Cow,
thread::{spawn, JoinHandle, Scope, ScopedJoinHandle},
};
use anyhow::Result;
use crossbeam::channel::{bounded, Receiver, Sender};
use indicatif::ProgressBar;
use crate::util::logging::{measure_and_recv, measure_and_send, meter_bar};
use super::ObjectWriter;
enum WorkHandle<'scope> {
Static(JoinHandle<Result<usize>>),
Scoped(ScopedJoinHandle<'scope, Result<usize>>),
}
fn ferry<T, W>(recv: Receiver<T>, writer: W, pb: ProgressBar) -> Result<usize>
where
T: Send + Sync + 'static,
W: ObjectWriter<T>,
{
let mut writer = writer; while let Some(obj) = measure_and_recv(&recv, &pb) {
writer.write_object(obj)?;
}
pb.finish_and_clear();
writer.finish()
}
pub struct ThreadObjectWriterBuilder<W> {
writer: W,
name: String,
capacity: usize,
}
pub struct ThreadObjectWriter<'scope, T>
where
T: Send + Sync + 'static,
{
sender: Sender<T>,
handle: WorkHandle<'scope>,
meter: ProgressBar,
}
impl<'scope, T> ThreadObjectWriter<'scope, T>
where
T: Send + Sync + 'scope,
{
pub fn wrap<W>(writer: W) -> ThreadObjectWriterBuilder<W>
where
W: ObjectWriter<T> + Send + Sync + 'scope,
{
ThreadObjectWriterBuilder {
writer,
name: "unnamed".into(),
capacity: 100,
}
}
}
impl<W> ThreadObjectWriterBuilder<W> {
pub fn with_capacity(self, cap: usize) -> Self {
ThreadObjectWriterBuilder {
capacity: cap,
..self
}
}
pub fn with_name<S: Into<Cow<'static, str>>>(self, name: S) -> Self {
let name: Cow<'static, str> = name.into();
ThreadObjectWriterBuilder {
name: name.to_string(),
..self
}
}
pub fn spawn_scoped<'scope, 'env, T>(
self,
scope: &'scope Scope<'scope, 'env>,
) -> ThreadObjectWriter<'scope, T>
where
W: ObjectWriter<T> + Send + Sync + 'scope,
T: Send + Sync + 'scope,
{
let (sender, receiver) = bounded(self.capacity);
let pb = meter_bar(self.capacity, &format!("{} buffer", self.name));
let rpb = pb.clone();
let h = scope.spawn(move || ferry(receiver, self.writer, rpb));
ThreadObjectWriter {
meter: pb,
sender,
handle: WorkHandle::Scoped(h),
}
}
}
impl<W> ThreadObjectWriterBuilder<W> {
pub fn spawn<T>(self) -> ThreadObjectWriter<'static, T>
where
W: ObjectWriter<T> + Send + Sync + 'static,
T: Send + Sync + 'static,
{
let (sender, receiver) = bounded(self.capacity);
let pb = meter_bar(self.capacity, &format!("{} buffer", self.name));
let rpb = pb.clone();
let h = spawn(move || ferry(receiver, self.writer, rpb));
ThreadObjectWriter {
meter: pb,
sender,
handle: WorkHandle::Static(h),
}
}
}
impl<'scope, T: Send + Sync + 'scope> ThreadObjectWriter<'scope, T> {
pub fn satellite<'a>(&'a self) -> ThreadWriterSatellite<'a, 'scope, T>
where
'scope: 'a,
{
ThreadWriterSatellite::create(self)
}
}
impl<'scope, T: Send + Sync + 'static> ObjectWriter<T> for ThreadObjectWriter<'scope, T> {
fn write_object(&mut self, object: T) -> Result<()> {
measure_and_send(&self.sender, object, &self.meter)?;
Ok(())
}
fn finish(self) -> Result<usize> {
drop(self.sender);
let res = match self.handle {
WorkHandle::Static(h) => h.join().map_err(std::panic::resume_unwind)?,
WorkHandle::Scoped(h) => h.join().map_err(std::panic::resume_unwind)?,
};
res
}
}
#[derive(Clone)]
pub struct ThreadWriterSatellite<'a, 'scope, T>
where
T: Send + Sync + 'static,
'scope: 'a,
{
delegate: &'a ThreadObjectWriter<'scope, T>,
sender: Sender<T>,
}
impl<'a, 'scope, T> ThreadWriterSatellite<'a, 'scope, T>
where
T: Send + Sync + 'static,
'scope: 'a,
{
fn create(delegate: &'a ThreadObjectWriter<'scope, T>) -> ThreadWriterSatellite<'a, 'scope, T> {
ThreadWriterSatellite {
delegate,
sender: delegate.sender.clone(),
}
}
}
impl<'a, 'scope, T> ObjectWriter<T> for ThreadWriterSatellite<'a, 'scope, T>
where
T: Send + Sync + 'static,
'scope: 'a,
{
fn write_object(&mut self, object: T) -> Result<()> {
measure_and_send(&self.sender, object, &self.delegate.meter)?;
Ok(())
}
fn finish(self) -> Result<usize> {
Ok(0)
}
}