chiark / gitweb /
Integrate the TrIPE server into the Java edifice.
authorMark Wooding <mdw@distorted.org.uk>
Sat, 16 Jun 2018 18:32:22 +0000 (19:32 +0100)
committerMark Wooding <mdw@distorted.org.uk>
Sat, 16 Jun 2018 18:32:22 +0000 (19:32 +0100)
And probably other things too.  We're still in broad brushstrokes mode
here.

Makefile
admin.scala
jni.c
keys.scala
sys.scala

index 6e5ca0e..5edb29a 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -40,7 +40,7 @@ CC                     = gcc
 CFLAGS                  = -O2 -g -Wall -pedantic -Werror
 
 ## Native linker.
-LD                      = gcc
+LD                      = gcc -Wl,-z,defs
 LDFLAGS.so              = -shared
 
 ## External `pkg-config' packages required.
@@ -107,14 +107,30 @@ V_AT_0                     = @
 ###--------------------------------------------------------------------------
 ### External native packages.
 
-PKGS_CFLAGS            := $(foreach p,$(PKGS),$(shell pkg-config --cflags $p))
-PKGS_LIBS              := $(foreach p,$(PKGS),$(shell pkg-config --libs $p))
+EXTPREFIX               = $(abs_builddir)/$(OUTDIR)/inst
+
+join-paths              = $(if $(filter /%,$2),$2,$1/$2)
+ext-srcdir              = $(or $($1_SRCDIR),../$1)
+
+PKG_CONFIG              = PKG_CONFIG_LIBDIR=$(OUTDIR)/inst/lib/pkgconfig \
+                               pkg-config --static
+
+PKGS_CFLAGS            := $(foreach p,$(PKGS),$(shell $(PKG_CONFIG) --cflags $p))
+PKGS_LIBS              := $(foreach p,$(PKGS),$(shell $(PKG_CONFIG) --libs $p))
 
 ALL_CFLAGS              = $(CFLAGS) -fPIC \
                                $(addprefix -I,$(JNI_INCLUDES)) \
+                               -I$(OUTDIR)/inst/include \
+                               -I$(call ext-srcdir,tripe)/common \
+                               -I$(call ext-srcdir,tripe)/priv \
+                               -I$(call ext-srcdir,tripe)/server \
+                               -I$(OUTDIR)/build/tripe/config \
                                $(PKGS_CFLAGS)
 
-LIBS                    = $(PKGS_LIBS)
+LIBS                    = $(OUTDIR)/build/tripe/server/libtripe.a \
+                               $(OUTDIR)/build/tripe/priv/libpriv.a \
+                               $(OUTDIR)/build/tripe/common/libcommon.a \
+                               -L$(OUTDIR)/inst/lib $(PKGS_LIBS)
 
 ###--------------------------------------------------------------------------
 ### Various other tweaks and overrides.
@@ -177,14 +193,16 @@ $(OUTDIR)/%.class-stamp: %.scala
 ###--------------------------------------------------------------------------
 ### Native-code libraries.
 
-SHLIBS                 += toy
-toy_SOURCES             = jni.c
+SHLIBS                 += tripe
+tripe_SOURCES           = jni.c
 
 shlibfile               = $(patsubst %,$(OUTDIR)/lib%.so,$1)
 SHLIBFILES              = $(call shlibfile,$(SHLIBS))
 TARGETS                        += $(SHLIBFILES)
 ALL_SOURCES            += $(foreach l,$(SHLIBS),$($l_SOURCES))
 
+$(call objects,$(tripe_SOURCES),.o): $(call stamps,ext,tripe)
+
 $(SHLIBFILES): $(OUTDIR)/lib%.so: $$(call objects,$$($$*_SOURCES),.o)
        $(call v_tag,LD)$(LD) $(LDFLAGS.so) -o$@ $^ $(LIBS)
 
@@ -203,7 +221,6 @@ CLASSES                     += tar:util
 CLASSES                        += progress:sys,util
 CLASSES                        += keys:progress,tar,sys,util
 CLASSES                        += terminal:progress,sys,util
-CLASSES                        += main:sys
 
 ## Machinery for parsing the `CLASSES' list.
 COMMA                   = ,
