chiark / gitweb /
It works a bit more, so that's probably progress.
[tripe-android] / peers.scala
1 /* -*-scala-*-
2  *
3  * The database of known peers
4  *
5  * (c) 2018 Straylight/Edgeware
6  */
7
8 /*----- Licensing notice --------------------------------------------------*
9  *
10  * This file is part of the Trivial IP Encryption (TrIPE) Android app.
11  *
12  * TrIPE is free software: you can redistribute it and/or modify it under
13  * the terms of the GNU General Public License as published by the Free
14  * Software Foundation; either version 3 of the License, or (at your
15  * option) any later version.
16  *
17  * TrIPE is distributed in the hope that it will be useful, but WITHOUT
18  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
19  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
20  * for more details.
21  *
22  * You should have received a copy of the GNU General Public License
23  * along with TrIPE.  If not, see <https://www.gnu.org/licenses/>.
24  */
25
26 package uk.org.distorted.tripe; package object peers {
27
28 /*----- Imports -----------------------------------------------------------*/
29
30 import java.io.{BufferedReader, File, FileReader, Reader};
31 import java.net.{InetAddress, Inet4Address, Inet6Address,
32                  UnknownHostException};
33
34 import scala.collection.mutable.{HashMap, HashSet};
35 import scala.concurrent.Channel;
36 import scala.util.control.Breaks;
37 import scala.util.matching.Regex;
38
39 /*----- Handy regular expressions -----------------------------------------*/
40
41 private final val RX_COMMENT = """(?x) ^ \s* (?: [;\#] .* )? $""".r;
42 private final val RX_GRPHDR = """(?x) ^ \s* \[ (.*) \] \s* $""".r;
43 private final val RX_ASSGN = """(?x) ^
44         ([^\s:=] (?: [^:=]* [^\s:=])?)
45         \s* [:=] \s*
46         (| [^\s\#;]\S* (?: \s+ [^\s\#;]\S*)*)
47         (?: \s+ (?: [;\#].*)? )? $""".r;
48 private final val RX_CONT = """(?x) ^ \s+
49         (| [^\s\#;]\S* (?: \s+ [^\s\#;]\S*)*)
50         (?: \s+ (?: [;\#].*)? )? $""".r;
51 private final val RX_REF = """(?x) \$ \( ([^)]+) \)""".r;
52 private final val RX_RESOLVE = """(?x) \$ ([46*]*) \[ ([^\]]+) \]""".r;
53 private final val RX_PARENT = """(?x) [^\s,]+""".r
54
55 /*----- Name resolution ---------------------------------------------------*/
56
57 private object BulkResolver {
58   private val BREAK = new Breaks;
59 }
60
61 private class BulkResolver(val nthreads: Int = 8) {
62   import BulkResolver.BREAK.{breakable, break};
63
64   class Host(val name: String) {
65     var a4, a6: Seq[InetAddress] = Seq.empty;
66
67     def addaddr(a: InetAddress) { a match {
68       case _: Inet4Address => a4 +:= a;
69       case _: Inet6Address => a6 +:= a;
70       case _ => ();
71     } }
72
73     def get(flags: String): Seq[InetAddress] = {
74       var wanta4, wanta6, any, all = false;
75       var b = Seq.newBuilder[InetAddress];
76       for (ch <- flags) ch match {
77         case '*' => all = true;
78         case '4' => wanta4 = true; any = true;
79         case '6' => wanta6 = true; any = true;
80         case _ => ???
81       }
82       if (!any) { wanta4 = true; wanta6 = true; }
83       if (wanta4) b ++= a4;
84       if (wanta6) b ++= a6;
85       (all, b.result) match {
86         case (true, aa) => aa
87         case (false, aa@(Nil | Seq(_))) => aa
88         case (false, Seq(a, _*)) => Seq(a)
89       }
90     }
91   }
92
93   val ch = new Channel[Host];
94   val map = HashMap[String, Host]();
95   var preparing = true;
96
97   val workers = Array.tabulate(nthreads) { i =>
98     thread(s"resolver worker #$i") {
99       loopUnit { exit =>
100         val host = ch.read; if (host == null) exit;
101 println(s";; ${Thread.currentThread.getName} resolving `${host.name}'");
102         try {
103           for (a <- InetAddress.getAllByName(host.name)) host.addaddr(a);
104         } catch { case e: UnknownHostException => () }
105       }
106 println(s";; ${Thread.currentThread.getName} done'");
107       ch.write(null);
108     }
109   }
110
111   def prepare(name: String) {
112 println(s";; prepare host `$name'");
113     assert(preparing);
114     if (!(map contains name)) {
115       val host = new Host(name);
116       map(name) = host;
117       ch.write(host);
118     }
119   }
120
121   def finish() {
122     assert(preparing);
123     preparing = false;
124     ch.write(null);
125     for (t <- workers) t.join();
126   }
127
128   def resolve(name: String, flags: String): Seq[InetAddress] =
129     map(name).get(flags);
130 }
131
132 /*----- The peer configuration --------------------------------------------*/
133
134 def fmtpath(path: Seq[String]) =
135   path.reverse map { i => s"`$i'" } mkString " -> ";
136
137 class ConfigSyntaxError(val file: File, val lno: Int, val msg: String)
138         extends Exception {
139   override def getMessage(): String = s"$file:$lno: $msg";
140 }
141
142 class MissingConfigSection(val sect: String) extends Exception {
143   override def getMessage(): String =
144     s"missing configuration section `$sect'";
145 }
146
147 class MissingConfigItem(val sect: String, val key: String,
148                         val path: Seq[(String)]) extends Exception {
149   override def getMessage(): String = {
150     val msg = s"missing configuration item `$key' in section `$sect'";
151     if (path == Nil) msg
152     else msg + s" (wanted while expanding ${fmtpath(path)})"
153   }
154 }
155
156 class AmbiguousConfig(val key: String,
157                       val v0: String, val p0: Seq[String],
158                       val v1: String, val p1: Seq[String])
159         extends Exception {
160   override def getMessage(): String =
161     s"ambiguous answer resolving key `$key': " +
162     s"path ${fmtpath(p0)} yields `$v0', but ${fmtpath(p1)} yields `$v1'";
163 }
164
165 class ConfigCycle(val key: String, path: Seq[String]) extends Exception {
166   override def getMessage(): String =
167     s"found a cycle ${fmtpath(path)} looking up key `$key'";
168 }
169
170 class NoHostAddresses(val sect: String, val key: String, val host: String)
171         extends Exception {
172   override def getMessage(): String =
173     s"no addresses found for `$host' (key `$key' in section `$sect')";
174 }
175
176 private sealed abstract class ConfigCacheEntry;
177 private case object StillLooking extends ConfigCacheEntry;
178 private case object NotFound extends ConfigCacheEntry;
179 private case class Found(value: String, path: Seq[String])
180         extends ConfigCacheEntry;
181
182 class Config { conf =>
183
184   class Section private(val name: String) {
185     private val itemmap = HashMap[String, String]();
186     private[this] val cache = HashMap[String, ConfigCacheEntry]();
187
188     override def toString: String = s"${getClass.getName}($name)";
189
190     def parents: Seq[Section] =
191       (itemmap.get("@inherit")
192        map { pp => (RX_PARENT.findAllIn(pp) map { conf.section _ }).toList }
193        getOrElse Nil);
194
195     private def get_internal(key: String, path: Seq[String] = Nil):
196               Option[(String, Seq[String])] = {
197       val incpath = name +: path;
198
199       for (r <- cache.get(key)) r match {
200         case StillLooking => throw new ConfigCycle(key, incpath)
201         case NotFound => return None
202         case Found(v, p) => return Some((v, p ++ path));
203       }
204
205       for (v <- itemmap.get(key)) {
206         cache(key) = Found(v, Seq(name));
207         return Some((v, incpath));
208       }
209
210       cache(key) = StillLooking;
211
212       ((None: Option[(String, Seq[String])]) /: parents) { (st, parent) =>
213         parent.get_internal(key, incpath) match {
214           case None => st;
215           case newst@Some((v, p)) => st match {
216             case None => newst
217             case Some((vv, _)) if v == vv => st
218             case Some((vv, pp)) =>
219               throw new AmbiguousConfig(key, v, p, vv, pp)
220           }
221         }
222       } match {
223         case None => cache(key) = NotFound; None
224         case Some((v, p)) =>
225           cache(key) = Found(v, p dropRight path.length);
226           Some((v, p))
227       }
228     }
229
230     def get(key: String, resolve: Boolean = true,
231             path: Seq[String] = Nil): String = {
232       val v0 = key match {
233         case "name" => itemmap.getOrElse("name", name)
234         case _ => get_internal(key).
235           getOrElse(throw new MissingConfigItem(name, key, path))._1
236       }
237       expand(key, v0, resolve, path)
238     }
239
240     private def expand(key: String, value: String, resolve: Boolean,
241                        path: Seq[String]): String = {
242       val v1 = RX_REF.replaceAllIn(value, { m =>
243         Regex.quoteReplacement(get(m.group(1), resolve, path))
244       });
245       val v2 = if (!resolve) v1
246                else RX_RESOLVE.replaceAllIn(v1, { m =>
247                  resolver.resolve(m.group(2), m.group(1)) match {
248                    case Nil =>
249                      throw new NoHostAddresses(name, key, m.group(2));
250                    case addrs =>
251                      Regex.quoteReplacement((addrs map { _.getHostAddress }).
252                                             mkString(" "));
253                  }
254                })
255       v2
256     }
257
258     def items: Seq[String] = {
259       val b = Seq.newBuilder[String];
260       val seen = HashSet[String]();
261       val visiting = HashSet[String](name);
262       var stack = List(this);
263
264       while (stack != Nil) {
265         val sect = stack.head; stack = stack.tail;
266         for (k <- sect.itemmap.keys)
267           if (!(seen contains k)) { b += k; seen += k; }
268         for (p <- sect.parents)
269           if (!(visiting contains p.name))
270             { stack ::= p; visiting += p.name; }
271       }
272       b.result
273     }
274   }
275
276   private[this] val sectmap = new HashMap[String, Section];
277   def sections: Iterator[Section] = sectmap.values.iterator;
278   def section(name: String): Section =
279     sectmap.getOrElse(name, throw new MissingConfigSection(name));
280
281   private[this] val resolver = new BulkResolver;
282
283   private[this] def parseFile(path: File): this.type = {
284 println(s";; parse ${path.getPath}");
285     withCleaner { clean =>
286       val in = new FileReader(path); clean { in.close(); }
287
288       val lno = 1;
289       val b = new StringBuilder;
290       var key: String = null;
291       var sect: Section = null;
292       def flush() {
293         if (key != null) {
294           sect.itemmap(key) = b.result;
295 println(s";; in `${sect.name}', set `$key' to `${b.result}'");
296           b.clear();
297           key = null;
298         }
299       }
300       for (line <- lines(in)) line match {
301         case RX_COMMENT() =>
302           ();
303         case RX_GRPHDR(grp) =>
304           flush();
305           sect = sectmap.getOrElseUpdate(grp, new Section(grp));
306         case RX_CONT(v) =>
307           if (key == null) {
308             throw new ConfigSyntaxError(
309               path, lno, "no config value to continue");
310           }
311           b += '\n'; b ++= v;
312         case RX_ASSGN(k, v) =>
313           if (sect == null) {
314             throw new ConfigSyntaxError(
315               path, lno, "no active section to update");
316           }
317           flush();
318           key = k; b ++= v;
319         case _ =>
320           throw new ConfigSyntaxError(path, lno, "incomprehensible line");
321       }
322       flush();
323     }
324     this
325   }
326
327   def parse(path: File): this.type = {
328     if (!path.isDirectory) parseFile(path);
329     else for {
330       f <- path.listFiles sortBy { _.getName };
331       name = f.getName;
332       if name.length > 0;
333       tail = name(name.length - 1);
334       if tail != '#' && tail != '~'
335     } parseFile(f);
336     this
337   }
338   def parse(path: String): this.type = parse(new File(path));
339
340   def analyse() {
341 println(";; resolving all...");
342     for ((_, sect) <- sectmap) {
343 println(s";; resolving in section `${sect.name}'...");
344       for (key <- sect.items) {
345 println(s";;    resolving in key `$key'...");
346         val mm = RX_RESOLVE.findAllIn(sect.get(key, false));
347         for (host <- mm) { resolver.prepare(mm.group(2)); }
348       }
349     }
350     resolver.finish();
351
352     def dumpsect(sect: Section) {
353       for (k <- sect.items.filterNot(_.startsWith("@")).sorted)
354         println(s";;    `$k' -> `${sect.get(k)}'")
355     }
356     for (sect <- sectmap.values.toSeq sortBy { _.name })
357       if (sect.name.startsWith("@")) ();
358       else if (sect.name.startsWith("$")) {
359         println(s";; special section `${sect.name}'");
360         dumpsect(sect);
361       } else {
362         println(s";; peer section `${sect.name}'");
363         dumpsect(sect);
364       }
365   }
366 }
367
368 /*----- That's all, folks -------------------------------------------------*/
369
370 }