1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127

pub use tokio::task::{JoinError,LocalKey,LocalSet};
pub use tokio::task::{block_in_place,spawn_blocking,yield_now};

use std::future::Future;
use tokio::sync::mpsc;
use std::result::Result;
use pin_utils::unsafe_pinned;

#[must_use="you need to save this JoinHandle so you can join later"]
pub struct JoinHandle<T> { jh : tokio::task::JoinHandle<T> }

/*
impl<T> std::ops::Deref for JoinHandle<T> {
  type Target = tokio::task::JoinHandle<T>;
  fn deref(&self) -> &Self::Target { &self.0 }
}
impl<T> std::ops::DerefMut for JoinHandle<T> {
  fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 }
}
impl<T> JoinHandle<T> {
  fn unwrap(self) -> tokio::task::JoinHandle<T> { self.0 }
}*/

use std::pin::Pin;
use std::task::{Context,Poll};

impl<T> JoinHandle<T> {
  unsafe_pinned!(jh: tokio::task::JoinHandle<T>);
}

impl<T> Future for JoinHandle<T> {
  type Output = Result<T, JoinError>;
  fn poll(mut self : Pin<&mut Self>, cx : &mut Context<'_>)
          -> Poll<Result<T, JoinError>> {
    self.as_mut().jh().poll(cx)
  }
}

pub fn spawn_for_join<T>(task : T) -> JoinHandle<T::Output> where
    T: Future + Send + 'static,
    T::Output: Send + 'static,
{ 
  JoinHandle{ jh : tokio::task::spawn(task) }
}

pub type TaskResult = anyhow::Result<()>;

type Report = anyhow::Error;

pub struct Tracker {
  sender : Option<mpsc::Sender<Report>>,
  receiver : mpsc::Receiver<Report>,
}

impl Tracker {
  pub fn new() -> Tracker {
    let (sender, receiver) = mpsc::channel(5);
    Tracker { sender : Some(sender), receiver }
  }

  pub fn spawn<T,S>(&mut self, what : S, task: T) where
    T: Future<Output=TaskResult> + Send + 'static,
    T::Output: Send + 'static,
    S : Into<String>,
  {
    let jh = spawn_for_join(task);
    self.register(jh,what.into());
  }

  pub fn register(&mut self, jh : JoinHandle<TaskResult>, what : String) {
    let mut sender =
      self.sender.as_mut().expect("tracker not waiting yet").clone();
    tokio::task::spawn(async move {
      let got = jh.await;

      let report : Option<(anyhow::Error, String)> = match got {
        Ok(Ok(())) => {
          println!("terminated ok: {}", what);
          None
        },
        Ok(Err(task_err)) => {
          let m = format!("detached tracked task failed: {} {:?}",
                          what, task_err);
          Some((task_err, m))
        },
        Err(join_err) => {
          if join_err.is_cancelled() {
            // mystery, but probably ok
            // println!("cancelled: {}", what);
            None
          } else {
            let m = format!("detached tracked task failed: {} {:?}",
                            what, join_err);
            Some((join_err.into(), m))
          }
        },
      };
      if let Some((report_err, m)) = report {
        eprintln!("{}", m);
        sender.send(report_err.context(m)).await.ok();
        // ^ Err means wait went away, or we are crashing, or something
      }
    });
  }

  pub async fn failfast(mut self) -> TaskResult {
    self.sender = None;
    let got = self.receiver.recv().await;
    match got {
      None => Ok(()),
      Some(error) => Err(error),
    }
  }

  pub async fn failslow(mut self) -> TaskResult {
    let mut result = Ok(());
    self.sender = None;
    loop {
      let got = self.receiver.recv().await;
      if got.is_none() { break }
      let got = Err(got.unwrap());
      if result.is_ok() { result = got; }
    }
    return result;
  }
}