chiark / gitweb /
packetframe: Completely rewrite FrameReader
authorIan Jackson <ijackson@chiark.greenend.org.uk>
Sun, 18 Apr 2021 12:30:34 +0000 (13:30 +0100)
committerIan Jackson <ijackson@chiark.greenend.org.uk>
Fri, 23 Apr 2021 18:32:25 +0000 (19:32 +0100)
The was growing more and more epicycles, every time I fixed a test
failure.  And it still wouldn't cope with an empty frame!

Signed-off-by: Ian Jackson <ijackson@chiark.greenend.org.uk>
src/packetframe.rs

index fe87ec58c944388d32c3725b8df65d789a7f4d6f..c658dc40061b825bf924f9c2ed5be4a39a4967f5 100644 (file)
@@ -23,6 +23,8 @@ const CHUNK_MAX: ChunkLen = 65534;
 const CHUNK_ERR: ChunkLen = 65535;
 const CHUNK_DEF: ChunkLen = 8192;
 
+pub const BUFREADER_CAPACITY: usize = CHUNK_DEF as usize + 4;
+
 type BO = BigEndian;
 
 #[derive(Debug,Copy,Clone,Error)]
@@ -43,32 +45,31 @@ pub struct Broken {
 
 #[derive(Debug)]
 pub struct FrameReader<R: Read> {
-  inner: Fuse<R>,
   state: ReaderState,
+  inner: BufReader<Fuse<R>>,
 }
 
 #[derive(Debug)]
 pub struct ReadFrame<'r,R:Read> {
-  fr: Result<&'r mut FrameReader<R>, Option<SenderError>>,
+  fr: &'r mut FrameReader<R>,
 }
 
 #[derive(Debug,Copy,Clone)]
 enum ReaderState {
-  Idle,
-  FrameStart,
-  InFrame(usize),
-  HadEof,
+  InBuffer { ibuf_used: ChunkLen, chunk_remaining: ChunkLen },
+  InChunk { remaining: ChunkLen },
+  HadFrameEnd(Result<(), SenderError>),
+  // xxx HadUnexpectedStreamEof,
 }
 use ReaderState::*;
 
 #[derive(Debug,Error)]
