chiark / gitweb /
saner memory management (and better performance), improved handling of case when...
authorstevenj <stevenj@alum.mit.edu>
Wed, 29 Aug 2007 04:45:08 +0000 (00:45 -0400)
committerstevenj <stevenj@alum.mit.edu>
Wed, 29 Aug 2007 04:45:08 +0000 (00:45 -0400)
darcs-hash:20070829044508-c8de0-635e1db270782d6b6741c27994cdc596c4cdde2c.gz

cdirect/cdirect.c
cdirect/redblack.c
cdirect/redblack.h
cdirect/redblack_test.c

index 05ec2c71e5af9633d34d3629f5421127d2fb55e3..5b7200c545db72b76fe062fb1275dda767cf4af7 100644 (file)
@@ -33,7 +33,6 @@
 typedef struct {
      int n; /* dimension */
      int L; /* RECT_LEN(n) */
-     double *rects; /* the hyper-rectangles */
      double magic_eps; /* Jones' epsilon parameter (1e-4 is recommended) */
      int which_diam; /* which measure of hyper-rectangle diam to use:
                        0 = Jones, 1 = Gablonsky */
@@ -45,12 +44,14 @@ typedef struct {
      nlopt_stopping *stop; /* stopping criteria */
      nlopt_func f; void *f_data;
      double *work; /* workspace, of length >= 2*n */
-     int *iwork, iwork_len; /* workspace, of length iwork_len >= n */
+     int *iwork; /* workspace, length >= n */
      double fmin, *xmin; /* minimum so far */
      
-     /* red-black tree of hyperrect indices, sorted by (d,f) in
+     /* red-black tree of hyperrects, sorted by (d,f) in
        lexographical order */
      rb_tree rtree;
+     double **hull; /* array to store convex hull */
+     int hull_len; /* allocated length of hull array */
 } params;
 
 /***************************************************************************/
@@ -74,16 +75,7 @@ static double rect_diameter(int n, const double *w, const params *p)
      }
 }
 
-static double *alloc_rects(int n, int *Na, double *rects, int newN)
-{
-     if (newN <= *Na)
-         return rects;
-     else {
-         (*Na) += newN;
-         return realloc(rects, sizeof(double) * RECT_LEN(n) * (*Na));
-     }
-}
-#define ALLOC_RECTS(n, Nap, rects, newN) if (!(rects = alloc_rects(n, Nap, rects, newN))) return NLOPT_OUT_OF_MEMORY
+#define ALLOC_RECT(rect, L) if (!(rect = (double*) malloc(sizeof(double)*(L)))) return NLOPT_OUT_OF_MEMORY
 
 static double *fv_qsort = 0;
 static int sort_fv_compare(const void *a_, const void *b_)
@@ -116,20 +108,19 @@ static double function_eval(const double *x, params *p) {
      p->stop->nevals++;
      return f;
 }
-#define FUNCTION_EVAL(fv,x,p) fv = function_eval(x, p); if (p->fmin < p->stop->fmin_max) return NLOPT_FMIN_MAX_REACHED; else if (nlopt_stop_evals((p)->stop)) return NLOPT_MAXEVAL_REACHED; else if (nlopt_stop_time((p)->stop)) return NLOPT_MAXTIME_REACHED
+#define FUNCTION_EVAL(fv,x,p,freeonerr) fv = function_eval(x, p); if (p->fmin < p->stop->fmin_max) { free(freeonerr); return NLOPT_FMIN_MAX_REACHED; } else if (nlopt_stop_evals((p)->stop)) { free(freeonerr); return NLOPT_MAXEVAL_REACHED; } else if (nlopt_stop_time((p)->stop)) { free(freeonerr); return NLOPT_MAXTIME_REACHED; }
 
 #define THIRD (0.3333333333333333333333)
 
 #define EQUAL_SIDE_TOL 5e-2 /* tolerance to equate side sizes */
 
 /* divide rectangle idiv in the list p->rects */
-static nlopt_result divide_rect(int *N, int *Na, int idiv, params *p)
+static nlopt_result divide_rect(double *rdiv, params *p)
 {
      int i;
      const const int n = p->n;
      const int L = p->L;
-     double *r = p->rects;
-     double *c = r + L*idiv + 2; /* center of rect to divide */
+     double *c = rdiv + 2; /* center of rect to divide */
      double *w = c + n; /* widths of rect to divide */
      double wmax = w[0];
      int imax = 0, nlongest = 0;
@@ -150,9 +141,9 @@ static nlopt_result divide_rect(int *N, int *Na, int idiv, params *p)
               if (wmax - w[i] <= wmax * EQUAL_SIDE_TOL) {
                    double csave = c[i];
                    c[i] = csave - w[i] * THIRD;
-                   FUNCTION_EVAL(fv[2*i], c, p);
+                   FUNCTION_EVAL(fv[2*i], c, p, 0);
                    c[i] = csave + w[i] * THIRD;
-                   FUNCTION_EVAL(fv[2*i+1], c, p);
+                   FUNCTION_EVAL(fv[2*i+1], c, p, 0);
                    c[i] = csave;
               }
               else {
@@ -160,23 +151,23 @@ static nlopt_result divide_rect(int *N, int *Na, int idiv, params *p)
               }
          }
          sort_fv(n, fv, isort);
-         ALLOC_RECTS(n, Na, r, (*N)+2*nlongest); 
-         p->rects = r; c = r + L*idiv + 2; w = c + n;
+         if (!(node = rb_tree_find(&p->rtree, rdiv)))
+              return NLOPT_FAILURE;
          for (i = 0; i < nlongest; ++i) {
               int k;
-              if (!(node = rb_tree_find_exact(&p->rtree, idiv)))
-                   return NLOPT_FAILURE;
               w[isort[i]] *= THIRD;
-              r[L*idiv] = rect_diameter(n, w, p);
-              rb_tree_resort(&p->rtree, node);
+              rdiv[0] = rect_diameter(n, w, p);
+              node = rb_tree_resort(&p->rtree, node);
               for (k = 0; k <= 1; ++k) {
-                   r[L*(*N)] = r[L*idiv];
-                   memcpy(r + L*(*N) + 2, c, sizeof(double) * 2*n);
-                   r[L*(*N) + 2 + isort[i]] += w[isort[i]] * (2*k-1);
-                   r[L*(*N) + 1] = fv[2*isort[i]+k];
-                   if (!rb_tree_insert(&p->rtree, *N))
+                   double *rnew;
+                   ALLOC_RECT(rnew, L);
+                   memcpy(rnew, rdiv, sizeof(double) * L);
+                   rnew[2 + isort[i]] += w[isort[i]] * (2*k-1);
+                   rnew[1] = fv[2*isort[i]+k];
+                   if (!rb_tree_insert(&p->rtree, rnew)) {
+                        free(rnew);
                         return NLOPT_FAILURE;
-                   ++(*N);
+                   }
               }
          }
      }
