chiark / gitweb /
changelog: document further make-release changes
[otter.git] / support / timedfd.rs
1 // Copyright 2020-2021 Ian Jackson and contributors to Otter
2 // SPDX-License-Identifier: AGPL-3.0-or-later
3 // There is NO WARRANTY.
4
5 use crate::prelude::*;
6
7 pub trait Timed {
8   fn set_deadline(&mut self, deadline: Option<Instant>);
9   fn set_timeout(&mut self, timeout: Option<Duration>);
10 }
11 pub trait TimedRead : Timed + Read  { }
12 pub trait TimedWrite: Timed + Write { }
13
14 use nix::fcntl::{fcntl, OFlag, FcntlArg};
15 use nix::Error as NE;
16
17 use mio::Token;
18
19 pub struct TimedFd<RW: TimedFdReadWrite> {
20   fd: Fd,
21   poll: mio::Poll,
22   events: mio::event::Events,
23   deadline: Option<Instant>,
24   rw: PhantomData<RW>,
25 }
26
27 pub trait TimedFdReadWrite {
28   const INTEREST: mio::Interest;
29 }
30
31 pub type TimedFdReader = TimedFd<TimedFdRead>;
32 pub type TimedFdWriter = TimedFd<TimedFdWrite>;
33
34 #[derive(Debug,Copy,Clone)] pub struct TimedFdRead;
35 impl TimedFdReadWrite for TimedFdRead {
36   const INTEREST : mio::Interest = mio::Interest::READABLE;
37 }
38 #[derive(Debug,Copy,Clone)] pub struct TimedFdWrite;
39 impl TimedFdReadWrite for TimedFdWrite {
40   const INTEREST : mio::Interest = mio::Interest::WRITABLE;
41 }
42
43 pub struct Fd(RawFd);
44 impl Fd {
45   pub fn from_raw_fd(fd: RawFd) -> Self { Fd(fd) }
46   fn extract_raw_fd(&mut self) -> RawFd { mem::replace(&mut self.0, -1) }
47 }
48 impl IntoRawFd for Fd {
49   fn into_raw_fd(mut self) -> RawFd { self.extract_raw_fd() }
50 }
51 impl AsRawFd for Fd {
52   fn as_raw_fd(&self) -> RawFd { self.0 }
53 }
54
55 impl<RW> TimedFd<RW> where RW: TimedFdReadWrite {
56   /// Takes ownership of the fd
57   ///
58   /// Will change the fd's open-file to nonblocking.
59   #[throws(io::Error)]
60   pub fn new<F>(fd: F) -> TimedFd<RW> where F: IntoRawFd {
61     Self::from_fd( Fd::from_raw_fd( fd.into_raw_fd() ))?
62   }
63
64   /// Takes ownership of the fd
65   ///
66   /// Will change the fd's open-file to nonblocking.
67   #[throws(io::Error)]
68   fn from_fd(fd: Fd) -> Self {
69     fcntl(fd.as_raw_fd(), FcntlArg::F_SETFL(OFlag::O_NONBLOCK))
70       .map_err(|e| io::Error::from(e))?;
71
72     let poll = mio::Poll::new()?;
73     poll.registry().register(
74       &mut mio::unix::SourceFd(&fd.as_raw_fd()),
75       Token(0),
76       RW::INTEREST,
77     )?;
78     let events = mio::event::Events::with_capacity(1);
79     TimedFd { fd, poll, events, deadline: None, rw: PhantomData }
80   }
81 }
82
83 impl<RW> Timed for TimedFd<RW> where RW: TimedFdReadWrite {
84   fn set_deadline(&mut self, deadline: Option<Instant>) {
85     self.deadline = deadline;
86   }
87   fn set_timeout(&mut self, timeout: Option<Duration>) {
88     self.set_deadline(timeout.map(|timeout|{
89       Instant::now() + timeout
90     }));
91   }
92 }
93
94 impl<RW> TimedFd<RW> where RW: TimedFdReadWrite {
95   #[throws(io::Error)]
96   fn rw<F,O>(&mut self, mut f: F) -> O
97   where F: FnMut(i32) -> Result<O, nix::Error>
98   {
99     'again: loop {
100       for event in &self.events {
101         if event.token() == Token(0) {
102           match f(self.fd.as_raw_fd()) {
103             Ok(got) => { break 'again got },
104             Err(NE::EINTR) => continue 'again,
105             Err(NE::EAGAIN) => break,
106             Err(ne) => throw!(ne),
107           }
108         }
109       }
110
111       let timeout = if let Some(deadline) = self.deadline {
112         let now = Instant::now();
113         if now >= deadline { throw!(io::ErrorKind::TimedOut) }
114         Some(deadline - now)
115       } else {
116         None
117       };
118       loop {
119         match self.poll.poll(&mut self.events, timeout) {
120           Err(e) if e.kind() == ErrorKind::Interrupted => continue,
121           Err(e) => throw!(e),
122           Ok(()) => break,
123         }
124       }
125       if self.events.is_empty() { throw!(io::ErrorKind::TimedOut) }
126     }
127   }
128 }
129
130 impl Read for TimedFd<TimedFdRead> {
131   #[throws(io::Error)]
132   fn read(&mut self, buf: &mut [u8]) -> usize {
133     self.rw(|fd| unistd::read(fd, buf))?
134   }
135 }
136
137 impl Write for TimedFd<TimedFdWrite> {
138   #[throws(io::Error)]
139   fn write(&mut self, buf: &[u8]) -> usize {
140     self.rw(|fd| unistd::write(fd, buf))?
141   }
142   #[throws(io::Error)]
143   fn flush(&mut self) {
144   }
145 }
146
147 impl Drop for Fd {
148   fn drop(&mut self) {
149     let fd = self.extract_raw_fd();
150     if fd >= 2 { let _ = nix::unistd::close(fd); }
151   }
152 }
153