chiark / gitweb /
config, wip macro, generetes some code
[hippotat.git] / src / config.rs
1 // Copyright 2021 Ian Jackson and contributors to Hippotat
2 // SPDX-License-Identifier: AGPL-3.0-or-later
3 // There is NO WARRANTY.
4
5 use crate::prelude::*;
6
7 use configparser::ini::Ini;
8
9 #[derive(StructOpt,Debug)]
10 pub struct Opts {
11   /// Top-level config file or directory
12   ///
13   /// Look for `main.cfg`, `config.d` and `secrets.d` here.
14   ///
15   /// Or if this is a file, just read that file.
16   #[structopt(long, default_value="/etc/hippotat")]
17   pub config: PathBuf,
18   
19   /// Additional config files or dirs, which can override the others
20   #[structopt(long, multiple=true, number_of_values=1)]
21   pub extra_config: Vec<PathBuf>,
22 }
23
24 pub struct CidrString(pub String);
25
26 #[derive(hippotat_macros::ResolveConfig)]
27 pub struct InstanceConfig {
28 /*
29   // Exceptional settings
30   #[special(special_name, SKL::ServerName)] pub server: String,
31   pub                                           secret: String, // xxx newytpe
32   #[special(special_ipif, SKL::Ordinary)]   pub ipif:   String,
33
34   // Capped settings:
35 */
36   #[limited]    pub max_batch_down:               u32,
37 /*
38   #[limited]    pub max_queue_time:               Duration,
39   #[limited]    pub http_timeout:                 Duration,
40   #[limited]    pub target_requests_outstanding:  u32,
41
42   // Ordinary settings:
43   pub addrs:                        Vec<IpAddr>,
44   pub vnetwork:                     Vec<CidrString>,
45   pub vaddr:                        Vec<IpAddr>,
46   pub vrelay:                       IpAddr,
47   pub port:                         u16,
48   pub mtu:                          u32,
49   pub ifname_server:                String,
50   pub ifname_client:                String,
51
52   // Ordinary settings, used by server only:
53   #[server]  pub max_clock_skew:               Duration,
54
55   // Ordinary settings, used by client only:
56   #[client]  pub http_timeout_grace:           Duration,
57   #[client]  pub max_requests_outstanding:     u32,
58   #[client]  pub max_batch_up:                 u32,
59   #[client]  pub http_retry:                   Duration,
60   #[client]  pub url:                          Uri,
61   #[client]  pub vroutes:                      Vec<CidrString>,
62 */
63 }
64
65 #[derive(Debug,Clone,Hash,Eq,PartialEq)]
66 pub enum SectionName {
67   Link(LinkName),
68   Client(ClientName),
69   Server(ServerName), // includes SERVER, which is slightly special
70   ServerLimit(ServerName),
71   GlobalLimit,
72   Common,
73   Default,
74 }
75 pub use SectionName as SN;
76
77 #[derive(Debug,Clone)]
78 struct RawVal { val: Option<String>, loc: Arc<PathBuf> }
79 type SectionMap = HashMap<String, RawVal>;
80
81 pub struct Config {
82   opts: Opts,
83 }
84
85 static OUTSIDE_SECTION: &str = "[";
86
87 #[derive(Default,Debug)]
88 struct Aggregate {
89   sections: HashMap<SectionName, SectionMap>,
90 }
91
92 type OkAnyway<'f,A> = &'f dyn Fn(ErrorKind) -> Option<A>;
93 #[ext]
94 impl<'f,A> OkAnyway<'f,A> {
95   fn ok<T>(self, r: &Result<T, io::Error>) -> Option<A> {
96     let e = r.as_ref().err()?;
97     let k = e.kind();
98     let a = self(k)?;
99     Some(a)
100   }
101 }
102
103 impl FromStr for SectionName {
104   type Err = AE;
105   #[throws(AE)]
106   fn from_str(s: &str) -> Self {
107     match s {
108       "COMMON" => return SN::Common,
109       "DEFAULT" => return SN::Default,
110       "LIMIT" => return SN::GlobalLimit,
111       _ => { }
112     };
113     if let Ok(n@ ServerName(_)) = s.parse() { return SN::Server(n) }
114     if let Ok(n@ ClientName(_)) = s.parse() { return SN::Client(n) }
115     let (server, client) = s.split_ascii_whitespace().collect_tuple()
116       .ok_or_else(|| anyhow!(
117         "bad section name {:?} \
118          (must be COMMON, DEFAULT, <server>, <client>, or <server> <client>",
119         s
120       ))?;
121     let server = server.parse().context("server name in link section name")?;
122     if client == "LIMIT" { return SN::ServerLimit(server) }
123     let client = client.parse().context("client name in link section name")?;
124     SN::Link(LinkName { server, client })
125   }
126 }
127
128 impl Aggregate {
129   #[throws(AE)] // AE does not include path
130   fn read_file<A>(&mut self, path: &Path, anyway: OkAnyway<A>) -> Option<A>
131   {
132     let f = fs::File::open(path);
133     if let Some(anyway) = anyway.ok(&f) { return Some(anyway) }
134     let mut f = f.context("open")?;
135
136     let mut s = String::new();
137     let y = f.read_to_string(&mut s);
138     if let Some(anyway) = anyway.ok(&y) { return Some(anyway) }
139     y.context("read")?;
140
141     let mut ini = Ini::new_cs();
142     ini.set_default_section(OUTSIDE_SECTION);
143     ini.read(s).map_err(|e| anyhow!("{}", e)).context("parse as INI")?;
144     let map = mem::take(ini.get_mut_map());
145     if map.get(OUTSIDE_SECTION).is_some() {
146       throw!(anyhow!("INI file contains settings outside a section"));
147     }
148
149     let loc = Arc::new(path.to_owned());
150
151     for (sn, vars) in map {
152       dbg!( InstanceConfig::FIELDS );// check xxx vars are in fields
153
154       let sn = sn.parse().dcontext(&sn)?;
155         self.sections.entry(sn)
156         .or_default()
157         .extend(
158           vars.into_iter()
159             .map(|(k,val)| {
160               (k.replace('-',"_"),
161                RawVal { val, loc: loc.clone() })
162             })
163         );
164     }
165     None
166   }
167
168   #[throws(AE)] // AE includes path
169   fn read_dir_d<A>(&mut self, path: &Path, anyway: OkAnyway<A>) -> Option<A>
170   {
171     let dir = fs::read_dir(path);
172     if let Some(anyway) = anyway.ok(&dir) { return Some(anyway) }
173     let dir = dir.context("open directory").dcontext(path)?;
174     for ent in dir {
175       let ent = ent.context("read directory").dcontext(path)?;
176       let leaf = ent.file_name();
177       let leaf = leaf.to_str();
178       let leaf = if let Some(leaf) = leaf { leaf } else { continue }; //utf8?
179       if leaf.len() == 0 { continue }
180       if ! leaf.chars().all(
181         |c| c=='-' || c=='_' || c.is_ascii_alphanumeric()
182       ) { continue }
183
184       // OK we want this one
185       let ent = ent.path();
186       self.read_file(&ent, &|_| None::<Void>).dcontext(&ent)?;
187     }
188     None
189   }
190
191   #[throws(AE)] // AE includes everything
192   fn read_toplevel(&mut self, toplevel: &Path) {
193     enum Anyway { None, Dir }
194     match self.read_file(toplevel, &|k| match k {
195       EK::NotFound => Some(Anyway::None),
196       EK::IsADirectory => Some(Anyway::Dir),
197       _ => None,
198     })
199       .dcontext(toplevel).context("top-level config directory (or file)")?
200     {
201       None | Some(Anyway::None) => { },
202
203       Some(Anyway::Dir) => {
204         struct AnywayNone;
205         let anyway_none = |k| match k {
206           EK::NotFound => Some(AnywayNone),
207           _ => None,
208         };
209
210         let mk = |leaf: &str| {
211           [ toplevel, &PathBuf::from(leaf) ]
212             .iter().collect::<PathBuf>()
213         };
214
215         for &(try_main, desc) in &[
216           ("main.cfg", "main config file"),
217           ("master.cfg", "obsolete-named main config file"),
218         ] {
219           let main = mk(try_main);
220
221           match self.read_file(&main, &anyway_none)
222             .dcontext(main).context(desc)?
223           {
224             None => break,
225             Some(AnywayNone) => { },
226           }
227         }
228
229         for &(try_dir, desc) in &[
230           ("config.d", "per-link config directory"),
231           ("secrets.d", "per-link secrets directory"),
232         ] {
233           let dir = mk(try_dir);
234           match self.read_dir_d(&dir, &anyway_none).context(desc)? {
235             None => { },
236             Some(AnywayNone) => { },
237           }
238         }
239       }
240     }
241   }
242
243   #[throws(AE)] // AE includes extra, but does that this is extra
244   fn read_extra(&mut self, extra: &Path) {
245     struct AnywayDir;
246
247     match self.read_file(extra, &|k| match k {
248       EK::IsADirectory => Some(AnywayDir),
249       _ => None,
250     })
251       .dcontext(extra)?
252     {
253       None => return,
254       Some(AnywayDir) => {
255         self.read_dir_d(extra, &|_| None::<Void>)?;
256       }
257     }
258
259   }
260 }
261
262 enum LinkEnd { Server, Client }
263
264 struct ResolveContext<'c> {
265   agg: &'c Aggregate,
266   link: &'c LinkName,
267   end: LinkEnd,
268   all_sections: Vec<SectionName>,
269 }
270
271 trait Parseable: Sized {
272   fn parse(s: &Option<String>) -> Result<Self, AE>;
273 }
274
275 impl Parseable for Duration {
276   #[throws(AE)]
277   fn parse(s: &Option<String>) -> Duration {
278     let s = s.as_ref().ok_or_else(|| anyhow!("value needed"))?;
279     if let Ok(u64) = s.parse() { return Duration::from_secs(u64) }
280     throw!(anyhow!("xxx parse with humantime"))
281   }
282 }
283 macro_rules! parseable_from_str { ($t:ty) => {
284   impl Parseable for $t {
285     #[throws(AE)]
286     fn parse(s: &Option<String>) -> $t {
287       let s = s.as_ref().ok_or_else(|| anyhow!("value needed"))?;
288       s.parse()?
289     }
290   }
291 } }
292 parseable_from_str!{u32}
293
294 #[derive(Debug,Copy,Clone)]
295 enum SectionKindList {
296   Ordinary,
297   Limited,
298   Limits,
299   ClientAgnostic,
300   ServerName,
301 }
302 use SectionKindList as SKL;
303
304 impl SectionKindList {
305   fn contains(self, s: &SectionName) -> bool {
306     match self {
307       SKL::Ordinary       => matches!(s, SN::Link(_)
308                                        | SN::Client(_)
309                                        | SN::Server(_)
310                                        | SN::Common),
311
312       SKL::Limits         => matches!(s, SN::ServerLimit(_)
313                                        | SN::GlobalLimit),
314
315       SKL::ClientAgnostic => matches!(s, SN::Common
316                                        | SN::Server(_)),
317
318       SKL::Limited        => SKL::Ordinary.contains(s)
319                            | SKL::Limits  .contains(s),
320
321       SKL::ServerName     => matches!(s, SN::Common)
322                            | matches!(s, SN::Server(ServerName(name))
323                                          if name == "SERVER"),
324     }
325   }
326 }
327
328 impl<'c> ResolveContext<'c> {
329   fn first_of_raw(&self, key: &'static str, sections: SectionKindList)
330                   -> Option<&'c RawVal> {
331     for section in self.all_sections.iter()
332       .filter(|s| sections.contains(s))
333     {
334       if let Some(raw) = self.agg.sections
335         .get(section)
336         .and_then(|vars: &SectionMap| vars.get(key))
337       {
338         return Some(raw)
339       }
340     }
341     None
342   }
343
344   #[throws(AE)]
345   fn first_of<T>(&self, key: &'static str, sections: SectionKindList)
346                  -> Option<T>
347   where T: Parseable
348   {
349     match self.first_of_raw(key, sections) {
350       None => None,
351       Some(raw) => Some({
352         Parseable::parse(&raw.val)
353           .context(key)
354 //          .with_context(|| format!(r#"in section "{}""#, &section))
355           .dcontext(&raw.loc)?
356       }),
357     }
358   }
359
360   #[throws(AE)]
361   pub fn ordinary<T>(&self, key: &'static str) -> T
362   where T: Parseable + Default
363   {
364     self.first_of(key, SKL::Ordinary)?
365       .unwrap_or_default()
366   }
367
368   #[throws(AE)]
369   pub fn limited<T>(&self, key: &'static str) -> T
370   where T: Parseable + Default + Ord
371   {
372     let val = self.ordinary(key)?;
373     if let Some(limit) = self.first_of(key, SKL::Limits)? {
374       min(val, limit)
375     } else {
376       val
377     }
378   }
379
380   #[throws(AE)]
381   pub fn client<T>(&self, key: &'static str) -> T
382   where T: Parseable + Default {
383     match self.end {
384       LinkEnd::Client => self.ordinary(key)?,
385       LinkEnd::Server => default(),
386     }
387   }
388   #[throws(AE)]
389   pub fn server<T>(&self, key: &'static str) -> T
390   where T: Parseable + Default {
391     match self.end {
392       LinkEnd::Server => self.ordinary(key)?,
393       LinkEnd::Client => default(),
394     }
395   }
396
397   #[throws(AE)]
398   pub fn special_ipif<T>(&self, key: &'static str) -> T
399   where T: Parseable + Default
400   {
401     match self.end {
402       LinkEnd::Client => self.ordinary(key)?,
403       LinkEnd::Server => {
404         self.first_of(key, SKL::ClientAgnostic)?
405           .unwrap_or_default()
406       },
407     }
408   }
409 }
410
411 /*
412 impl<'c> ResolveContext<'c> {
413   #[throws(AE)]
414   fn resolve_instance(&self) -> InstanceConfig {
415     InstanceConfig {
416       max_batch_down: self.limited::<u32>("max_batch_down")?,
417     }
418   }
419 }
420 */
421
422 #[throws(AE)]
423 pub fn read() {
424   let opts = config::Opts::from_args();
425
426   let agg = (||{
427     let mut agg = Aggregate::default();
428
429     agg.read_toplevel(&opts.config)?;
430     for extra in &opts.extra_config {
431       agg.read_extra(extra).context("extra config")?;
432     }
433
434     eprintln!("GOT {:#?}", agg);
435
436     Ok::<_,AE>(agg)
437   })().context("read configuration")?;
438
439   let link = LinkName {
440     server: "fooxxx".parse().unwrap(),
441     client: "127.0.0.1".parse().unwrap(),
442   };
443
444   let rctx = ResolveContext {
445     agg: &agg,
446     link: &link,
447     end: LinkEnd::Server,
448     all_sections: vec![
449       SN::Link(link.clone()),
450       SN::Client(link.client.clone()),
451       SN::Server(link.server.clone()),
452       SN::Common,
453       SN::ServerLimit(link.server.clone()),
454       SN::GlobalLimit,
455     ],
456   };
457 }