@@ -193,21 +184,21 @@ static nlopt_result divide_rect(int *N, int *Na, int idiv, params *p)
          }
          else
               i = imax; /* trisect longest side */
-         ALLOC_RECTS(n, Na, r, (*N)+2);
-          p->rects = r; c = r + L*idiv + 2; w = c + n;
-         if (!(node = rb_tree_find_exact(&p->rtree, idiv)))
+         if (!(node = rb_tree_find(&p->rtree, rdiv)))
               return NLOPT_FAILURE;
          w[i] *= THIRD;
-         r[L*idiv] = rect_diameter(n, w, p);
-         rb_tree_resort(&p->rtree, node);
+         rdiv[0] = rect_diameter(n, w, p);
+         node = rb_tree_resort(&p->rtree, node);
          for (k = 0; k <= 1; ++k) {
-              r[L*(*N)] = r[L*idiv];
-              memcpy(r + L*(*N) + 2, c, sizeof(double) * 2*n);
-              r[L*(*N) + 2 + i] += w[i] * (2*k-1);
-              FUNCTION_EVAL(r[L*(*N) + 1], r + L*(*N) + 2, p);
-              if (!rb_tree_insert(&p->rtree, *N))
+              double *rnew;
+              ALLOC_RECT(rnew, L);
+              memcpy(rnew, rdiv, sizeof(double) * L);
+              rnew[2 + i] += w[i] * (2*k-1);
+              FUNCTION_EVAL(rnew[1], rnew + 2, p, rnew);
+              if (!rb_tree_insert(&p->rtree, rnew)) {
+                   free(rnew);
                    return NLOPT_FAILURE;
-              ++(*N);
+              }
          }
      }
      return NLOPT_SUCCESS;
@@ -217,65 +208,62 @@ static nlopt_result divide_rect(int *N, int *Na, int idiv, params *p)
 /* O(N log N) convex hull algorithm, used later to find the potentially
    optimal points */
 
-/* Find the lower convex hull of a set of points (xy[s*i], xy[s*i+1]), where
-   0 <= i < N and s >= 2.
+/* Find the lower convex hull of a set of points (x,y) stored in a rb-tree
+   of pointers to {x,y} arrays sorted in lexographic order by (x,y).
 
    Unlike standard convex hulls, we allow redundant points on the hull.
 
-   The return value is the number of points in the hull, with indices
-   stored in ihull.  ihull should point to arrays of length >= N.
-   rb_tree should be a red-black tree of indices (keys == i) sorted
-   in lexographic order by (xy[s*i], xy[s*i+1]).
+   The return value is the number of points in the hull, with pointers
+   stored in hull[i] (should be an array of length >= t->N).
 */
-static int convex_hull(int N, double *xy, int s, int *ihull, rb_tree *t)
+static int convex_hull(rb_tree *t, double **hull)
 {
-     int nhull;
+     int nhull = 0;
      double minslope;
      double xmin, xmax, yminmin, ymaxmin;
      rb_node *n, *nmax;
 
-     if (N == 0) return 0;
-     
      /* Monotone chain algorithm [Andrew, 1979]. */
 
      n = rb_tree_min(t);
+     if (!n) return 0;
      nmax = rb_tree_max(t);
 
-     xmin = xy[s*(n->k)];
-     yminmin = xy[s*(n->k)+1];
-     xmax = xy[s*(nmax->k)];
+     xmin = n->k[0];
+     yminmin = n->k[1];
+     xmax = nmax->k[0];
 
-     ihull[nhull = 1] = n->k;
+     hull[nhull++] = n->k;
      if (xmin == xmax) return nhull;
 
      /* set nmax = min mode with x == xmax */
-     while (xy[s*(nmax->k)] == xmax)
-         nmax = rb_tree_pred(t, nmax); /* non-NULL since xmin != xmax */
-     nmax = rb_tree_succ(t, nmax);
+     while (nmax->k[0] == xmax)
+         nmax = rb_tree_pred(nmax); /* non-NULL since xmin != xmax */
+     nmax = rb_tree_succ(nmax);
 
-     ymaxmin = xy[s*(nmax->k)+1];
+     ymaxmin = nmax->k[1];
      minslope = (ymaxmin - yminmin) / (xmax - xmin);
 
      /* set n = first node with x != xmin */
-     while (xy[s*(n->k)] == xmin)
-         n = rb_tree_succ(t, n); /* non-NULL since xmin != xmax */
+     while (n->k[0] == xmin)
+         n = rb_tree_succ(n); /* non-NULL since xmin != xmax */
 
-     for (; n != nmax; n = rb_tree_succ(t, n)) { 
-         int k = n->k;
-         if (xy[s*k+1] > yminmin + (xy[s*k] - xmin) * minslope)
+     for (; n != nmax; n = rb_tree_succ(n)) { 
+         double *k = n->k;
+         if (k[1] > yminmin + (k[0] - xmin) * minslope)
               continue;
          /* remove points until we are making a "left turn" to k */
          while (nhull > 1) {
-              int t1 = ihull[nhull - 1], t2 = ihull[nhull - 2];
+              double *t1 = hull[nhull - 1], *t2 = hull[nhull - 2];
               /* cross product (t1-t2) x (k-t2) > 0 for a left turn: */
-              if ((xy[s*t1]-xy[s*t2]) * (xy[s*k+1]-xy[s*t2+1])
-                  - (xy[s*t1+1]-xy[s*t2+1]) * (xy[s*k]-xy[s*t2]) >= 0)
+              if ((t1[0]-t2[0]) * (k[1]-t2[1])
+                  - (t1[1]-t2[1]) * (k[0]-t2[0]) >= 0)
                    break;
               --nhull;
          }
-         ihull[nhull++] = k;
+         hull[nhull++] = k;
      }
-     ihull[nhull++] = nmax->k;
+     hull[nhull++] = nmax->k;
      return nhull;
 }
 
@@ -291,41 +279,34 @@ static int small(double *w, params *p)
      return 1;
 }
 