@@ -224,11 +241,6 @@ DISTFILES          += $(foreach c,$(CLASSES),\
 ###--------------------------------------------------------------------------
 ### External packages.
 
-EXTPREFIX               = $(abs_builddir)/$(OUTDIR)/inst
-
-join-paths              = $(if $(filter /%,$2),$2,$1/$2)
-ext-srcdir              = $(or $($1_SRCDIR),../$1)
-
 EXTERNALS              += adns
 adns_CONFIG             = --disable-dynamic
 
index fab8305..52a2912 100644 (file)
@@ -27,7 +27,7 @@ package uk.org.distorted.tripe; package object admin {
 
 /*----- Imports -----------------------------------------------------------*/
 
-import java.io.{BufferedReader, Reader, Writer};
+import java.io.{BufferedReader, InputStreamReader, OutputStreamWriter};
 import java.util.concurrent.locks.{Condition, ReentrantLock => Lock};
 
 import scala.collection.mutable.{HashMap, Publisher};
@@ -35,6 +35,7 @@ import scala.concurrent.Channel;
 import scala.util.control.Breaks;
 
 import Implicits._;
+import sys.{serverInput, serverOutput};
 
 /*----- Classification of server messages ---------------------------------*/
 
@@ -77,8 +78,7 @@ class CommandFailed(val msg: Seq[String]) extends Exception {
 
 class ConnectionLostException extends Exception;
 
-class Connection(val in: Reader, val out: Writer)
-       extends Publisher[AsyncMessage]
+object Connection extends Publisher[AsyncMessage]
 {
   /* Synchronization.
    *
@@ -87,12 +87,15 @@ class Connection(val in: Reader, val out: Writer)
    * hold the `Connection' lock before locking any individual `Job' objects.
    */
 
-  var livep: Boolean = true;           // Is this connection still alive?
-  var fgjob: Option[this.Job] = None;  // Foreground job, if there is one.
-  val jobmap = new HashMap[String, this.Job]; // Maps tags to extant jobs.
-  var bgseq = 0;                       // Next background job tag.
+  private var livep: Boolean = true;   // Is this connection still alive?
+  private var fgjob: Option[this.Job] = None; // Foreground job, if there is one.
+  private val jobmap = new HashMap[String, this.Job]; // Maps tags to extant jobs.
+  private var bgseq = 0;               // Next background job tag.
 
-  type Pub = Connection;
+  private val in = new BufferedReader(new InputStreamReader(serverInput));
+  private val out = new OutputStreamWriter(serverOutput);
+
+  type Pub = Connection.type;
 
   class Job extends Iterator[Seq[String]] {
     private[Connection] val ch = new Channel[JobMessage];
@@ -183,8 +186,6 @@ println(";; write command");
 
   def submit(toks: String*): this.Job = submit(false, toks: _*);
 
-  def close() { synchronized { out.close(); } }
-
   /* These two expect the connection lock to be held. */
   def foregroundJob: Job =
     fgjob.getOrElse { throw new ServerFailed("no foreground job"); }
@@ -267,7 +268,6 @@ println(s";; line: $line");
        }
       }
       publish(ConnectionLost);
-      in.close(); out.close();
     }
   }
 }
diff --git a/jni.c b/jni.c
index 9aa79f8..9ca3651 100644 (file)
--- a/jni.c
+++ b/jni.c
 #include <stdlib.h>
 #include <string.h>
 
-#include <jni.h>
-
-#include <sys/types.h>
+#include <dirent.h>
+#include <fcntl.h>
+#include <netdb.h>
+#include <unistd.h>
+#include <sys/ioctl.h>
 #include <sys/select.h>
 #include <sys/socket.h>
 #include <sys/stat.h>
 #include <sys/sysmacros.h>
+#include <sys/types.h>
 #include <sys/un.h>
-#include <fcntl.h>
-#include <unistd.h>
-#include <dirent.h>
+
+#include <jni.h>
+
+//#include <linux/if.h>
+#include <linux/if_tun.h>
 
 #include <mLib/align.h>
 #include <mLib/bits.h>
@@ -55,6 +60,9 @@
 
 #include <catacomb/ghash.h>
 
+#define TUN_INTERNALS
+#include <tripe.h>
+
 #undef sun
 
 /*----- Magic class names and similar -------------------------------------*/
 #define JNIFUNC(f) Java_uk_org_distorted_tripe_sys_package_00024_##f
 
 /* The little class for bundling up error codes. */
-#define ERRENTRY "uk/org/distorted/tripe/sys/package$ErrorEntry"
+#define ERRENTCLS "uk/org/distorted/tripe/sys/package$ErrorEntry"
+
+/* The `sys' package class. */
+#define SYSCLS "uk/org/distorted/tripe/sys/package"
+
+/* The server lock class. */
+#define LOCKCLS "uk/org/distorted/tripe/sys/package$ServerLock"
 
 /* The `stat' class. */
-#define STAT "uk/org/distorted/tripe/sys/package$FileInfo"
+#define STATCLS "uk/org/distorted/tripe/sys/package$FileInfo"
+
+/* Standard Java classes. */
+#define FDCLS "java/io/FileDescriptor"
+#define STRCLS "java/lang/String"
+#define RANDCLS "java/security/SecureRandom"
 
 /* Exception class names. */
 #define NULLERR "java/lang/NullPointerException"
 #define TYPEERR "uk/org/distorted/tripe/sys/package$NativeObjectTypeException"
 #define SYSERR "uk/org/distorted/tripe/sys/package$SystemError"
+#define NAMEERR "uk/org/distorted/tripe/sys/package$NameResolutionException"
+#define INITERR "uk/org/distorted/tripe/sys/package$InitializationException"
 #define ARGERR "java/lang/IllegalArgumentException"
+#define STERR "java/lang/IllegalStateException"
 #define BOUNDSERR "java/lang/IndexOutOfBoundsException"
 
-/*----- Miscellaneous utilities -------------------------------------------*/
+/*----- Essential state ---------------------------------------------------*/
 
-static void put_cstring(JNIEnv *jni, jbyteArray v, const char *p)
-  { if (p) (*jni)->ReleaseByteArrayElements(jni, v, (jbyte *)p, JNI_ABORT); }
+static JNIEnv *jni_tripe = 0;
+
+/*----- Miscellaneous utilities -------------------------------------------*/
 
 static void vexcept(JNIEnv *jni, const char *clsname,
                    const char *msg, va_list *ap)
@@ -164,6 +187,9 @@ static const char *get_cstring(JNIEnv *jni, jbyteArray v)
   return ((const char *)(*jni)->GetByteArrayElements(jni, v, 0));
 }
 