-enum ReadError {
-  GoodEof,
+enum ReadHeaderError {
+  TolerableEof,
   IO(#[from] io::Error),
-  SE(#[from] SenderError),
 }
-display_as_debug!{ReadError}
-use ReadError as RE;
+display_as_debug!{ReadHeaderError}
+use ReadHeaderError as RHE;
 
 // ---------- write ----------
 
@@ -152,107 +153,137 @@ impl<W:Write> Write for Fuse<W> {
 
 // ---------- read ----------
 
-impl ReaderState {
-  fn idle(&self) -> bool {
-    matches_doesnot!(self,
-                     = Idle | HadEof,
-                     ! FrameStart | InFrame(_))
-  }
-}
-
-fn badeof() -> ReadError { RE::IO(io::ErrorKind::UnexpectedEof.into()) }
+fn badeof() -> io::Error { io::ErrorKind::UnexpectedEof.into() }
 
 impl<R:Read> FrameReader<R> {
-  pub fn new(r: R) -> FrameReader<R> where R:BufRead {
-    Self::new_raw(Fuse::new(r))
+  pub fn new(r: R) -> FrameReader<R> {
+    let r = Fuse::new(r);
+    let r = BufReader::with_capacity(BUFREADER_CAPACITY, r);
+    Self::from_bufreader(r)
   }
-  fn new_raw(r: Fuse<R>) -> FrameReader<R> {
-    FrameReader { inner: r, state: Idle }
+  pub fn from_bufreader(r: BufReader<Fuse<R>>) -> FrameReader<R> {
+    FrameReader { inner: r, state: HadFrameEnd(Ok(())) }
+  }
+
+  #[throws(MgmtChannelReadError)]
+  pub fn read_rmp<T:DeserializeOwned>(&mut self) -> Option<T> {
+    let frame = self.new_frame()?;
+    if_let!{ Some(mut frame) = frame; else return Ok(None); };
+    let v = rmp_serde::decode::from_read(&mut frame)
+      .map_err(|e| MgmtChannelReadError::Parse(format!("{}", &e)))?;
+    Some(v)
   }
 
   #[throws(io::Error)]
   pub fn new_frame<'r>(&'r mut self) -> Option<ReadFrame<'r,R>> {
-    if ! self.state.idle() {
-      let mut buf = vec![0u8; CHUNK_DEF.into()];
-      while ! self.state.idle() {
-        match self.do_read(&mut buf) {
-          Ok(_) | Err(RE::SE(_)) => {},
-          Err(RE::GoodEof) => break,
-          Err(RE::IO(ioe)) => throw!(ioe),
-        }
-      }
-    }
-    self.state = FrameStart;
-    match Self::chunk_remaining(&mut self.inner, &mut self.state) {
+    self.finish_reading_frame()?;
+
+    match self.read_chunk_header() {
       Ok(_) => {},
-      Err(RE::GoodEof) => { self.state = HadEof; return None },
-      Err(RE::IO(e)) => throw!(e),
-      Err(RE::SE(e)) => throw!(e),
-    }
-    Some(ReadFrame { fr: Ok(self) })
-  }
-
-  #[throws(ReadError)]
-  fn chunk_remaining<'s>(inner: &mut Fuse<R>, state: &'s mut ReaderState)
-                         -> &'s mut usize {
-    match *state {
-      Idle => panic!(),
-      HadEof => throw!(RE::GoodEof),
-      FrameStart | InFrame(0) => {
-        *state = InFrame(match match {
-          let mut lbuf = [0u8;2];
-          let mut q = &mut lbuf[..];
-          match io::copy(
-            &mut inner.take(2),
-            &mut q,
-          )? {
-            // length of chunk header
-            0 => { match state { FrameStart => throw!(RE::GoodEof),
-                                 InFrame(0) => throw!(badeof()),
-                                 _ => panic!(), } },
-            1 => throw!(badeof()),
-            2 => (&lbuf[..]).read_u16::<BO>().unwrap(),
-            _ => panic!(),
-          }
-        } {
-          // value in chunk header
-          0         => Left(RE::GoodEof),
-          CHUNK_ERR => Left(RE::SE(SenderError)),
-          x         => Right(x as usize),
-        } {
-          // Left( end of frame )  Right( nonempty chunk len )
-          Left(e) => { *state = Idle; throw!(e); }
-          Right(x) => x,
-        });
-        match *state { InFrame(ref mut x) => x, _ => panic!() }
-      },
-      InFrame(ref mut remaining) => remaining,
+      Err(RHE::TolerableEof) => return None,
+      Err(RHE::IO(e)) => throw!(e),
     }
+    Some(ReadFrame { fr: self })
   }
 
-  #[throws(ReadError)]
-  fn do_read(&mut self, buf: &mut [u8]) -> usize {
-    assert_ne!(buf.len(), 0);
-    let remaining = Self::chunk_remaining(&mut self.inner, &mut self.state)?;
+  #[throws(io::Error)]
+  fn finish_reading_frame(&mut self) {
+    while matches_doesnot!(
+      self.state,
+      = InBuffer{..} | InChunk{..},
+      ! HadFrameEnd(..),
+    ) {
+      struct Discard;
+      impl ReadOutput for Discard {
+        #[inline]
+        fn copy_from_buf(&mut self, input: &[u8]) -> usize { input.len() }
+      }
+      self.read_from_frame(&mut Discard)?;
+    }
+  }
 
-    //dbgc!(buf.len(), &remaining);
+  #[throws(ReadHeaderError)]
+  fn read_chunk_header(&mut self) {
+    assert!(matches_doesnot!(
+      self.state,
+      = InChunk { remaining: 0 },
+      = HadFrameEnd(..),
+      ! InChunk { remaining: _ },
+      ! InBuffer{..},
+    ), "bad state {:?}", self.state);
+
+    let header_value = {
+      let mut lbuf = [0u8;2];
+      let mut q = &mut lbuf[..];
+      match io::copy(
+        &mut (&mut self.inner).take(2),
+        &mut q,
+      )? {
+        // length of chunk header read
+        0 => throw!(RHE::TolerableEof), // EOF on underlying stream
+        1 => throw!(badeof()),
+        2 => (&lbuf[..]).read_u16::<BO>().unwrap(),
+        _ => panic!(),
+      }
+    };
 
-    let n = min(buf.len(), *remaining);
-    let r = self.inner.read(&mut buf[0..n])?;
-    assert!(r <= n);
-    if r == 0 { throw!(badeof()); }
-    *remaining -= r;
-    //dbgc!(r, self.in_frame);
-    r
+    self.state = match header_value {
+      0         => HadFrameEnd(Ok(())),
+      CHUNK_ERR => HadFrameEnd(Err(SenderError)),
+      len       => InChunk { remaining: len },
+    }
   }
 
-  #[throws(MgmtChannelReadError)]
-  pub fn read_rmp<T:DeserializeOwned>(&mut self) -> Option<T> {
-    let frame = self.new_frame()?;
-    if_let!{ Some(mut frame) = frame; else return Ok(None); };
-    let v = rmp_serde::decode::from_read(&mut frame)
-      .map_err(|e| MgmtChannelReadError::Parse(format!("{}", &e)))?;
-    Some(v)
+  #[throws(io::Error)]
+  fn read_from_frame<O:ReadOutput+?Sized>(&mut self, output: &mut O) -> usize {
+    loop {
+      if let InBuffer { ref mut ibuf_used, chunk_remaining } = self.state {
+        let ibuf = self.inner.buffer();
+        let cando = &ibuf[ (*ibuf_used).into() ..
+                             min(ibuf.len(), chunk_remaining.into()) ];
+        let got = output.copy_from_buf(cando);
+        *ibuf_used += ChunkLen::try_from(got).unwrap();
+        if got != 0 { break got }
+        assert_eq!(cando.len(), 0);
+        self.inner.consume((*ibuf_used).into());
+        let remaining = chunk_remaining - *ibuf_used;
+        self.state = InChunk { remaining };
+      }
+
+      if let InChunk { remaining } = self.state {
+        if remaining > 0 {
+          let got = self.inner.fill_buf()?.len();
+          if got == 0 { throw!(badeof()) }
+          self.state = InBuffer { ibuf_used: 0, chunk_remaining: remaining };
+          continue;
+        }
+      }
+
+      match self.state {
+        InChunk { remaining: 0 } => { },
+        HadFrameEnd(Ok(())) => break 0,
+        HadFrameEnd(Err(e)) => throw!(e),
+        _ => panic!("bad state {:?}", self.state),
+      }
+
+      match self.read_chunk_header() {
+        Ok(()) => { },
+        Err(RHE::TolerableEof) => throw!(badeof()),
+        Err(RHE::IO(e)) => throw!(e),
+      }
+    }
+  }   
+}
+
+trait ReadOutput {
+  fn copy_from_buf(&mut self, input: &[u8]) -> usize;
+}
+    
+impl ReadOutput for [u8] {
+  #[inline]
+  fn copy_from_buf(&mut self, input: &[u8]) -> usize {
+    let mut p = self;
+    p.write(input).unwrap()
   }
 }
 
@@ -260,19 +291,7 @@ impl<'r, R:Read> Read for ReadFrame<'r, R> {
   #[throws(io::Error)]
   fn read(&mut self, buf: &mut [u8]) -> usize {
     if buf.len() == 0 { return 0 }
-    //dbgc!(buf.len(), self.fr.as_ref().err());
-    let fr = match self.fr {
-      Ok(ref mut fr) => fr,
-      Err(None) => return 0,
-      Err(Some(e@ SenderError)) => throw!(e),
-    };
-    //dbgc!(fr.in_frame);
-    match fr.do_read(buf) {
-      Ok(0) | Err(RE::GoodEof) => { self.fr = Err(None); 0 },
-      Ok(x) => x,
-      Err(RE::IO(ioe)) => throw!(ioe),
-      Err(RE::SE(e@ SenderError)) => { self.fr = Err(Some(e)); throw!(e) },
-    }
+    self.fr.read_from_frame(buf)?
   }
 }
 
@@ -512,9 +531,9 @@ fn write_test(){
     for bufsize in 1..=msg.buf.len()+1 {
       dbgc!(lumpsize, bufsize);
       let rd = LumpReader::new(lumpsize, &*msg.buf);
-      let rd = BufReader::with_capacity(bufsize, rd);
       let rd = Fuse::new(rd);
-      let mut rd = FrameReader::new_raw(rd);
+      let rd = BufReader::with_capacity(bufsize, rd);
+      let mut rd = FrameReader::from_bufreader(rd);
 
       expect_good(&mut rd, b"hello");
       expect_boom(&mut rd);