-static nlopt_result divide_good_rects(int *N, int *Na, params *p)
+static nlopt_result divide_good_rects(params *p)
 {
      const int n = p->n;
-     const int L = p->L;
-     int *ihull, nhull, i, xtol_reached = 1, divided_some = 0;
-     double *r = p->rects;
+     double **hull;
+     int nhull, i, xtol_reached = 1, divided_some = 0;
      double magic_eps = p->magic_eps;
 
-     if (p->iwork_len < *N) {
-         p->iwork_len = p->iwork_len + *N;
-         p->iwork = (int *) realloc(p->iwork, sizeof(int) * p->iwork_len);
-         if (!p->iwork)
-              return NLOPT_OUT_OF_MEMORY;
+     if (p->hull_len < p->rtree.N) {
+         p->hull_len += p->rtree.N;
+         p->hull = (double **) realloc(p->hull, sizeof(double*)*p->hull_len);
+         if (!p->hull) return NLOPT_OUT_OF_MEMORY;
      }
-     ihull = p->iwork;
-     nhull = convex_hull(*N, r, L, ihull, &p->rtree);
+     nhull = convex_hull(&p->rtree, hull = p->hull);
  divisions:
      for (i = 0; i < nhull; ++i) {
          double K1 = -HUGE_VAL, K2 = -HUGE_VAL, K;
          if (i > 0)
-              K1 = (r[L*ihull[i]+1] - r[L*ihull[i-1]+1]) /
-                   (r[L*ihull[i]] - r[L*ihull[i-1]]);
+              K1 = (hull[i][1] - hull[i-1][1]) / (hull[i][0] - hull[i-1][0]);
          if (i < nhull-1)
-              K1 = (r[L*ihull[i]+1] - r[L*ihull[i+1]+1]) /
-                   (r[L*ihull[i]] - r[L*ihull[i+1]]);
+              K1 = (hull[i][1] - hull[i+1][1]) / (hull[i][0] - hull[i+1][0]);
          K = MAX(K1, K2);
-         if (r[L*ihull[i]+1] - K * r[L*ihull[i]]
+         if (hull[i][1] - K * hull[i][0]
              <= p->fmin - magic_eps * fabs(p->fmin)) {
               /* "potentially optimal" rectangle, so subdivide */
               divided_some = 1;
-              nlopt_result ret;
-              ret = divide_rect(N, Na, ihull[i], p);
-              r = p->rects; /* may have grown */
+              nlopt_result ret = divide_rect(hull[i], p);
               if (ret != NLOPT_SUCCESS) return ret;
-              xtol_reached = xtol_reached && small(r + L*ihull[i] + 2+n, p);
+              xtol_reached = xtol_reached && small(hull[i] + 2+n, p);
          }
      }
      if (!divided_some) {
@@ -333,13 +314,19 @@ static nlopt_result divide_good_rects(int *N, int *Na, params *p)
               magic_eps = 0;
               goto divisions; /* try again */
          }
-         else { /* WTF? divide largest rectangle */
-              double wmax = r[0];
-              int imax = 0;
-              for (i = 1; i < *N; ++i)
-                   if (r[L*i] > wmax)
-                        wmax = r[L*(imax=i)];
-              return divide_rect(N, Na, imax, p);
+         else { /* WTF? divide largest rectangle with smallest f */
+              /* (note that this code actually gets called from time
+                 to time, and the heuristic here seems to work well,
+                 but I don't recall this situation being discussed in
+                 the references?) */
+              rb_node *max = rb_tree_max(&p->rtree);
+              rb_node *pred = max;
+              double wmax = max->k[0];
+              do { /* note: this loop is O(N) worst-case time */
+                   max = pred;
+                   pred = rb_tree_pred(max);
+              } while (pred && pred->k[0] == wmax);
+              return divide_rect(max->k, p);
          }
      }
      return xtol_reached ? NLOPT_XTOL_REACHED : NLOPT_SUCCESS;
@@ -348,18 +335,13 @@ static nlopt_result divide_good_rects(int *N, int *Na, params *p)
 /***************************************************************************/
 
 /* lexographic sort order (d,f) of hyper-rects, for red-black tree */
-static int hyperrect_compare(int i, int j, void *p_)
+static int hyperrect_compare(double *a, double *b)
 {
-     params *p = (params *) p_;
-     int L = p->L;
-     double *r = p->rects;
-     double di = r[i*L], dj = r[j*L], fi, fj;
-     if (di < dj) return -1;
-     if (dj < di) return +1;
-     fi = r[i*L+1]; fj = r[j*L+1];
-     if (fi < fj) return -1;
-     if (fj < fi) return +1;
-     return 0;
+     if (a[0] < b[0]) return -1;
+     if (a[0] > b[0]) return +1;
+     if (a[1] < b[1]) return -1;
+     if (a[1] > b[1]) return +1;
+     return (int) (a - b); /* tie-breaker */
 }
 
 /***************************************************************************/
@@ -372,7 +354,8 @@ nlopt_result cdirect_unscaled(int n, nlopt_func f, void *f_data,
                              double magic_eps, int which_alg)
 {
      params p;
-     int Na = 100, N = 1, i, x_center = 1;
+     int i, x_center = 1;
+     double *rnew;
      nlopt_result ret = NLOPT_OUT_OF_MEMORY;
 
      p.magic_eps = magic_eps;
@@ -388,38 +371,41 @@ nlopt_result cdirect_unscaled(int n, nlopt_func f, void *f_data,
      p.fmin = f(n, x, NULL, f_data); stop->nevals++;
      p.work = 0;
      p.iwork = 0;
-     p.rects = 0;
+     p.hull = 0;
+
+     rb_tree_init(&p.rtree, hyperrect_compare);
 
-     if (!rb_tree_init(&p.rtree, hyperrect_compare, &p)) goto done;
-     p.work = (double *) malloc(sizeof(double) * 2*n);
+     p.work = (double *) malloc(sizeof(double) * (2*n));
      if (!p.work) goto done;
-     p.rects = (double *) malloc(sizeof(double) * Na * RECT_LEN(n));
-     if (!p.rects) goto done;
-     p.iwork = (int *) malloc(sizeof(int) * (p.iwork_len = Na + n));
+     p.iwork = (int *) malloc(sizeof(int) * n);
      if (!p.iwork) goto done;
+     p.hull_len = 128; /* start with a reasonable number */
+     p.hull = (double **) malloc(sizeof(double *) * p.hull_len);
+     if (!p.hull) goto done;
 
+     if (!(rnew = (double *) malloc(sizeof(double) * p.L))) goto done;
      for (i = 0; i < n; ++i) {
-         p.rects[2+i] = 0.5 * (lb[i] + ub[i]);
+         rnew[2+i] = 0.5 * (lb[i] + ub[i]);
          x_center = x_center
-              && (fabs(p.rects[2+i]-x[i]) < 1e-13*(1+fabs(x[i])));
-         p.rects[2+n+i] = ub[i] - lb[i];
+              && (fabs(rnew[2+i]-x[i]) < 1e-13*(1+fabs(x[i])));
+         rnew[2+n+i] = ub[i] - lb[i];
      }
-     p.rects[0] = rect_diameter(n, p.rects+2+n, &p);
+     rnew[0] = rect_diameter(n, rnew+2+n, &p);
      if (x_center)
-         p.rects[1] = p.fmin; /* avoid computing f(center) twice */
+         rnew[1] = p.fmin; /* avoid computing f(center) twice */
      else
-         p.rects[1] = function_eval(p.rects+2, &p);
-     if (!rb_tree_insert(&p.rtree, 0)) {
-         ret = NLOPT_FAILURE;
+         rnew[1] = function_eval(rnew+2, &p);
+     if (!rb_tree_insert(&p.rtree, rnew)) {
+         free(rnew);
          goto done;
      }
 
-     ret = divide_rect(&N, &Na, 0, &p);
+     ret = divide_rect(rnew, &p);
      if (ret != NLOPT_SUCCESS) goto done;
 
      while (1) {
          double fmin0 = p.fmin;
-         ret = divide_good_rects(&N, &Na, &p);
+         ret = divide_good_rects(&p);
          if (ret != NLOPT_SUCCESS) goto done;
          if (nlopt_stop_f(p.stop, p.fmin, fmin0)) {
               ret = NLOPT_FTOL_REACHED;
@@ -428,9 +414,9 @@ nlopt_result cdirect_unscaled(int n, nlopt_func f, void *f_data,
      }
 
  done:
-     rb_tree_destroy(&p.rtree);
+     rb_tree_destroy_with_keys(&p.rtree);
+     free(p.hull);
      free(p.iwork);
-     free(p.rects);
      free(p.work);
              
      *fmin = p.fmin;
index 2ddcebba21731e8b958da7e593f02c77a0328d47..44faf2d2f9267e4e8264fc636a6a97f643f2b788 100644 (file)
@@ -4,62 +4,48 @@
 #include <stdlib.h>
 #include "redblack.h"
 
-int rb_tree_init(rb_tree *t, rb_compare compare, void *c_data) {
-     t->compare = compare; t->c_data = c_data;
-     t->nil.c = BLACK; t->nil.l = t->nil.r = t->nil.p = &t->nil; t->nil.k = -1;
-     t->root = &t->nil;
+/* it is convenient to use an explicit node for NULL nodes ... we need
+   to be careful never to change this node indirectly via one of our
+   pointers!  */
+rb_node nil = {&nil, &nil, &nil, 0, BLACK};
+#define NIL (&nil)
+
+void rb_tree_init(rb_tree *t, rb_compare compare) {
+     t->compare = compare;
+     t->root = NIL;
      t->N = 0;
-     t->Nalloc = 100; /* allocate some space to start with */
-     t->nodes = (rb_node*) malloc(sizeof(rb_node) * t->Nalloc);
-     return t->nodes != NULL;
+}
+
+static void destroy(rb_node *n)
+{
+     if (n != NIL) {
+         destroy(n->l); destroy(n->r);
+         free(n);
+     }
 }
 
 void rb_tree_destroy(rb_tree *t)
 {
-     t->root = 0; t->N = 0; t->Nalloc = 0;
-     free(t->nodes); t->nodes = 0;
+     destroy(t->root);
+     t->root = NIL;
 }
 
-/* in our application, we can optimize memory allocation because
-   we never delete two nodes in a row (we always add a node after
-   deleting)... or rather, we never delete but the value of
-   the key sometimes changes.  ... this means we can just
-   allocate a linear, exponentially growing stack (nodes) of
-   nodes, and don't have to worry about holes in the stack ...
-   otherwise, alloc1 should be replaced by an implementation that
-   malloc's each node separately */
-static rb_node *alloc1(rb_tree *t, int k)
+void rb_tree_destroy_with_keys(rb_tree *t)
 {
-     rb_node *nil = &t->nil;
-     rb_node *n;
-     if (t->Nalloc == t->N) { /* grow allocation */
-         rb_node *old_nodes = t->nodes;
-         ptrdiff_t change;
-         int i;
-         t->Nalloc = 2*t->Nalloc + 1;
-         t->nodes = (rb_node*) realloc(t->nodes, sizeof(rb_node) * t->Nalloc);
-         if (!t->nodes) return NULL;
-         change = t->nodes - old_nodes;
-         if (t->root != nil) t->root += change;
-         for (i = 0; i < t->N; ++i) { /* shift all pointers, ugh */
-              if (t->nodes[i].p != nil) t->nodes[i].p += change;
-              if (t->nodes[i].r != nil) t->nodes[i].r += change;
-              if (t->nodes[i].l != nil) t->nodes[i].l += change;
-         }
+     rb_node *n = rb_tree_min(t);
+     while (n) {
+         free(n->k); n->k = NULL;
+         n = rb_tree_succ(n);
      }
-     n = t->nodes + t->N++;
-     n->k = k;
-     n->p = n->l = n->r = nil;
-     return n;
+     rb_tree_destroy(t);
 }
 
 static void rotate_left(rb_node *p, rb_tree *t)
 {
-     rb_node *nil = &t->nil;
-     rb_node *n = p->r; /* must be non-NULL */
+     rb_node *n = p->r; /* must be non-NIL */
      p->r = n->l;
      n->l = p;
-     if (p->p != nil) {
+     if (p->p != NIL) {
          if (p == p->p->l) p->p->l = n;
          else p->p->r = n;
      }
@@ -67,16 +53,15 @@ static void rotate_left(rb_node *p, rb_tree *t)
          t->root = n;
      n->p = p->p;
      p->p = n;
-     if (p->r != nil) p->r->p = p;
+     if (p->r != NIL) p->r->p = p;
 }
 
 static void rotate_right(rb_node *p, rb_tree *t)
 {
-     rb_node *nil = &t->nil;
-     rb_node *n = p->l; /* must be non-NULL */
+     rb_node *n = p->l; /* must be non-NIL */
      p->l = n->r;
      n->r = p;
-     if (p->p != nil) {
+     if (p->p != NIL) {
          if (p == p->p->l) p->p->l = n;
          else p->p->r = n;
      }
@@ -84,26 +69,26 @@ static void rotate_right(rb_node *p, rb_tree *t)
          t->root = n;
      n->p = p->p;
      p->p = n;
-     if (p->l != nil) p->l->p = p;
+     if (p->l != NIL) p->l->p = p;
 }
 
 static void insert_node(rb_tree *t, rb_node *n)
 {
-     rb_node *nil = &t->nil;
      rb_compare compare = t->compare;
-     void *c_data = t->c_data;
-     int k = n->k;
+     rb_key k = n->k;
      rb_node *p = t->root;
      n->c = RED;
-     if (p == nil) {
+     n->p = n->l = n->r = NIL;
+     t->N++;
+     if (p == NIL) {
          t->root = n;
          n->c = BLACK;
          return;
      }
      /* insert (RED) node into tree */
      while (1) {
-         if (compare(k, p->k, c_data) <= 0) { /* k <= p->k */
-              if (p->l != nil)
+         if (compare(k, p->k) <= 0) { /* k <= p->k */
+              if (p->l != NIL)
                    p = p->l;
               else {
                    p->l = n;
@@ -112,7 +97,7 @@ static void insert_node(rb_tree *t, rb_node *n)
               }
          }
          else {
-              if (p->r != nil)
+              if (p->r != NIL)
                    p = p->r;
               else {
                    p->r = n;
@@ -124,10 +109,10 @@ static void insert_node(rb_tree *t, rb_node *n)
  fixtree:
      if (n->p->c == RED) { /* red cannot have red child */
          rb_node *u = p == p->p->l ? p->p->r : p->p->l;
-         if (u != nil && u->c == RED) {
+         if (u != NIL && u->c == RED) {
               p->c = u->c = BLACK;
               n = p->p;
-              if ((p = n->p) != nil) {
+              if ((p = n->p) != NIL) {
                    n->c = RED;
                    goto fixtree;
               }
@@ -152,93 +137,64 @@ static void insert_node(rb_tree *t, rb_node *n)
      }
 }
 
-int rb_tree_insert(rb_tree *t, int k)
+rb_node *rb_tree_insert(rb_tree *t, rb_key k)
 {
-     rb_node *n = alloc1(t, k);
-     if (!n) return 0;
+     rb_node *n = (rb_node *) malloc(sizeof(rb_node));
+     if (!n) return NULL;
+     n->k = k;
      insert_node(t, n);
-     return 1;
+     return n;
 }
 
 static int check_node(rb_node *n, int *nblack, rb_tree *t)
 {
-     rb_node *nil = &t->nil;
      int nbl, nbr;
      rb_compare compare = t->compare;
-     void *c_data = t->c_data;
-     if (n == nil) { *nblack = 0; return 1; }
-     if (n->r != nil && n->r->p != n) return 0;
-     if (n->r != nil && compare(n->r->k, n->k, c_data) < 0)
+     if (n == NIL) { *nblack = 0; return 1; }
+     if (n->r != NIL && n->r->p != n) return 0;
+     if (n->r != NIL && compare(n->r->k, n->k) < 0)
          return 0;
-     if (n->l != nil && n->l->p != n) return 0;
-     if (n->l != nil && compare(n->l->k, n->k, c_data) > 0)
+     if (n->l != NIL && n->l->p != n) return 0;
+     if (n->l != NIL && compare(n->l->k, n->k) > 0)
          return 0;
      if (n->c == RED) {
-         if (n->r != nil && n->r->c == RED) return 0;
-         if (n->l != nil && n->l->c == RED) return 0;
+         if (n->r != NIL && n->r->c == RED) return 0;
+         if (n->l != NIL && n->l->c == RED) return 0;
      }
      if (!(check_node(n->r, &nbl, t) && check_node(n->l, &nbr, t))) 
          return 0;
      if (nbl != nbr) return 0;
-     *nblack = nbl + n->c == BLACK;
+     *nblack = nbl + (n->c == BLACK);
      return 1;
 }
 int rb_tree_check(rb_tree *t)
 {
-     rb_node *nil = &t->nil;
      int nblack;
-     if (nil->c != BLACK) return 0;
-     if (t->root == nil) return 1;
+     if (nil.c != BLACK) return 0;
+     if (nil.p != NIL || nil.r != NIL || nil.l != NIL) return 0;
+     if (t->root == NIL) return 1;
      if (t->root->c != BLACK) return 0;
      return check_node(t->root, &nblack, t);
 }
 
-rb_node *rb_tree_find(rb_tree *t, int k)
+rb_node *rb_tree_find(rb_tree *t, rb_key k)
 {
-     rb_node *nil = &t->nil;
      rb_compare compare = t->compare;
-     void *c_data = t->c_data;
      rb_node *p = t->root;
-     while (p != nil) {
-         int comp = compare(k, p->k, c_data);
+     while (p != NIL) {
+         int comp = compare(k, p->k);
          if (!comp) return p;
          p = comp <= 0 ? p->l : p->r;
      }
      return NULL;
 }
 
-/* like rb_tree_find, but guarantees that returned node n will have
-   n->k == k (may not be true above if compare(k,k') == 0 for some k != k') */
-rb_node *rb_tree_find_exact(rb_tree *t, int k)
-{
-     rb_node *nil = &t->nil;
-     rb_compare compare = t->compare;
-     void *c_data = t->c_data;
-     rb_node *p = t->root;
-     while (p != nil) {
-         int comp = compare(k, p->k, c_data);
-         if (!comp) break;
-         p = comp <= 0 ? p->l : p->r;
-     }
-     if (p == nil)
-         return NULL;
-     while (p->l != nil && !compare(k, p->l->k, c_data)) p = p->l;
-     if (p->l != nil) p = p->l;
-     do {
-         if (p->k == k) return p;
-         p = rb_tree_succ(t, p);
-     } while (p && compare(p->k, k, c_data) <= 0);
-     return NULL;
-}
-
 /* find greatest point in subtree p that is <= k */
-static rb_node *find_le(rb_node *p, int k, rb_tree *t)
+static rb_node *find_le(rb_node *p, rb_key k, rb_tree *t)
 {
-     rb_node *nil = &t->nil;
      rb_compare compare = t->compare;
-     void *c_data = t->c_data;
-     while (p != nil) {
-         if (compare(p->k, k, c_data) <= 0) { /* p->k <= k */
+     while (p != NIL) {
+         if (compare(p->k, k) <= 0) { /* p->k <= k */
               rb_node *r = find_le(p->r, k, t);
               if (r) return r;
               else return p;
@@ -250,19 +206,17 @@ static rb_node *find_le(rb_node *p, int k, rb_tree *t)
 }
 
 /* find greatest point in t <= k */
-rb_node *rb_tree_find_le(rb_tree *t, int k)
+rb_node *rb_tree_find_le(rb_tree *t, rb_key k)
 {
      return find_le(t->root, k, t);
 }
 
 /* find least point in subtree p that is > k */
-static rb_node *find_gt(rb_node *p, int k, rb_tree *t)
+static rb_node *find_gt(rb_node *p, rb_key k, rb_tree *t)
 {
-     rb_node *nil = &t->nil;
      rb_compare compare = t->compare;
-     void *c_data = t->c_data;
-     while (p != nil) {
-         if (compare(p->k, k, c_data) > 0) { /* p->k > k */
+     while (p != NIL) {
+         if (compare(p->k, k) > 0) { /* p->k > k */
               rb_node *l = find_gt(p->l, k, t);
               if (l) return l;
               else return p;
@@ -274,64 +228,60 @@ static rb_node *find_gt(rb_node *p, int k, rb_tree *t)
 }
 
 /* find least point in t > k */
-rb_node *rb_tree_find_gt(rb_tree *t, int k)
+rb_node *rb_tree_find_gt(rb_tree *t, rb_key k)
 {
      return find_gt(t->root, k, t);
 }
 
 rb_node *rb_tree_min(rb_tree *t)
 {
-     rb_node *nil = &t->nil;
      rb_node *n = t->root;
-     while (n != nil && n->l != nil)
+     while (n != NIL && n->l != NIL)
          n = n->l;
-     return(n == nil ? NULL : n);
+     return(n == NIL ? NULL : n);
 }
 
 rb_node *rb_tree_max(rb_tree *t)
 {
-     rb_node *nil = &t->nil;
      rb_node *n = t->root;
-     while (n != nil && n->r != nil)
+     while (n != NIL && n->r != NIL)
          n = n->r;
-     return(n == nil ? NULL : n);
+     return(n == NIL ? NULL : n);
 }
 
-rb_node *rb_tree_succ(rb_tree *t, rb_node *n)
+rb_node *rb_tree_succ(rb_node *n)
 {
-     rb_node *nil = &t->nil;
      if (!n) return NULL;
-     if (n->r == nil) {
+     if (n->r == NIL) {
          rb_node *prev;
          do {
               prev = n;
               n = n->p;
-         } while (prev == n->r && n != nil);
-         return n == nil ? NULL : n;
+         } while (prev == n->r && n != NIL);
+         return n == NIL ? NULL : n;
      }
      else {
          n = n->r;
-         while (n->l != nil)
+         while (n->l != NIL)
               n = n->l;
          return n;
      }
 }
 
-rb_node *rb_tree_pred(rb_tree *t, rb_node *n)
+rb_node *rb_tree_pred(rb_node *n)
 {
-     rb_node *nil = &t->nil;
      if (!n) return NULL;
-     if (n->l == nil) {
+     if (n->l == NIL) {
          rb_node *prev;
          do {
               prev = n;
               n = n->p;
-         } while (prev == n->l && n != nil);
-         return n == nil ? NULL : n;
+         } while (prev == n->l && n != NIL);
+         return n == NIL ? NULL : n;
      }
      else {
          n = n->l;
-         while (n->r != nil)
+         while (n->r != NIL)
               n = n->r;
          return n;
      }
@@ -339,85 +289,85 @@ rb_node *rb_tree_pred(rb_tree *t, rb_node *n)
 
 rb_node *rb_tree_remove(rb_tree *t, rb_node *n)
 {
-     rb_node *nil = &t->nil;
-     rb_node *m;
-     if (n->l != nil && n->r != nil) {
+     rb_key k = n->k;
+     rb_node *m, *mp;
+     if (n->l != NIL && n->r != NIL) {
          rb_node *lmax = n->l;
-         while (lmax->r != nil) lmax = lmax->r;
+         while (lmax->r != NIL) lmax = lmax->r;
          n->k = lmax->k;
          n = lmax;
      }
-     m = n->l != nil? n->l : n->r;
-     if (n->p != nil) {
+     m = n->l != NIL ? n->l : n->r;
+     if (n->p != NIL) {
          if (n->p->r == n) n->p->r = m;
          else n->p->l = m;
      }
      else
          t->root = m;
-     m->p = n->p;
+     mp = n->p;
+     if (m != NIL) m->p = mp;
      if (n->c == BLACK) {
          if (m->c == RED)
               m->c = BLACK;
          else {
          deleteblack:
-              if (m->p != nil) {
-                   rb_node *s = m == m->p->l ? m->p->r : m->p->l;
+              if (mp != NIL) {
+                   rb_node *s = m == mp->l ? mp->r : mp->l;
                    if (s->c == RED) {
-                        m->p->c = RED;
+                        mp->c = RED;
                         s->c = BLACK;
-                        if (m == m->p->l) rotate_left(m->p, t);
-                        else rotate_right(m->p, t);
-                        s = m == m->p->l ? m->p->r : m->p->l;
+                        if (m == mp->l) rotate_left(mp, t);
+                        else rotate_right(mp, t);
+                        s = m == mp->l ? mp->r : mp->l;
                    }
-                   if (m->p->c == BLACK && s->c == BLACK
+                   if (mp->c == BLACK && s->c == BLACK
                        && s->l->c == BLACK && s->r->c == BLACK) {
-                        if (s != nil) s->c = RED;
-                        m = m->p;
+                        if (s != NIL) s->c = RED;
+                        m = mp; mp = m->p;
                         goto deleteblack;
                    }
-                   else if (m->p->c == RED && s->c == BLACK &&
+                   else if (mp->c == RED && s->c == BLACK &&
                             s->l->c == BLACK && s->r->c == BLACK) {
-                        if (s != nil) s->c = RED;
-                        m->p->c = BLACK;
+                        if (s != NIL) s->c = RED;
+                        mp->c = BLACK;
                    }
                    else {
-                        if (m == m->p->l && s->c == BLACK &&
+                        if (m == mp->l && s->c == BLACK &&
                             s->l->c == RED && s->r->c == BLACK) {
                              s->c = RED;
                              s->l->c = BLACK;
                              rotate_right(s, t);
-                             s = m == m->p->l ? m->p->r : m->p->l;
+                             s = m == mp->l ? mp->r : mp->l;
                         }
-                        else if (m == m->p->r && s->c == BLACK &&
+                        else if (m == mp->r && s->c == BLACK &&
                                  s->r->c == RED && s->l->c == BLACK) {
                              s->c = RED;
                              s->r->c = BLACK;
                              rotate_left(s, t);
-                             s = m == m->p->l ? m->p->r : m->p->l;
+                             s = m == mp->l ? mp->r : mp->l;
                         }
-                        s->c = m->p->c;
-                        m->p->c = BLACK;
-                        if (m == m->p->l) {
+                        s->c = mp->c;
+                        mp->c = BLACK;
+                        if (m == mp->l) {
                              s->r->c = BLACK;
-                             rotate_left(m->p, t);
+                             rotate_left(mp, t);
                         }
                         else {
                              s->l->c = BLACK;
-                             rotate_right(m->p, t);
+                             rotate_right(mp, t);
                         }
                    }
               }
          }
      }
+     t->N--;
+     n->k = k; /* n may have changed during remove */
      return n; /* the node that was deleted may be different from initial n */
 }
 
 rb_node *rb_tree_resort(rb_tree *t, rb_node *n)
 {
-     int k = n->k;
      n = rb_tree_remove(t, n);
-     n->p = n->l = n->r = &t->nil;
-     n->k = k; /* n may have changed during remove */
      insert_node(t, n);
      return n;
 }
index fdc507a806065a720121618a96495dd76cb0a4ac..8fc9ae41f68da0485d5f67c91101dbe6838e17dd 100644 (file)
@@ -6,44 +6,38 @@ extern "C"
 {
 #endif /* __cplusplus */
 
+typedef double *rb_key; /* key type ... double* is convenient for us,
+                          but of course this could be cast to anything
+                          desired (although void* would look more generic) */
+
 typedef enum { RED, BLACK } rb_color;
 typedef struct rb_node_s {
      struct rb_node_s *p, *r, *l; /* parent, right, left */
-     int k; /* key/data ... for DIRECT, an index into our hyperrect array */
+     rb_key k; /* key (and data) */
      rb_color c;
 } rb_node;
 
-typedef int (*rb_compare)(int k1, int k2, void *c_data);
+typedef int (*rb_compare)(rb_key k1, rb_key k2);
 
 typedef struct {
-     rb_compare compare; void *c_data;
+     rb_compare compare;
      rb_node *root;
      int N; /* number of nodes */
-
-     /* in our application, we can optimize memory allocation because
-       we never delete two nodes in a row (we always add a node after
-       deleting)... or rather, we never delete but the value of
-        the key sometimes changes.  ... this means we can just
-       allocate a linear, exponentially growing stack (nodes) of
-       nodes, and don't have to worry about holes in the stack */
-     rb_node *nodes; /* allocated data of nodes, in some order */
-     int Nalloc; /* number of allocated nodes */
-     rb_node nil; /* explicit node for NULL nodes, for convenience */
 } rb_tree;
 
-extern int rb_tree_init(rb_tree *t, rb_compare compare, void *c_data);
+extern void rb_tree_init(rb_tree *t, rb_compare compare);
 extern void rb_tree_destroy(rb_tree *t);
-extern int rb_tree_insert(rb_tree *t, int k);
+extern void rb_tree_destroy_with_keys(rb_tree *t);
+extern rb_node *rb_tree_insert(rb_tree *t, rb_key k);
 extern int rb_tree_check(rb_tree *t);
-extern rb_node *rb_tree_find(rb_tree *t, int k);
-extern rb_node *rb_tree_find_exact(rb_tree *t, int k);
-extern rb_node *rb_tree_find_le(rb_tree *t, int k);
-extern rb_node *rb_tree_find_gt(rb_tree *t, int k);
+extern rb_node *rb_tree_find(rb_tree *t, rb_key k);
+extern rb_node *rb_tree_find_le(rb_tree *t, rb_key k);
+extern rb_node *rb_tree_find_gt(rb_tree *t, rb_key k);
 extern rb_node *rb_tree_resort(rb_tree *t, rb_node *n);
 extern rb_node *rb_tree_min(rb_tree *t);
 extern rb_node *rb_tree_max(rb_tree *t);
-extern rb_node *rb_tree_succ(rb_tree *t, rb_node *n);
-extern rb_node *rb_tree_pred(rb_tree *t, rb_node *n);
+extern rb_node *rb_tree_succ(rb_node *n);
+extern rb_node *rb_tree_pred(rb_node *n);
 
 /* To change a key, use rb_tree_find+resort.  Removing a node
    currently wastes memory unless you change the allocation scheme
index 003f90da2639fa0ff01ceac98ecb38a5e35bc587..fba0f7baef80d94f0cd19ad52d30af0fd17ea3cf 100644 (file)
@@ -1,37 +1,39 @@
 #include <stdio.h>
 #include <stdlib.h>
+#include <math.h>
 #include <time.h>
 #include "redblack.h"
 
-static int comp(int k1, int k2, void *dummy)
+static int comp(rb_key k1, rb_key k2)
 {
-     (void) dummy;
-     return k1 - k2;
+     if (*k1 < *k2) return -1;
+     if (*k1 > *k2) return +1;
+     return 0;
 }
 
 int main(int argc, char **argv)
 {
      int N, M;
      int *k;
+     double kd;
      rb_tree t;
      rb_node *n;
      int i, j;
 
-     if (argc != 2) {
-         fprintf(stderr, "Usage: redblack_test Ntest\n");
+     if (argc < 2) {
+         fprintf(stderr, "Usage: redblack_test Ntest [rand seed]\n");
          return 1;
      }
 
      N = atoi(argv[1]);
      k = (int *) malloc(N * sizeof(int));
-     if (!rb_tree_init(&t, comp, NULL)) {
-         fprintf(stderr, "error in rb_tree_init\n");
-         return 1;
-     }
+     rb_tree_init(&t, comp);
 
-     srand((unsigned) time(NULL));
+     srand((unsigned) (argc > 2 ? atoi(argv[2]) : time(NULL)));
      for (i = 0; i < N; ++i) {
-         if (!rb_tree_insert(&t, k[i] = rand() % N)) {
+         double *newk = (double *) malloc(sizeof(double));
+         *newk = (k[i] = rand() % N);
+         if (!rb_tree_insert(&t, newk)) {
               fprintf(stderr, "error in rb_tree_insert\n");
               return 1;
          }
@@ -41,20 +43,27 @@ int main(int argc, char **argv)
          }
      }
      
-     for (i = 0; i < N; ++i)
-         if (!rb_tree_find(&t, k[i]) || !rb_tree_find_exact(&t, k[i])) {
+     if (t.N != N) {
+         fprintf(stderr, "incorrect N (%d) in tree (vs. %d)\n", t.N, N);
+         return 1;
+     }
+
+     for (i = 0; i < N; ++i) {
+         kd = k[i];
+         if (!rb_tree_find(&t, &kd)) {
               fprintf(stderr, "rb_tree_find lost %d!\n", k[i]);
               return 1;
          }
-     
+     }
      n = rb_tree_min(&t);
      for (i = 0; i < N; ++i) {
          if (!n) {
               fprintf(stderr, "not enough successors %d\n!", i);
               return 1;
          }
-         printf("%d: %d\n", i, n->k);
-         n = rb_tree_succ(&t, n);
+         printf("%d: %g\n", i, n->k[0]);
+         n = rb_tree_succ(n);
      }
      if (n) {
          fprintf(stderr, "too many successors!\n");
@@ -67,8 +76,8 @@ int main(int argc, char **argv)
               fprintf(stderr, "not enough predecessors %d\n!", i);
               return 1;
          }
-         printf("%d: %d\n", i, n->k);
-         n = rb_tree_pred(&t, n);
+         printf("%d: %g\n", i, n->k[0]);
+         n = rb_tree_pred(n);
      }
      if (n) {
          fprintf(stderr, "too many predecessors!\n");
@@ -83,17 +92,19 @@ int main(int argc, char **argv)
                    if (j-- == 0)
                         break;
          if (i >= N) abort();
-         if (!(n = rb_tree_find(&t, k[i])) || !rb_tree_find_exact(&t, k[i])) {
+         kd = k[i];
+         if (!(n = rb_tree_find(&t, &kd))) {
                fprintf(stderr, "rb_tree_find lost %d!\n", k[i]);
                return 1;
           }
-         n->k = knew;
+         n->k[0] = knew;
          if (!rb_tree_resort(&t, n)) {
               fprintf(stderr, "error in rb_tree_resort\n");
               return 1;
          }
          if (!rb_tree_check(&t)) {
-              fprintf(stderr, "rb_tree_check_failed after change!\n");
+              fprintf(stderr, "rb_tree_check_failed after change %d!\n",
+                      N - M + 1);
               return 1;
          }
          k[i] = -1 - knew;
@@ -104,26 +115,25 @@ int main(int argc, char **argv)
          return 1;
      }
 
+     for (i = 0; i < N; ++i)
+         k[i] = -1 - k[i]; /* undo negation above */
+         
      for (i = 0; i < N; ++i) {
-         k[i] = -1 - k[i];
-         /* rescale keys by 100 to add more space between them */
-         k[i] *= 100;
-         t.nodes[i].k *= 100;
-     }
-
-     for (i = 0; i < N; ++i) {
-         int k = rand() % (N * 150) - N*25;
-         rb_node *le = rb_tree_find_le(&t, k);
-         rb_node *gt = rb_tree_find_gt(&t, k);
-         rb_node *n = rb_tree_min(&t);
-         printf("%d <= %d < %d\n", le? le->k:-999999, k, gt? gt->k:999999);
-         if (n->k > k) {
+         rb_node *le, *gt;
+         kd = 0.01 * (rand() % (N * 150) - N*25);
+         le = rb_tree_find_le(&t, &kd);
+         gt = rb_tree_find_gt(&t, &kd);
+         n = rb_tree_min(&t);
+         double lek = le ? le->k[0] : -HUGE_VAL;
+         double gtk = gt ? gt->k[0] : +HUGE_VAL;
+         printf("%g <= %g < %g\n", lek, kd, gtk);
+         if (n->k[0] > kd) {
               if (le) {
-                   fprintf(stderr, "found invalid le %d for %d\n", le->k, k);
+                   fprintf(stderr, "found invalid le %g for %g\n", lek, kd);
                    return 1;
               }
               if (gt != n) {
-                   fprintf(stderr, "gt is not first node for k=%d\n", k);
+                   fprintf(stderr, "gt is not first node for k=%g\n", kd);
                    return 1;
               }
          }
@@ -131,14 +141,16 @@ int main(int argc, char **argv)
               rb_node *succ = n;
               do {
                    n = succ;
-                   succ = rb_tree_succ(&t, n);
-              } while (succ && succ->k <= k);
+                   succ = rb_tree_succ(n);
+              } while (succ && succ->k[0] <= kd);
               if (n != le) {
-                   fprintf("rb_tree_find_le gave wrong result for k=%d\n", k);
+                   fprintf(stderr,
+                           "rb_tree_find_le gave wrong result for k=%g\n",kd);
                    return 1;
               }
               if (succ != gt) {
-                   fprintf("rb_tree_find_gt gave wrong result for k=%d\n", k);
+                   fprintf(stderr,
+                           "rb_tree_find_gt gave wrong result for k=%g\n",kd);
                    return 1;
               }
          }
@@ -151,11 +163,14 @@ int main(int argc, char **argv)
                    if (j-- == 0)
                         break;
          if (i >= N) abort();
-         if (!(n = rb_tree_find(&t, k[i])) || !rb_tree_find_exact(&t, k[i])) {
+         kd = k[i];
+         if (!(n = rb_tree_find(&t, &kd))) {
               fprintf(stderr, "rb_tree_find lost %d!\n", k[i]);
               return 1;
          }
-         rb_tree_remove(&t, n);
+         n = rb_tree_remove(&t, n);
+         free(n->k); 
+         free(n);
          if (!rb_tree_check(&t)) {
               fprintf(stderr, "rb_tree_check_failed after remove!\n");
               return 1;
@@ -163,6 +178,11 @@ int main(int argc, char **argv)
          k[i] = -1 - k[i];
      }
      
+     if (t.N != 0) {
+         fprintf(stderr, "nonzero N (%d) in tree at end\n", t.N);
+         return 1;
+     }
+
      rb_tree_destroy(&t);
      free(k);
      return 0;