chiark / gitweb /
read_limited_bytes: take a `capacity` argument
[hippotat.git] / src / utils.rs
index a47b84802b595aacd4c2452a9197a3ec0ba2fe4e..c210a7d3ce2ab7a0801653cbfb40060f9465058b 100644 (file)
@@ -1,5 +1,5 @@
 // Copyright 2021 Ian Jackson and contributors to Hippotat
-// SPDX-License-Identifier: AGPL-3.0-or-later
+// SPDX-License-Identifier: GPL-3.0-or-later
 // There is NO WARRANTY.
 
 use crate::prelude::*;
@@ -16,16 +16,68 @@ impl<T,E> Result<T,E> where AE: From<E> {
   }
 }
 
+#[derive(Error,Debug)]
+pub enum ReadLimitedError {
+  #[error("maximum size {limit} exceeded")]
+  Truncated { sofar: Box<[u8]>, limit: usize },
+
+  #[error("HTTP error {0}")]
+  Hyper(#[from] hyper::Error),
+}
+
+impl ReadLimitedError {
+  pub fn discard_data(&mut self) { match self {
+    ReadLimitedError::Truncated { sofar,.. } => { mem::take(sofar); },
+    _ => { },
+  } }
+}
+#[ext(pub)]
+impl<T> Result<T,ReadLimitedError> {
+  fn discard_data(self) -> Self {
+    self.map_err(|mut e| { e.discard_data(); e })
+  }
+}
+
+#[throws(ReadLimitedError)]
+pub async fn read_limited_bytes<S>(limit: usize, initial: Box<[u8]>,
+                                   capacity: usize,
+                                   stream: &mut S) -> Box<[u8]>
+where S: futures::Stream<Item=Result<hyper::body::Bytes,hyper::Error>>
+         + Debug + Unpin,
+      // we also require that the Stream is cancellation-safe
+{
+  let mut accum = initial.into_vec();
+  let capacity = min(limit, capacity);
+  if capacity > accum.len() { accum.reserve(capacity - accum.len()); }
+  while let Some(item) = stream.next().await {
+    let b = item?;
+    accum.extend(b);
+    if accum.len() > limit {
+      throw!(ReadLimitedError::Truncated { limit, sofar: accum.into() })
+    }
+  }
+  accum.into()
+}
+
+pub fn time_t_now() -> u64 {
+  SystemTime::now()
+    .duration_since(UNIX_EPOCH)
+    .unwrap_or_else(|_| Duration::default()) // clock is being weird
+    .as_secs()
+}
+
 use sha2::Digest as _;
 
-type HmacH = sha2::Sha256;
-const HMAC_L: usize = 32;
+pub type HmacH = sha2::Sha256;
+pub const HMAC_B: usize = 64;
+pub const HMAC_L: usize = 32;
 
-fn token_hmac(key: &[u8], message: &[u8]) -> [u8; HMAC_L] {
+pub fn token_hmac(key: &[u8], message: &[u8]) -> [u8; HMAC_L] {
   let key = {
-    let mut padded = [0; HMAC_L];
-    if key.len() > HMAC_L {
-      padded = HmacH::digest(key).into();
+    let mut padded = [0; HMAC_B];
+    if key.len() > padded.len() {
+      let digest: [u8; HMAC_L] = HmacH::digest(key).into();
+      padded[0..HMAC_L].copy_from_slice(&digest);
     } else {
       padded[0.. key.len()].copy_from_slice(key);
     }
@@ -34,6 +86,8 @@ fn token_hmac(key: &[u8], message: &[u8]) -> [u8; HMAC_L] {
   let mut ikey = key;  for k in &mut ikey { *k ^= 0x36; }
   let mut okey = key;  for k in &mut okey { *k ^= 0x5C; }
 
+  //dbg!(DumpHex(&key), DumpHex(message), DumpHex(&ikey), DumpHex(&okey));
+
   let h1 = HmacH::new()
     .chain(&ikey)
     .chain(message)
@@ -47,6 +101,7 @@ fn token_hmac(key: &[u8], message: &[u8]) -> [u8; HMAC_L] {
 
 #[test]
 fn hmac_test_vectors(){
+  // C&P from RFC 4231
   let vectors = r#"
    Key =          0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b
                   0b0b0b0b                          (20 bytes)
@@ -129,7 +184,7 @@ fn hmac_test_vectors(){
 "#;
   let vectors = regex_replace_all!{
     r#"\(.*\)"#,
-    vectors,
+    vectors.trim_end(),
     |_| "",
   };
   let vectors = regex_replace_all!{