chiark / gitweb /
Improved first couple of client operations.
authorSimon Tatham <anakin@pobox.com>
Thu, 28 Dec 2023 16:49:06 +0000 (16:49 +0000)
committerSimon Tatham <anakin@pobox.com>
Thu, 28 Dec 2023 16:49:06 +0000 (16:49 +0000)
Now we check errors sensibly rather than via cavalier .unwrap(); we
cache accounts as well as statuses per id; and we verify the returned
id against the one we asked for (to make it not an obvious attack
avenue to file things under the returned id).

src/client.rs

index 2c91b868091cf3e2cb969c8fa282d01643abc5b9..ac38593416fe821667700aa6897c309d3add8179 100644 (file)
@@ -1,3 +1,4 @@
+use reqwest::Url;
 use std::collections::HashMap;
 
 use super::auth::{AuthConfig,AuthError};
@@ -6,43 +7,106 @@ use super::types::*;
 pub struct Client {
     auth: AuthConfig,
     client: reqwest::blocking::Client,
+    accounts: HashMap<String, Account>,
     statuses: HashMap<String, Status>,
 }
 
+pub enum ClientError {
+    InternalError(String), // message
+    UrlParseError(String, String), // url, message
+    UrlError(String, String), // url, message
+}
+
+impl From<reqwest::Error> for ClientError {
+    fn from(err: reqwest::Error) -> Self {
+        match err.url() {
+            Some(url) => ClientError::UrlError(
+                url.to_string(), err.to_string()),
+            None => ClientError::InternalError(err.to_string()),
+        }
+    }
+}
+
 impl Client {
     pub fn new() -> Result<Self, AuthError> {
         Ok(Client {
             auth: AuthConfig::load()?,
             client: reqwest::blocking::Client::new(),
+            accounts: HashMap::new(),
             statuses: HashMap::new(),
         })
     }
 
     fn api_request(&self, method: reqwest::Method, url_suffix: &str)
-               -> reqwest::blocking::RequestBuilder {
-        let url = reqwest::Url::parse(
-            &(self.auth.instance_url.clone() + "/api/v1/" + url_suffix));
-        // FIXME: add params to that
+               -> Result<(String, reqwest::blocking::RequestBuilder),
+                         ClientError> {
+        let urlstr = self.auth.instance_url.clone() + "/api/v1/" + url_suffix;
+        let url = match Url::parse(&urlstr) {
+            Ok(url) => Ok(url),
+            Err(e) => Err(ClientError::UrlParseError(
+               urlstr.clone(), e.to_string())),
+        }?;
 
-        let url = url.unwrap(); // FIXME: handle url::parser::ParseError
+        Ok((urlstr, self.client.request(method, url)
+            .bearer_auth(&self.auth.user_token)))
+    }
 
-        self.client.request(method, url)
-            .bearer_auth(&self.auth.user_token)
+    pub fn cache_account(&mut self, ac: &Account) {
+        self.accounts.insert(ac.id.to_string(), ac.clone());
     }
 
-    pub fn status_by_id(&mut self, id: &str) -> Option<Status> {
+    pub fn cache_status(&mut self, st: &Status) {
+        self.cache_account(&st.account);
+        self.statuses.insert(st.id.to_string(), st.clone());
+    }
+
+    pub fn account_by_id(&mut self, id: &str) -> Result<Account, ClientError> {
+        if let Some(st) = self.accounts.get(id) {
+            return Ok(st.clone());
+        }
+
+        let (url, req) = self.api_request(reqwest::Method::GET,
+                                          &("accounts/".to_owned() + id))?;
+        let body = req.send()?.text()?;
+        let ac: Account = match serde_json::from_str(&body) {
+            Ok(ac) => Ok(ac),
+            Err(e) => Err(ClientError::UrlError(url.clone(), e.to_string())),
+        }?;
+        if ac.id != id {
+            return Err(ClientError::UrlError(
+                url.clone(), format!("request returned wrong account id {}",
+                                     &ac.id)));
+        }
+        self.accounts.insert(id.to_string(), ac.clone());
+        Ok(ac)
+    }
+
+    pub fn status_by_id(&mut self, id: &str) -> Result<Status, ClientError> {
         if let Some(st) = self.statuses.get(id) {
-            return Some(st.clone());
+            let mut st = st.clone();
+            if let Some(ac) = self.accounts.get(&st.account.id) {
+                // Update the account details with the latest version
+                // we had cached
+                st.account = ac.clone();
+            }
+            return Ok(st);
         }
 
-        let req = self.api_request(reqwest::Method::GET,
-                                   &("statuses/".to_owned() + id));
-        // FIXME: if this goes wrong, log it
-        let body = req.send().unwrap().text().unwrap();
-        dbg!(&body);
-        let st: Status = serde_json::from_str(&body).unwrap();
-        self.statuses.insert(id.to_string(), st);
-        Some(self.statuses.get(id).unwrap().clone())
+        let (url, req) = self.api_request(reqwest::Method::GET,
+                                          &("statuses/".to_owned() + id))?;
+        let body = req.send()?.text()?;
+        let st: Status = match serde_json::from_str(&body) {
+            Ok(st) => Ok(st),
+            Err(e) => Err(ClientError::UrlError(url.clone(), e.to_string())),
+        }?;
+        if st.id != id {
+            return Err(ClientError::UrlError(
+                url.clone(), format!("request returned wrong status id {}",
+                                     &st.id)));
+        }
+        self.accounts.insert(id.to_string(), st.account.clone());
+        self.statuses.insert(id.to_string(), st.clone());
+        Ok(st)
     }
 }