+static void put_cstring(JNIEnv *jni, jbyteArray v, const char *p)
+  { if (p) (*jni)->ReleaseByteArrayElements(jni, v, (jbyte *)p, JNI_ABORT); }
+
 static void vexcept_syserror(JNIEnv *jni, const char *clsname,
                             int err, const char *msg, va_list *ap)
 {
@@ -769,11 +795,11 @@ JNIEXPORT jobject JNIFUNC(errtab)(JNIEnv *jni, jobject cls)
   jobject e;
 
   eltcls =
-    (*jni)->FindClass(jni, ERRENTRY);
+    (*jni)->FindClass(jni, ERRENTCLS);
   assert(eltcls);
   v = (*jni)->NewObjectArray(jni, N(errtab), eltcls, 0); if (!v) return (0);
   init = (*jni)->GetMethodID(jni, eltcls, "<init>",
-                            "(Ljava/lang/String;I)V");
+                            "(L"STRCLS";I)V");
   assert(init);
 
   for (i = 0; i < N(errtab); i++) {
@@ -792,7 +818,7 @@ JNIEXPORT jobject JNIFUNC(strerror)(JNIEnv *jni, jobject cls, jint err)
 
 static void fdguts(JNIEnv *jni, jclass *cls, jfieldID *fid)
 {
-  *cls = (*jni)->FindClass(jni, "java/io/FileDescriptor"); assert(cls);
+  *cls = (*jni)->FindClass(jni, FDCLS); assert(cls);
   *fid = (*jni)->GetFieldID(jni, *cls, "fd", "I"); // OpenJDK
   if (!*fid) *fid = (*jni)->GetFieldID(jni, *cls, "descriptor", "I"); // Android
   assert(*fid);
@@ -1005,7 +1031,7 @@ static jobject xltstat(JNIEnv *jni, const struct stat *st)
   else if (S_ISLNK(st->st_mode)) modehack |= 0120000;
   else if (S_ISSOCK(st->st_mode)) modehack |= 0140000;
 
-  cls = (*jni)->FindClass(jni, STAT); assert(cls);
+  cls = (*jni)->FindClass(jni, STATCLS); assert(cls);
   init = (*jni)->GetMethodID(jni, cls, "<init>", "(IIJIIIIIIJIJJJJ)V");
   assert(init);
   return ((*jni)->NewObject(jni, cls, init,
@@ -1146,7 +1172,7 @@ struct trigger {
 static const struct native_type trigger_type =
        { "trigger", sizeof(struct trigger), 0x65ffd8b4 };
 
-JNIEXPORT wrapper JNICALL JNIFUNC(makeTrigger)(JNIEnv *jni, jobject cls)
+JNIEXPORT wrapper JNICALL JNIFUNC(make_1trigger)(JNIEnv *jni, jobject cls)
 {
   struct trigger trig;
   int fd[2];
@@ -1174,8 +1200,8 @@ end:
   return (ret);
 }
 
-JNIEXPORT void JNICALL JNIFUNC(destroyTrigger)(JNIEnv *jni, jobject cls,
-                                              wrapper wtrig)
+JNIEXPORT void JNICALL JNIFUNC(destroy_1trigger)(JNIEnv *jni, jobject cls,
+                                               wrapper wtrig)
 {
   struct trigger trig;
 
@@ -1185,8 +1211,8 @@ JNIEXPORT void JNICALL JNIFUNC(destroyTrigger)(JNIEnv *jni, jobject cls,
   update_wrapper(jni, &trigger_type, wtrig, &trig);
 }
 
-JNIEXPORT void JNICALL JNIFUNC(resetTrigger)(JNIEnv *jni, jobject cls,
-                                            wrapper wtrig)
+JNIEXPORT void JNICALL JNIFUNC(reset_1trigger)(JNIEnv *jni, jobject cls,
+                                             wrapper wtrig)
 {
   struct trigger trig;
   char buf[64];
@@ -1218,84 +1244,398 @@ JNIEXPORT void JNICALL JNIFUNC(trigger)(JNIEnv *jni, jobject cls,
     except_syserror(jni, SYSERR, errno, "failed to pull trigger");
 }
 
-/*----- A server connection, using a Unix-domain socket -------------------*/
+/*----- A tunnel supplied by Java -----------------------------------------*/
 
-struct conn {
-  struct native_base _base;
-  int fd;
-  unsigned f;
-#define CF_CLOSERD 1u
-#define CF_CLOSEWR 2u
-#define CF_CLOSEMASK (CF_CLOSERD | CF_CLOSEWR)
+struct tunnel {
+  const tunnel_ops *ops;
+  sel_file f;
+  struct peer *p;
 };
-static const struct native_type conn_type =
-       { "conn", sizeof(struct conn), 0xed030167 };
 
-JNIEXPORT wrapper JNICALL JNIFUNC(connect)(JNIEnv *jni, jobject cls,
-                                          jobject path, wrapper wtrig)
+static const struct tunnel_ops tun_java;
+
+static int t_init(void) { return (0); }
+
+static void t_read(int fd, unsigned mode, void *v)
 {
-  struct conn conn;
-  struct trigger trig;
-  struct sockaddr_un sun;
-  int rc, maxfd;
-  fd_set rfds, wfds;
-  const char *pathstr = 0;
-  int err;
-  socklen_t sz;
-  wrapper ret = 0;
-  int nb;
-  int fd = -1;
+  tunnel *t = v;
+  ssize_t n;
+  buf b;
 
-  if (unwrap(jni, &trig, &trigger_type, wtrig)) goto end;
-  pathstr = get_cstring(jni, path); if (!pathstr) goto end;
-  if (strlen(pathstr) >= sizeof(sun.sun_path)) {
-    except(jni, ARGERR,
-          "Unix-domain socket path `%s' too long", pathstr);
+  n = read(fd, buf_i, sizeof(buf_i));
+  if (n < 0) {
+    a_warn("TUN", "%s", p_ifname(t->p), "java",
+          "read-error", "?ERRNO", A_END);
+    return;
+  }
+  IF_TRACING(T_TUNNEL, {
+    trace(T_TUNNEL, "tun-java: packet arrived");
+    trace_block(T_PACKET, "tunnel: packet contents", buf_i, n);
+  })
+  buf_init(&b, buf_i, n);
+  p_tun(t->p, &b);
+}
+
+static tunnel *t_create(peer *p, int fd, char **ifn)
+{
+  JNIEnv *jni = jni_tripe;
+  tunnel *t = 0;
+  const char *name = p_name(p);
+  jbyteArray jname;
+  size_t n = strlen(p_name(p));
+  jclass cls, metacls;
+  jstring jclsname, jexcmsg;
+  const char *clsname, *excmsg;
+  jmethodID mid;
+  jthrowable exc;
+
+  assert(jni);
+
+  jname = wrap_cstring(jni, name);
+  cls = (*jni)->FindClass(jni, SYSCLS); assert(cls);
+  mid = (*jni)->GetStaticMethodID(jni, cls, "getTunnelFd", "([B)I");
+  assert(mid);
+  fd = (*jni)->CallStaticIntMethod(jni, cls, mid, jname);
+
+  exc = (*jni)->ExceptionOccurred(jni);
+  if (exc) {
+    cls = (*jni)->GetObjectClass(jni, exc);
+    metacls = (*jni)->GetObjectClass(jni, cls);
+    mid = (*jni)->GetMethodID(jni, metacls,
+                             "getName", "()L"STRCLS";");
+    assert(mid);
+    jclsname = (*jni)->CallObjectMethod(jni, cls, mid);
+    clsname = (*jni)->GetStringUTFChars(jni, jclsname, 0);
+    mid = (*jni)->GetMethodID(jni, cls,
+                             "getMessage", "()L"STRCLS";");
+    jexcmsg = (*jni)->CallObjectMethod(jni, exc, mid);
+    excmsg = (*jni)->GetStringUTFChars(jni, jexcmsg, 0);
+    a_warn("TUN", "-", "java", "get-tunnel-fd-failed",
+          "%s", clsname, "%s", excmsg, A_END);
+    (*jni)->ReleaseStringUTFChars(jni, jclsname, clsname);
+    (*jni)->ReleaseStringUTFChars(jni, jexcmsg, excmsg);
+    (*jni)->ExceptionClear(jni);
     goto end;
   }
 
-  INIT_NATIVE(conn, &conn);
-  fd = socket(PF_UNIX, SOCK_STREAM, 0); if (fd < 0) goto err;
-  nb = set_nonblocking(jni, fd, 1); if (nb < 0) goto end;
+  t = CREATE(tunnel);
+  t->ops = &tun_java;
+  t->p = p;
+  sel_initfile(&sel, &t->f, fd, SEL_READ, t_read, t);
 
-  sun.sun_family = AF_UNIX;
-  strcpy(sun.sun_path, (char *)pathstr);
-  if (!connect(fd, (struct sockaddr *)&sun, sizeof(sun))) goto connected;
-  else if (errno != EINPROGRESS) goto err;
+  if (!*ifn) {
+    *ifn = xmalloc(n + 5);
+    sprintf(*ifn, "vpn-%s", name);
+  }
 
-  maxfd = trig.rfd;
-  if (maxfd < fd) maxfd = fd;
-  for (;;) {
-    FD_ZERO(&rfds); FD_SET(trig.rfd, &rfds);
-    FD_ZERO(&wfds); FD_SET(fd, &wfds);
-    rc = select(maxfd + 1, &rfds, &wfds, 0, 0); if (rc < 0) goto err;
-    if (FD_ISSET(trig.rfd, &rfds)) goto end;
-    if (FD_ISSET(fd, &wfds)) {
-      sz = sizeof(sun);
-      if (!getpeername(fd, (struct sockaddr *)&sun, &sz)) goto connected;
-      else if (errno != ENOTCONN) goto err;
-      sz = sizeof(err);
-      if (!getsockopt(fd, SOL_SOCKET, SO_ERROR, &err, &sz)) errno = err;
-      goto err;
-    }
+end:
+  return (t);
+}
+
+static void t_inject(tunnel *t, buf *b)
+{
+  IF_TRACING(T_TUNNEL, {
+    trace(T_TUNNEL, "tun-java: inject decrypted packet");
+    trace_block(T_PACKET, "tunnel: packet contents", BBASE(b), BLEN(b));
+  })
+  DISCARD(write(t->f.fd, BBASE(b), BLEN(b)));
+}
+
+static void t_destroy(tunnel *t)
+  { sel_rmfile(&t->f); close(t->f.fd); DESTROY(t); }
+
+static const struct tunnel_ops tun_java = {
+  "java", 0,
+  /*      init */ t_init,
+  /*    create */ t_create,
+  /* setifname */ 0,
+  /*    inject */ t_inject,
+  /*   destroy */ t_destroy
+};
+
+
+JNIEXPORT jint JNICALL JNIFUNC(open_1tun)(JNIEnv *jni, jobject cls)
+{
+  int ret = -1;
+  int fd = -1;
+  struct ifreq iff;
+
+  if ((fd = open("/dev/net/tun", O_RDWR)) < 0) {
+    except_syserror(jni, SYSERR, errno, "failed to open tunnel device");
+    goto end;
   }
 
-connected:
-  if (set_nonblocking(jni, fd, nb) < 0) goto end;
-  conn.fd = fd; fd = -1;
-  conn.f = 0;
-  ret = wrap(jni, &conn_type, &conn);
-  goto end;
+  if (set_nonblocking(jni, fd, 1) || set_closeonexec(jni, fd)) goto end;
+
+  memset(&iff, 0, sizeof(iff));
+  iff.ifr_name[0] = 0;
+  iff.ifr_flags = IFF_TUN | IFF_NO_PI;
+  if (ioctl(fd, TUNSETIFF, &iff) < 0) {
+    except_syserror(jni, SYSERR, errno, "failed to configure tunnel device");
+    goto end;
+  }
+
+  ret = fd; fd = -1;
 
-err:
-  except_syserror(jni, SYSERR, errno,
-                 "failed to connect to Unix-domain socket `%s'", pathstr);
 end:
   if (fd != -1) close(fd);
-  put_cstring(jni, path, pathstr);
   return (ret);
 }
 
+/*----- A custom noise source ---------------------------------------------*/
+
+static void javanoise(rand_pool *r)
+{
+  JNIEnv *jni = jni_tripe;
+  jclass cls;
+  jmethodID mid;
+  jbyteArray v;
+  jbyte *p;
+  jsize n;
+
+  noise_devrandom(r);
+
+  assert(jni);
+  cls = (*jni)->FindClass(jni, RANDCLS); assert(cls);
+  mid = (*jni)->GetStaticMethodID(jni, cls, "getSeed", "(I)[B"); assert(mid);
+  v = (*jni)->CallStaticObjectMethod(jni, cls, mid, 32);
+  if (v) {
+    n = (*jni)->GetArrayLength(jni, v);
+    p = (*jni)->GetByteArrayElements(jni, v, 0);
+    rand_add(r, p, n, n);
+    (*jni)->ReleaseByteArrayElements(jni, v, p, JNI_ABORT);
+  }
+  if ((*jni)->ExceptionOccurred(jni)) {
+    (*jni)->ExceptionDescribe(jni);
+    (*jni)->ExceptionClear(jni);
+  }
+}
+
+static const rand_source javasource = { javanoise, noise_timer };
+
+/*----- Embedding the TrIPE server ----------------------------------------*/
+
+static void lock_tripe(JNIEnv *jni)
+{
+  jclass cls = (*jni)->FindClass(jni, LOCKCLS); assert(cls);
+  (*jni)->MonitorEnter(jni, cls);
+}
+
+static void unlock_tripe(JNIEnv *jni)
+{
+  jclass cls = (*jni)->FindClass(jni, LOCKCLS); assert(cls);
+  (*jni)->MonitorExit(jni, cls);
+}
+
+#define STATES(_)                                                      \
+       _(INIT)                                                         \
+       _(RESOLVE)                                                      \
+       _(KEYS)                                                         \
+       _(BIND)                                                         \
+       _(READY)                                                        \
+       _(RUNNING)
+
+enum {
+#define DEFTAG(st) st,
+  STATES(DEFTAG)
+#undef DEFTAG
+  MAXSTATE
+};
+
+static const char *statetab[] = {
+#define DEFNAME(st) #st,
+  STATES(DEFNAME)
+#undef DEFNAME
+};
+
+static unsigned state = INIT;
+static int clientsk = -1;
+
+static const char *statename(unsigned st)
+{
+  if (st >= MAXSTATE) return ("<invalid>");
+  else return (statetab[st]);
+}
+
+static int ensure_state(JNIEnv *jni, unsigned want)
+{
+  unsigned cur;
+
+  lock_tripe(jni);
+  cur = state;
+  unlock_tripe(jni);
+
+  if (cur != want) {
+    except(jni, STERR, "server is in state %s (%u), not %s (%u)",
+          statename(cur), cur, statename(want), want);
+    return (-1);
+  }
+  return (0);
+}
+
+JNIEXPORT void JNICALL JNIFUNC(base_1init)(JNIEnv *jni, jobject cls)
+{
+  int fd[2];
+  int i;
+
+  for (i = 0; i < N(fd); i++) fd[i] = -1;
+
+  lock_tripe(jni);
+  jni_tripe = jni;
+  if (ensure_state(jni, INIT)) goto end;
+
+  if (socketpair(PF_UNIX, SOCK_STREAM, 0, fd)) {
+    except_syserror(jni, SYSERR, errno, "failed to create socket pair");
+    goto end;
+  }
+
+  clientsk = fd[0]; fd[0] = -1;
+
+  rand_noisesrc(RAND_GLOBAL, &javasource);
+  rand_seed(RAND_GLOBAL, MAXHASHSZ);
+  lp_init();
+  a_create(fd[1], fd[1], AF_NOTE | AF_WARN | AF_TRACE); fd[1] = -1;
+  a_switcherr();
+  p_addtun(&tun_java); p_setdflttun(&tun_java);
+  p_init();
+  kx_init();
+
+  state++;
+
+end:
+  for (i = 0; i < N(fd); i++) if (fd[i] != -1) close(fd[i]);
+  jni_tripe = 0;
+  unlock_tripe(jni);
+}
+
+JNIEXPORT void JNICALL JNIFUNC(setup_1resolver)(JNIEnv *jni, jobject cls)
+{
+  lock_tripe(jni);
+  if (ensure_state(jni, RESOLVE)) goto end;
+
+  if (a_init())
+    { except(jni, INITERR, "failed to initialize resolver"); return; }
+
+  state++;
+
+end:
+  unlock_tripe(jni);
+}
+
+JNIEXPORT void JNICALL JNIFUNC(load_1keys)(JNIEnv *jni, jobject cls,
+                                         jobject privstr, jobject pubstr,
+                                         jobject tagstr)
+{
+  const char *priv = 0, *pub = 0, *tag = 0;
+
+  lock_tripe(jni);
+  if (ensure_state(jni, KEYS)) return;
+
+  priv = get_cstring(jni, privstr); if (!priv) goto end;
+  pub = get_cstring(jni, pubstr); if (!pub) goto end;
+  tag = get_cstring(jni, tagstr); if (!tag) goto end;
+
+  if (km_init(priv, pub, tag))
+    { except(jni, INITERR, "failed to load initial keys"); goto end; }
+
+  state++;
+
+end:
+  put_cstring(jni, privstr, priv);
+  put_cstring(jni, pubstr, pub);
+  put_cstring(jni, tagstr, tag);
+  unlock_tripe(jni);
+}
+
+JNIEXPORT void JNICALL JNIFUNC(unload_1keys)(JNIEnv *jni, jobject cls)
+{
+  lock_tripe(jni);
+  if (ensure_state(jni, KEYS + 1)) goto end;
+
+  km_clear();
+
+  state--;
+
+end:
+  unlock_tripe(jni);
+}
+
+JNIEXPORT void JNICALL JNIFUNC(bind)(JNIEnv *jni, jobject cls,
+                                    jbyteArray hoststr, jbyteArray svcstr)
+{
+  const char *host = 0, *svc = 0;
+  struct addrinfo hint, *ai = 0;
+  int err;
+
+  lock_tripe(jni);
+  if (ensure_state(jni, BIND)) goto end;
+
+  if (hoststr) { host = get_cstring(jni, hoststr); if (!host) goto end; }
+  svc = get_cstring(jni, svcstr); if (!svc) goto end;
+
+  hint.ai_socktype = SOCK_DGRAM;
+  hint.ai_family = AF_UNSPEC;
+  hint.ai_protocol = IPPROTO_UDP;
+  hint.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
+  err = getaddrinfo(host, svc, &hint, &ai);
+  if (err) {
+    except(jni, NAMEERR, "failed to resolve %c%s%c, port `%s': %s",
+          host ? '`' : '<', host ? host : "nil", host ? '\'' : '>',
+          svc, gai_strerror(err));
+    goto end;
+  }
+
+  if (p_bind(ai))
+    { except(jni, INITERR, "failed to bind master socket"); goto end; }
+
+  state++;
+
+end:
+  if (ai) freeaddrinfo(ai);
+  put_cstring(jni, hoststr, host);
+  put_cstring(jni, svcstr, svc);
+  unlock_tripe(jni);
+}
+
+JNIEXPORT void JNICALL JNIFUNC(unbind)(JNIEnv *jni, jobject cls)
+{
+  lock_tripe(jni);
+  if (ensure_state(jni, BIND + 1)) goto end;
+
+  p_unbind();
+
+  state--;
+
+end:
+  unlock_tripe(jni);
+}
+
+JNIEXPORT void JNICALL JNIFUNC(mark)(JNIEnv *jni, jobject cls, jint seq)
+{
+  lock_tripe(jni);
+  a_notify("MARK", "%d", seq, A_END);
+  unlock_tripe(jni);
+}
+
+JNIEXPORT void JNICALL JNIFUNC(run)(JNIEnv *jni, jobject cls)
+{
+  lock_tripe(jni);
+  if (ensure_state(jni, READY)) goto end;
+  assert(!jni_tripe);
+  jni_tripe = jni;
+  state = RUNNING;
+  unlock_tripe(jni);
+
+  lp_run();
+
+  lock_tripe(jni);
+  jni_tripe = 0;
+  state = READY;
+
+end:
+  unlock_tripe(jni);
+}
+
 static int check_buffer_bounds(JNIEnv *jni, const char *what,
                               jbyteArray buf, jint start, jint len)
 {
@@ -1324,18 +1664,18 @@ static int check_buffer_bounds(JNIEnv *jni, const char *what,
 }
 
 JNIEXPORT void JNICALL JNIFUNC(send)(JNIEnv *jni, jobject cls,
-                                    wrapper wconn, jbyteArray buf,
+                                    jbyteArray buf,
                                     jint start, jint len,
                                     wrapper wtrig)
 {
-  struct conn conn;
   struct trigger trig;
   int rc, maxfd;
   ssize_t n;
   fd_set rfds, wfds;
   jbyte *p = 0;
 
-  if (unwrap(jni, &conn, &conn_type, wconn)) goto end;
+  if (ensure_state(jni, RUNNING)) goto end;
+
   if (unwrap(jni, &trig, &trigger_type, wtrig)) goto end;
   if (check_buffer_bounds(jni, "send", buf, start, len)) goto end;
 
@@ -1343,14 +1683,14 @@ JNIEXPORT void JNICALL JNIFUNC(send)(JNIEnv *jni, jobject cls,
   if (!p) goto end;
 
   maxfd = trig.rfd;
-  if (maxfd < conn.fd) maxfd = conn.fd;
+  if (maxfd < clientsk) maxfd = clientsk;
   while (len) {
     FD_ZERO(&rfds); FD_SET(trig.rfd, &rfds);
-    FD_ZERO(&wfds); FD_SET(conn.fd, &wfds);
+    FD_ZERO(&wfds); FD_SET(clientsk, &wfds);
     rc = select(maxfd + 1, &rfds, &wfds, 0, 0); if (rc < 0) goto err;
     if (FD_ISSET(trig.rfd, &rfds)) break;
-    if (FD_ISSET(conn.fd, &wfds)) {
-      n = send(conn.fd, p + start, len, 0);
+    if (FD_ISSET(clientsk, &wfds)) {
+      n = send(clientsk, p + start, len, 0);
       if (n >= 0) { start += n; len -= n; }
       else if (errno != EAGAIN && errno != EWOULDBLOCK) goto err;
     }
@@ -1365,18 +1705,24 @@ end:
 }
 
 JNIEXPORT jint JNICALL JNIFUNC(recv)(JNIEnv *jni, jobject cls,
-                                    wrapper wconn, jbyteArray buf,
+                                    jbyteArray buf,
                                     jint start, jint len,
                                     wrapper wtrig)
 {
-  struct conn conn;
   struct trigger trig;
   int maxfd;
   fd_set rfds;
   jbyte *p = 0;
   jint rc = -1;
 
-  if (unwrap(jni, &conn, &conn_type, wconn)) goto end;
+  lock_tripe(jni);
+  if (clientsk == -1) {
+    except(jni, STERR, "client connection not established");
+    unlock_tripe(jni);
+    goto end;
+  }
+  unlock_tripe(jni);
+
   if (unwrap(jni, &trig, &trigger_type, wtrig)) goto end;
   if (check_buffer_bounds(jni, "send", buf, start, len)) goto end;
 
@@ -1384,15 +1730,15 @@ JNIEXPORT jint JNICALL JNIFUNC(recv)(JNIEnv *jni, jobject cls,
   if (!p) goto end;
 
   maxfd = trig.rfd;
-  if (maxfd < conn.fd) maxfd = conn.fd;
+  if (maxfd < clientsk) maxfd = clientsk;
   for (;;) {
-    FD_ZERO(&rfds); FD_SET(trig.rfd, &rfds); FD_SET(conn.fd, &rfds);
+    FD_ZERO(&rfds); FD_SET(trig.rfd, &rfds); FD_SET(clientsk, &rfds);
     rc = select(maxfd + 1, &rfds, 0, 0, 0); if (rc < 0) goto err;
     if (FD_ISSET(trig.rfd, &rfds)) {
       break;
     }
-    if (FD_ISSET(conn.fd, &rfds)) {
-      rc = recv(conn.fd, p + start, len, 0);
+    if (FD_ISSET(clientsk, &rfds)) {
+      rc = recv(clientsk, p + start, len, 0);
       if (rc >= 0) break;
       else if (errno != EAGAIN && errno != EWOULDBLOCK) goto err;
     }
@@ -1407,28 +1753,4 @@ end:
   return (rc);
 }
 
-JNIEXPORT void JNICALL JNIFUNC(closeconn)(JNIEnv *jni, jobject cls,
-                                         wrapper wconn, jint how)
-{
-  struct conn conn;
-  int rc;
-
-  if (unwrap(jni, &conn, &conn_type, wconn)) goto end;
-  if (conn.fd == -1) goto end;
-
-  how &= CF_CLOSEMASK&~conn.f;
-  conn.f |= how;
-  if ((conn.f&CF_CLOSEMASK) == CF_CLOSEMASK) {
-    close(conn.fd);
-    conn.fd = -1;
-  } else {
-    if (how&CF_CLOSERD) shutdown(conn.fd, SHUT_RD);
-    if (how&CF_CLOSEWR) shutdown(conn.fd, SHUT_WR);
-  }
-  rc = update_wrapper(jni, &conn_type, wconn, &conn); assert(!rc);
-
-end:
-  return;
-}
-
 /*----- That's all, folks -------------------------------------------------*/
index b9595ec..bdbede9 100644 (file)
@@ -120,7 +120,7 @@ private val DEFAULTS: Seq[(String, Config => String)] =
       "sig-fresh" -> { _ => "always" },
       "fingerprint-hash" -> { _("hash") });
 
-private def parseConfig(file: File): Config = {
+private def parseConfig(file: File): HashMap[String, String] = {
 
   /* Build the new configuration in a temporary place. */
   var m = HashMap[String, String]();
@@ -131,7 +131,7 @@ private def parseConfig(file: File): Config = {
     for (line <- lines(in)) {
       line match {
        case RX_COMMENT() => ok;
-       case RX_KEYVAL(key, value) => m += key -> value;
+       case RX_KEYVAL(key, value) => m(key) = value;
        case _ =>
          throw new ConfigSyntaxError(file.getPath, lno,
                                      "failed to parse line");
@@ -150,7 +150,7 @@ private def readConfig(file: File): Config = {
   /* Fill in defaults where things have been missed out. */
   for ((key, dflt) <- DEFAULTS) {
     if (!(m contains key)) {
-      try { m += key -> dflt(m); }
+      try { m(key) = dflt(m); }
       catch {
        case e: DefaultFailed =>
          throw new ConfigDefaultFailed(file.getPath, key,
index 6931431..402bf1d 100644 (file)
--- a/sys.scala
+++ b/sys.scala
@@ -28,7 +28,7 @@ package uk.org.distorted.tripe; package object sys {
 /*----- Imports -----------------------------------------------------------*/
 
 import scala.collection.convert.decorateAsJava._;
-import scala.collection.mutable.HashSet;
+import scala.collection.mutable.{HashMap, HashSet};
 
 import java.io.{BufferedReader, BufferedWriter, Closeable, File,
                FileDescriptor, FileInputStream, FileOutputStream,
@@ -124,7 +124,7 @@ import StringImplicits._;
 /*----- Main code ---------------------------------------------------------*/
 
 /* Import the native code library. */
-System.loadLibrary("toy");
+System.loadLibrary("tripe");
 
 /* Native types.
  *
@@ -810,15 +810,15 @@ private final val maxTriggers = 2;
 private var nTriggers = 0;
 private var triggers: List[Wrapper] = Nil;
 
-@native protected def makeTrigger(): Wrapper;
-@native protected def destroyTrigger(trig: Wrapper);
-@native protected def resetTrigger(trig: Wrapper);
+@native protected def make_trigger(): Wrapper;
+@native protected def destroy_trigger(trig: Wrapper);
+@native protected def reset_trigger(trig: Wrapper);
 @native protected def trigger(trig: Wrapper);
 
 private def getTrigger(): Wrapper = {
   triggerLock synchronized {
     if (nTriggers == 0)
-      makeTrigger()
+      make_trigger()
     else {
       val trig = triggers.head;
       triggers = triggers.tail;
@@ -829,10 +829,10 @@ private def getTrigger(): Wrapper = {
 }
 
 private def putTrigger(trig: Wrapper) {
-  resetTrigger(trig);
+  reset_trigger(trig);
   triggerLock synchronized {
     if (nTriggers >= maxTriggers)
-      destroyTrigger(trig);
+      destroy_trigger(trig);
     else {
       triggers ::= trig;
       nTriggers += 1;
@@ -859,59 +859,69 @@ def interruptWithTrigger[T](body: Wrapper => T): T = {
   };
 }
 
-/*----- Connecting to a server --------------------------------------------*/
+/*----- Glue for the VPN server -------------------------------------------*/
 
-/* Primitive operations. */
-final val CF_CLOSERD = 1;
-final val CF_CLOSEWR = 2;
-final val CF_CLOSEMASK = CF_CLOSERD | CF_CLOSEWR;
-@native protected def connect(path: CString, trig: Wrapper): Wrapper;
-@native protected def send(conn: Wrapper, buf: CString,
-                          start: Int, len: Int, trig: Wrapper);
-@native protected def recv(conn: Wrapper, buf: CString,
-                          start: Int, len: Int, trig: Wrapper): Int;
-@native def closeconn(conn: Wrapper, how: Int);
-
-class Connection(path: String) extends Closeable {
-
-  /* The underlying primitive connection. */
-  private[this] val conn = interruptWithTrigger { trig =>
-    connect(path.toCString, trig);
-  };
-
-  /* Alternative constructors. */
-  def this(file: File) { this(file.getPath); }
+/* The lock class.  This is only a class because they're much easier to find
+ * than loose objects through JNI.
+ */
+private class ServerLock;
 
-  /* Cleanup.*/
-  override def close() { closeconn(conn, CF_CLOSEMASK); }
-  override protected def finalize() { super.finalize(); close(); }
+/* Exceptions. */
+class NameResolutionException(msg: String) extends Exception(msg);
+class InitializationException(msg: String) extends Exception(msg);
 
-  class Input private[Connection] extends InputStream {
-    /* An input stream which reads from the connection. */
+/* Primitive operations. */
+@native protected def open_tun(): Int;
+@native protected def base_init();
+@native protected def setup_resolver();
+@native def load_keys(priv: CString, pub: CString, tag: CString);
+@native def unload_keys();
+@native def bind(host: CString, svc: CString);
+@native def unbind();
+@native def mark(seq: Int);
+@native def run();
+@native protected def send(buf: CString, start: Int, len: Int,
+                          trig: Wrapper);
+@native protected def recv(buf: CString, start: Int, len: Int,
+                          trig: Wrapper): Int;
+
+base_init();
+setup_resolver();
+
+/* Tunnel descriptor plumbing. */
+val pending = HashMap[String, Int]();
+
+def getTunnelFd(peer: CString): Int =
+  pending synchronized { pending(peer.toJString) };
+def storeTunnelFd(peer: String, fd: Int)
+  { pending synchronized { pending(peer) = fd; } }
+def withdrawTunnelFd(peer: String)
+  { pending synchronized { pending -= peer; } }
+def withTunnelFd[T](peer: String, fd: Int)(body: => T): T = {
+  storeTunnelFd(peer, fd);
+  try { body } finally { withdrawTunnelFd(peer); }
+}
 
-    override def read(): Int = {
-      val buf = new Array[Byte](1);
-      val n = read(buf, 0, 1);
-      if (n < 0) -1 else buf(0)&0xff;
-    }
-    override def read(buf: Array[Byte]): Int =
-      read(buf, 0, buf.length);
-    override def read(buf: Array[Byte], start: Int, len: Int) =
-      interruptWithTrigger { trig => recv(conn, buf, start, len, trig); };
-    override def close() { closeconn(conn, CF_CLOSERD); }
+/* Server I/O. */
+lazy val serverInput: InputStream = new InputStream {
+  override def read(): Int = {
+    val buf = new Array[Byte](1);
+    val n = read(buf, 0, 1);
+    if (n < 0) -1 else buf(0)&0xff;
   }
-  lazy val input = new Input;
-
-  class Output private[Connection] extends OutputStream {
-    /* An output stream which writes to the connection. */
+  override def read(buf: Array[Byte]): Int =
+    read(buf, 0, buf.length);
+  override def read(buf: Array[Byte], start: Int, len: Int) =
+    interruptWithTrigger { trig => recv(buf, start, len, trig); };
+  override def close() { }
+}
 
-    override def write(b: Int) { write(Array[Byte](b.toByte), 0, 1); }
-    override def write(buf: Array[Byte]) { write(buf, 0, buf.length); }
-    override def write(buf: Array[Byte], start: Int, len: Int)
-      { interruptWithTrigger { trig => send(conn, buf, start, len, trig); } }
-    override def close() { closeconn(conn, CF_CLOSEWR); }
-  }
-  lazy val output = new Output;
+lazy val serverOutput: OutputStream = new OutputStream {
+  override def write(b: Int) { write(Array[Byte](b.toByte), 0, 1); }
+  override def write(buf: Array[Byte]) { write(buf, 0, buf.length); }
+  override def write(buf: Array[Byte], start: Int, len: Int)
+  { interruptWithTrigger { trig => send(buf, start, len, trig); } }
+  override def close() { }
 }
 
 /*----- Crypto-library hacks ----------------------------------------------*/