From a4c19a8bc0b800a7876f185ff9854e3a3f37cf93 Mon Sep 17 00:00:00 2001 From: stevenj Date: Wed, 29 Aug 2007 00:45:08 -0400 Subject: [PATCH] saner memory management (and better performance), improved handling of case when no potentially optimal rects are found (choose largest rect with smallest function val, greatly improving many test cases!) darcs-hash:20070829044508-c8de0-635e1db270782d6b6741c27994cdc596c4cdde2c.gz --- cdirect/cdirect.c | 240 +++++++++++++++++------------------ cdirect/redblack.c | 274 ++++++++++++++++------------------------ cdirect/redblack.h | 36 +++--- cdirect/redblack_test.c | 104 +++++++++------ 4 files changed, 302 insertions(+), 352 deletions(-) diff --git a/cdirect/cdirect.c b/cdirect/cdirect.c index 05ec2c7..5b7200c 100644 --- a/cdirect/cdirect.c +++ b/cdirect/cdirect.c @@ -33,7 +33,6 @@ typedef struct { int n; /* dimension */ int L; /* RECT_LEN(n) */ - double *rects; /* the hyper-rectangles */ double magic_eps; /* Jones' epsilon parameter (1e-4 is recommended) */ int which_diam; /* which measure of hyper-rectangle diam to use: 0 = Jones, 1 = Gablonsky */ @@ -45,12 +44,14 @@ typedef struct { nlopt_stopping *stop; /* stopping criteria */ nlopt_func f; void *f_data; double *work; /* workspace, of length >= 2*n */ - int *iwork, iwork_len; /* workspace, of length iwork_len >= n */ + int *iwork; /* workspace, length >= n */ double fmin, *xmin; /* minimum so far */ - /* red-black tree of hyperrect indices, sorted by (d,f) in + /* red-black tree of hyperrects, sorted by (d,f) in lexographical order */ rb_tree rtree; + double **hull; /* array to store convex hull */ + int hull_len; /* allocated length of hull array */ } params; /***************************************************************************/ @@ -74,16 +75,7 @@ static double rect_diameter(int n, const double *w, const params *p) } } -static double *alloc_rects(int n, int *Na, double *rects, int newN) -{ - if (newN <= *Na) - return rects; - else { - (*Na) += newN; - return realloc(rects, sizeof(double) * RECT_LEN(n) * (*Na)); - } -} -#define ALLOC_RECTS(n, Nap, rects, newN) if (!(rects = alloc_rects(n, Nap, rects, newN))) return NLOPT_OUT_OF_MEMORY +#define ALLOC_RECT(rect, L) if (!(rect = (double*) malloc(sizeof(double)*(L)))) return NLOPT_OUT_OF_MEMORY static double *fv_qsort = 0; static int sort_fv_compare(const void *a_, const void *b_) @@ -116,20 +108,19 @@ static double function_eval(const double *x, params *p) { p->stop->nevals++; return f; } -#define FUNCTION_EVAL(fv,x,p) fv = function_eval(x, p); if (p->fmin < p->stop->fmin_max) return NLOPT_FMIN_MAX_REACHED; else if (nlopt_stop_evals((p)->stop)) return NLOPT_MAXEVAL_REACHED; else if (nlopt_stop_time((p)->stop)) return NLOPT_MAXTIME_REACHED +#define FUNCTION_EVAL(fv,x,p,freeonerr) fv = function_eval(x, p); if (p->fmin < p->stop->fmin_max) { free(freeonerr); return NLOPT_FMIN_MAX_REACHED; } else if (nlopt_stop_evals((p)->stop)) { free(freeonerr); return NLOPT_MAXEVAL_REACHED; } else if (nlopt_stop_time((p)->stop)) { free(freeonerr); return NLOPT_MAXTIME_REACHED; } #define THIRD (0.3333333333333333333333) #define EQUAL_SIDE_TOL 5e-2 /* tolerance to equate side sizes */ /* divide rectangle idiv in the list p->rects */ -static nlopt_result divide_rect(int *N, int *Na, int idiv, params *p) +static nlopt_result divide_rect(double *rdiv, params *p) { int i; const const int n = p->n; const int L = p->L; - double *r = p->rects; - double *c = r + L*idiv + 2; /* center of rect to divide */ + double *c = rdiv + 2; /* center of rect to divide */ double *w = c + n; /* widths of rect to divide */ double wmax = w[0]; int imax = 0, nlongest = 0; @@ -150,9 +141,9 @@ static nlopt_result divide_rect(int *N, int *Na, int idiv, params *p) if (wmax - w[i] <= wmax * EQUAL_SIDE_TOL) { double csave = c[i]; c[i] = csave - w[i] * THIRD; - FUNCTION_EVAL(fv[2*i], c, p); + FUNCTION_EVAL(fv[2*i], c, p, 0); c[i] = csave + w[i] * THIRD; - FUNCTION_EVAL(fv[2*i+1], c, p); + FUNCTION_EVAL(fv[2*i+1], c, p, 0); c[i] = csave; } else { @@ -160,23 +151,23 @@ static nlopt_result divide_rect(int *N, int *Na, int idiv, params *p) } } sort_fv(n, fv, isort); - ALLOC_RECTS(n, Na, r, (*N)+2*nlongest); - p->rects = r; c = r + L*idiv + 2; w = c + n; + if (!(node = rb_tree_find(&p->rtree, rdiv))) + return NLOPT_FAILURE; for (i = 0; i < nlongest; ++i) { int k; - if (!(node = rb_tree_find_exact(&p->rtree, idiv))) - return NLOPT_FAILURE; w[isort[i]] *= THIRD; - r[L*idiv] = rect_diameter(n, w, p); - rb_tree_resort(&p->rtree, node); + rdiv[0] = rect_diameter(n, w, p); + node = rb_tree_resort(&p->rtree, node); for (k = 0; k <= 1; ++k) { - r[L*(*N)] = r[L*idiv]; - memcpy(r + L*(*N) + 2, c, sizeof(double) * 2*n); - r[L*(*N) + 2 + isort[i]] += w[isort[i]] * (2*k-1); - r[L*(*N) + 1] = fv[2*isort[i]+k]; - if (!rb_tree_insert(&p->rtree, *N)) + double *rnew; + ALLOC_RECT(rnew, L); + memcpy(rnew, rdiv, sizeof(double) * L); + rnew[2 + isort[i]] += w[isort[i]] * (2*k-1); + rnew[1] = fv[2*isort[i]+k]; + if (!rb_tree_insert(&p->rtree, rnew)) { + free(rnew); return NLOPT_FAILURE; - ++(*N); + } } } } @@ -193,21 +184,21 @@ static nlopt_result divide_rect(int *N, int *Na, int idiv, params *p) } else i = imax; /* trisect longest side */ - ALLOC_RECTS(n, Na, r, (*N)+2); - p->rects = r; c = r + L*idiv + 2; w = c + n; - if (!(node = rb_tree_find_exact(&p->rtree, idiv))) + if (!(node = rb_tree_find(&p->rtree, rdiv))) return NLOPT_FAILURE; w[i] *= THIRD; - r[L*idiv] = rect_diameter(n, w, p); - rb_tree_resort(&p->rtree, node); + rdiv[0] = rect_diameter(n, w, p); + node = rb_tree_resort(&p->rtree, node); for (k = 0; k <= 1; ++k) { - r[L*(*N)] = r[L*idiv]; - memcpy(r + L*(*N) + 2, c, sizeof(double) * 2*n); - r[L*(*N) + 2 + i] += w[i] * (2*k-1); - FUNCTION_EVAL(r[L*(*N) + 1], r + L*(*N) + 2, p); - if (!rb_tree_insert(&p->rtree, *N)) + double *rnew; + ALLOC_RECT(rnew, L); + memcpy(rnew, rdiv, sizeof(double) * L); + rnew[2 + i] += w[i] * (2*k-1); + FUNCTION_EVAL(rnew[1], rnew + 2, p, rnew); + if (!rb_tree_insert(&p->rtree, rnew)) { + free(rnew); return NLOPT_FAILURE; - ++(*N); + } } } return NLOPT_SUCCESS; @@ -217,65 +208,62 @@ static nlopt_result divide_rect(int *N, int *Na, int idiv, params *p) /* O(N log N) convex hull algorithm, used later to find the potentially optimal points */ -/* Find the lower convex hull of a set of points (xy[s*i], xy[s*i+1]), where - 0 <= i < N and s >= 2. +/* Find the lower convex hull of a set of points (x,y) stored in a rb-tree + of pointers to {x,y} arrays sorted in lexographic order by (x,y). Unlike standard convex hulls, we allow redundant points on the hull. - The return value is the number of points in the hull, with indices - stored in ihull. ihull should point to arrays of length >= N. - rb_tree should be a red-black tree of indices (keys == i) sorted - in lexographic order by (xy[s*i], xy[s*i+1]). + The return value is the number of points in the hull, with pointers + stored in hull[i] (should be an array of length >= t->N). */ -static int convex_hull(int N, double *xy, int s, int *ihull, rb_tree *t) +static int convex_hull(rb_tree *t, double **hull) { - int nhull; + int nhull = 0; double minslope; double xmin, xmax, yminmin, ymaxmin; rb_node *n, *nmax; - if (N == 0) return 0; - /* Monotone chain algorithm [Andrew, 1979]. */ n = rb_tree_min(t); + if (!n) return 0; nmax = rb_tree_max(t); - xmin = xy[s*(n->k)]; - yminmin = xy[s*(n->k)+1]; - xmax = xy[s*(nmax->k)]; + xmin = n->k[0]; + yminmin = n->k[1]; + xmax = nmax->k[0]; - ihull[nhull = 1] = n->k; + hull[nhull++] = n->k; if (xmin == xmax) return nhull; /* set nmax = min mode with x == xmax */ - while (xy[s*(nmax->k)] == xmax) - nmax = rb_tree_pred(t, nmax); /* non-NULL since xmin != xmax */ - nmax = rb_tree_succ(t, nmax); + while (nmax->k[0] == xmax) + nmax = rb_tree_pred(nmax); /* non-NULL since xmin != xmax */ + nmax = rb_tree_succ(nmax); - ymaxmin = xy[s*(nmax->k)+1]; + ymaxmin = nmax->k[1]; minslope = (ymaxmin - yminmin) / (xmax - xmin); /* set n = first node with x != xmin */ - while (xy[s*(n->k)] == xmin) - n = rb_tree_succ(t, n); /* non-NULL since xmin != xmax */ + while (n->k[0] == xmin) + n = rb_tree_succ(n); /* non-NULL since xmin != xmax */ - for (; n != nmax; n = rb_tree_succ(t, n)) { - int k = n->k; - if (xy[s*k+1] > yminmin + (xy[s*k] - xmin) * minslope) + for (; n != nmax; n = rb_tree_succ(n)) { + double *k = n->k; + if (k[1] > yminmin + (k[0] - xmin) * minslope) continue; /* remove points until we are making a "left turn" to k */ while (nhull > 1) { - int t1 = ihull[nhull - 1], t2 = ihull[nhull - 2]; + double *t1 = hull[nhull - 1], *t2 = hull[nhull - 2]; /* cross product (t1-t2) x (k-t2) > 0 for a left turn: */ - if ((xy[s*t1]-xy[s*t2]) * (xy[s*k+1]-xy[s*t2+1]) - - (xy[s*t1+1]-xy[s*t2+1]) * (xy[s*k]-xy[s*t2]) >= 0) + if ((t1[0]-t2[0]) * (k[1]-t2[1]) + - (t1[1]-t2[1]) * (k[0]-t2[0]) >= 0) break; --nhull; } - ihull[nhull++] = k; + hull[nhull++] = k; } - ihull[nhull++] = nmax->k; + hull[nhull++] = nmax->k; return nhull; } @@ -291,41 +279,34 @@ static int small(double *w, params *p) return 1; } -static nlopt_result divide_good_rects(int *N, int *Na, params *p) +static nlopt_result divide_good_rects(params *p) { const int n = p->n; - const int L = p->L; - int *ihull, nhull, i, xtol_reached = 1, divided_some = 0; - double *r = p->rects; + double **hull; + int nhull, i, xtol_reached = 1, divided_some = 0; double magic_eps = p->magic_eps; - if (p->iwork_len < *N) { - p->iwork_len = p->iwork_len + *N; - p->iwork = (int *) realloc(p->iwork, sizeof(int) * p->iwork_len); - if (!p->iwork) - return NLOPT_OUT_OF_MEMORY; + if (p->hull_len < p->rtree.N) { + p->hull_len += p->rtree.N; + p->hull = (double **) realloc(p->hull, sizeof(double*)*p->hull_len); + if (!p->hull) return NLOPT_OUT_OF_MEMORY; } - ihull = p->iwork; - nhull = convex_hull(*N, r, L, ihull, &p->rtree); + nhull = convex_hull(&p->rtree, hull = p->hull); divisions: for (i = 0; i < nhull; ++i) { double K1 = -HUGE_VAL, K2 = -HUGE_VAL, K; if (i > 0) - K1 = (r[L*ihull[i]+1] - r[L*ihull[i-1]+1]) / - (r[L*ihull[i]] - r[L*ihull[i-1]]); + K1 = (hull[i][1] - hull[i-1][1]) / (hull[i][0] - hull[i-1][0]); if (i < nhull-1) - K1 = (r[L*ihull[i]+1] - r[L*ihull[i+1]+1]) / - (r[L*ihull[i]] - r[L*ihull[i+1]]); + K1 = (hull[i][1] - hull[i+1][1]) / (hull[i][0] - hull[i+1][0]); K = MAX(K1, K2); - if (r[L*ihull[i]+1] - K * r[L*ihull[i]] + if (hull[i][1] - K * hull[i][0] <= p->fmin - magic_eps * fabs(p->fmin)) { /* "potentially optimal" rectangle, so subdivide */ divided_some = 1; - nlopt_result ret; - ret = divide_rect(N, Na, ihull[i], p); - r = p->rects; /* may have grown */ + nlopt_result ret = divide_rect(hull[i], p); if (ret != NLOPT_SUCCESS) return ret; - xtol_reached = xtol_reached && small(r + L*ihull[i] + 2+n, p); + xtol_reached = xtol_reached && small(hull[i] + 2+n, p); } } if (!divided_some) { @@ -333,13 +314,19 @@ static nlopt_result divide_good_rects(int *N, int *Na, params *p) magic_eps = 0; goto divisions; /* try again */ } - else { /* WTF? divide largest rectangle */ - double wmax = r[0]; - int imax = 0; - for (i = 1; i < *N; ++i) - if (r[L*i] > wmax) - wmax = r[L*(imax=i)]; - return divide_rect(N, Na, imax, p); + else { /* WTF? divide largest rectangle with smallest f */ + /* (note that this code actually gets called from time + to time, and the heuristic here seems to work well, + but I don't recall this situation being discussed in + the references?) */ + rb_node *max = rb_tree_max(&p->rtree); + rb_node *pred = max; + double wmax = max->k[0]; + do { /* note: this loop is O(N) worst-case time */ + max = pred; + pred = rb_tree_pred(max); + } while (pred && pred->k[0] == wmax); + return divide_rect(max->k, p); } } return xtol_reached ? NLOPT_XTOL_REACHED : NLOPT_SUCCESS; @@ -348,18 +335,13 @@ static nlopt_result divide_good_rects(int *N, int *Na, params *p) /***************************************************************************/ /* lexographic sort order (d,f) of hyper-rects, for red-black tree */ -static int hyperrect_compare(int i, int j, void *p_) +static int hyperrect_compare(double *a, double *b) { - params *p = (params *) p_; - int L = p->L; - double *r = p->rects; - double di = r[i*L], dj = r[j*L], fi, fj; - if (di < dj) return -1; - if (dj < di) return +1; - fi = r[i*L+1]; fj = r[j*L+1]; - if (fi < fj) return -1; - if (fj < fi) return +1; - return 0; + if (a[0] < b[0]) return -1; + if (a[0] > b[0]) return +1; + if (a[1] < b[1]) return -1; + if (a[1] > b[1]) return +1; + return (int) (a - b); /* tie-breaker */ } /***************************************************************************/ @@ -372,7 +354,8 @@ nlopt_result cdirect_unscaled(int n, nlopt_func f, void *f_data, double magic_eps, int which_alg) { params p; - int Na = 100, N = 1, i, x_center = 1; + int i, x_center = 1; + double *rnew; nlopt_result ret = NLOPT_OUT_OF_MEMORY; p.magic_eps = magic_eps; @@ -388,38 +371,41 @@ nlopt_result cdirect_unscaled(int n, nlopt_func f, void *f_data, p.fmin = f(n, x, NULL, f_data); stop->nevals++; p.work = 0; p.iwork = 0; - p.rects = 0; + p.hull = 0; + + rb_tree_init(&p.rtree, hyperrect_compare); - if (!rb_tree_init(&p.rtree, hyperrect_compare, &p)) goto done; - p.work = (double *) malloc(sizeof(double) * 2*n); + p.work = (double *) malloc(sizeof(double) * (2*n)); if (!p.work) goto done; - p.rects = (double *) malloc(sizeof(double) * Na * RECT_LEN(n)); - if (!p.rects) goto done; - p.iwork = (int *) malloc(sizeof(int) * (p.iwork_len = Na + n)); + p.iwork = (int *) malloc(sizeof(int) * n); if (!p.iwork) goto done; + p.hull_len = 128; /* start with a reasonable number */ + p.hull = (double **) malloc(sizeof(double *) * p.hull_len); + if (!p.hull) goto done; + if (!(rnew = (double *) malloc(sizeof(double) * p.L))) goto done; for (i = 0; i < n; ++i) { - p.rects[2+i] = 0.5 * (lb[i] + ub[i]); + rnew[2+i] = 0.5 * (lb[i] + ub[i]); x_center = x_center - && (fabs(p.rects[2+i]-x[i]) < 1e-13*(1+fabs(x[i]))); - p.rects[2+n+i] = ub[i] - lb[i]; + && (fabs(rnew[2+i]-x[i]) < 1e-13*(1+fabs(x[i]))); + rnew[2+n+i] = ub[i] - lb[i]; } - p.rects[0] = rect_diameter(n, p.rects+2+n, &p); + rnew[0] = rect_diameter(n, rnew+2+n, &p); if (x_center) - p.rects[1] = p.fmin; /* avoid computing f(center) twice */ + rnew[1] = p.fmin; /* avoid computing f(center) twice */ else - p.rects[1] = function_eval(p.rects+2, &p); - if (!rb_tree_insert(&p.rtree, 0)) { - ret = NLOPT_FAILURE; + rnew[1] = function_eval(rnew+2, &p); + if (!rb_tree_insert(&p.rtree, rnew)) { + free(rnew); goto done; } - ret = divide_rect(&N, &Na, 0, &p); + ret = divide_rect(rnew, &p); if (ret != NLOPT_SUCCESS) goto done; while (1) { double fmin0 = p.fmin; - ret = divide_good_rects(&N, &Na, &p); + ret = divide_good_rects(&p); if (ret != NLOPT_SUCCESS) goto done; if (nlopt_stop_f(p.stop, p.fmin, fmin0)) { ret = NLOPT_FTOL_REACHED; @@ -428,9 +414,9 @@ nlopt_result cdirect_unscaled(int n, nlopt_func f, void *f_data, } done: - rb_tree_destroy(&p.rtree); + rb_tree_destroy_with_keys(&p.rtree); + free(p.hull); free(p.iwork); - free(p.rects); free(p.work); *fmin = p.fmin; diff --git a/cdirect/redblack.c b/cdirect/redblack.c index 2ddcebb..44faf2d 100644 --- a/cdirect/redblack.c +++ b/cdirect/redblack.c @@ -4,62 +4,48 @@ #include #include "redblack.h" -int rb_tree_init(rb_tree *t, rb_compare compare, void *c_data) { - t->compare = compare; t->c_data = c_data; - t->nil.c = BLACK; t->nil.l = t->nil.r = t->nil.p = &t->nil; t->nil.k = -1; - t->root = &t->nil; +/* it is convenient to use an explicit node for NULL nodes ... we need + to be careful never to change this node indirectly via one of our + pointers! */ +rb_node nil = {&nil, &nil, &nil, 0, BLACK}; +#define NIL (&nil) + +void rb_tree_init(rb_tree *t, rb_compare compare) { + t->compare = compare; + t->root = NIL; t->N = 0; - t->Nalloc = 100; /* allocate some space to start with */ - t->nodes = (rb_node*) malloc(sizeof(rb_node) * t->Nalloc); - return t->nodes != NULL; +} + +static void destroy(rb_node *n) +{ + if (n != NIL) { + destroy(n->l); destroy(n->r); + free(n); + } } void rb_tree_destroy(rb_tree *t) { - t->root = 0; t->N = 0; t->Nalloc = 0; - free(t->nodes); t->nodes = 0; + destroy(t->root); + t->root = NIL; } -/* in our application, we can optimize memory allocation because - we never delete two nodes in a row (we always add a node after - deleting)... or rather, we never delete but the value of - the key sometimes changes. ... this means we can just - allocate a linear, exponentially growing stack (nodes) of - nodes, and don't have to worry about holes in the stack ... - otherwise, alloc1 should be replaced by an implementation that - malloc's each node separately */ -static rb_node *alloc1(rb_tree *t, int k) +void rb_tree_destroy_with_keys(rb_tree *t) { - rb_node *nil = &t->nil; - rb_node *n; - if (t->Nalloc == t->N) { /* grow allocation */ - rb_node *old_nodes = t->nodes; - ptrdiff_t change; - int i; - t->Nalloc = 2*t->Nalloc + 1; - t->nodes = (rb_node*) realloc(t->nodes, sizeof(rb_node) * t->Nalloc); - if (!t->nodes) return NULL; - change = t->nodes - old_nodes; - if (t->root != nil) t->root += change; - for (i = 0; i < t->N; ++i) { /* shift all pointers, ugh */ - if (t->nodes[i].p != nil) t->nodes[i].p += change; - if (t->nodes[i].r != nil) t->nodes[i].r += change; - if (t->nodes[i].l != nil) t->nodes[i].l += change; - } + rb_node *n = rb_tree_min(t); + while (n) { + free(n->k); n->k = NULL; + n = rb_tree_succ(n); } - n = t->nodes + t->N++; - n->k = k; - n->p = n->l = n->r = nil; - return n; + rb_tree_destroy(t); } static void rotate_left(rb_node *p, rb_tree *t) { - rb_node *nil = &t->nil; - rb_node *n = p->r; /* must be non-NULL */ + rb_node *n = p->r; /* must be non-NIL */ p->r = n->l; n->l = p; - if (p->p != nil) { + if (p->p != NIL) { if (p == p->p->l) p->p->l = n; else p->p->r = n; } @@ -67,16 +53,15 @@ static void rotate_left(rb_node *p, rb_tree *t) t->root = n; n->p = p->p; p->p = n; - if (p->r != nil) p->r->p = p; + if (p->r != NIL) p->r->p = p; } static void rotate_right(rb_node *p, rb_tree *t) { - rb_node *nil = &t->nil; - rb_node *n = p->l; /* must be non-NULL */ + rb_node *n = p->l; /* must be non-NIL */ p->l = n->r; n->r = p; - if (p->p != nil) { + if (p->p != NIL) { if (p == p->p->l) p->p->l = n; else p->p->r = n; } @@ -84,26 +69,26 @@ static void rotate_right(rb_node *p, rb_tree *t) t->root = n; n->p = p->p; p->p = n; - if (p->l != nil) p->l->p = p; + if (p->l != NIL) p->l->p = p; } static void insert_node(rb_tree *t, rb_node *n) { - rb_node *nil = &t->nil; rb_compare compare = t->compare; - void *c_data = t->c_data; - int k = n->k; + rb_key k = n->k; rb_node *p = t->root; n->c = RED; - if (p == nil) { + n->p = n->l = n->r = NIL; + t->N++; + if (p == NIL) { t->root = n; n->c = BLACK; return; } /* insert (RED) node into tree */ while (1) { - if (compare(k, p->k, c_data) <= 0) { /* k <= p->k */ - if (p->l != nil) + if (compare(k, p->k) <= 0) { /* k <= p->k */ + if (p->l != NIL) p = p->l; else { p->l = n; @@ -112,7 +97,7 @@ static void insert_node(rb_tree *t, rb_node *n) } } else { - if (p->r != nil) + if (p->r != NIL) p = p->r; else { p->r = n; @@ -124,10 +109,10 @@ static void insert_node(rb_tree *t, rb_node *n) fixtree: if (n->p->c == RED) { /* red cannot have red child */ rb_node *u = p == p->p->l ? p->p->r : p->p->l; - if (u != nil && u->c == RED) { + if (u != NIL && u->c == RED) { p->c = u->c = BLACK; n = p->p; - if ((p = n->p) != nil) { + if ((p = n->p) != NIL) { n->c = RED; goto fixtree; } @@ -152,93 +137,64 @@ static void insert_node(rb_tree *t, rb_node *n) } } -int rb_tree_insert(rb_tree *t, int k) +rb_node *rb_tree_insert(rb_tree *t, rb_key k) { - rb_node *n = alloc1(t, k); - if (!n) return 0; + rb_node *n = (rb_node *) malloc(sizeof(rb_node)); + if (!n) return NULL; + n->k = k; insert_node(t, n); - return 1; + return n; } static int check_node(rb_node *n, int *nblack, rb_tree *t) { - rb_node *nil = &t->nil; int nbl, nbr; rb_compare compare = t->compare; - void *c_data = t->c_data; - if (n == nil) { *nblack = 0; return 1; } - if (n->r != nil && n->r->p != n) return 0; - if (n->r != nil && compare(n->r->k, n->k, c_data) < 0) + if (n == NIL) { *nblack = 0; return 1; } + if (n->r != NIL && n->r->p != n) return 0; + if (n->r != NIL && compare(n->r->k, n->k) < 0) return 0; - if (n->l != nil && n->l->p != n) return 0; - if (n->l != nil && compare(n->l->k, n->k, c_data) > 0) + if (n->l != NIL && n->l->p != n) return 0; + if (n->l != NIL && compare(n->l->k, n->k) > 0) return 0; if (n->c == RED) { - if (n->r != nil && n->r->c == RED) return 0; - if (n->l != nil && n->l->c == RED) return 0; + if (n->r != NIL && n->r->c == RED) return 0; + if (n->l != NIL && n->l->c == RED) return 0; } if (!(check_node(n->r, &nbl, t) && check_node(n->l, &nbr, t))) return 0; if (nbl != nbr) return 0; - *nblack = nbl + n->c == BLACK; + *nblack = nbl + (n->c == BLACK); return 1; } int rb_tree_check(rb_tree *t) { - rb_node *nil = &t->nil; int nblack; - if (nil->c != BLACK) return 0; - if (t->root == nil) return 1; + if (nil.c != BLACK) return 0; + if (nil.p != NIL || nil.r != NIL || nil.l != NIL) return 0; + if (t->root == NIL) return 1; if (t->root->c != BLACK) return 0; return check_node(t->root, &nblack, t); } -rb_node *rb_tree_find(rb_tree *t, int k) +rb_node *rb_tree_find(rb_tree *t, rb_key k) { - rb_node *nil = &t->nil; rb_compare compare = t->compare; - void *c_data = t->c_data; rb_node *p = t->root; - while (p != nil) { - int comp = compare(k, p->k, c_data); + while (p != NIL) { + int comp = compare(k, p->k); if (!comp) return p; p = comp <= 0 ? p->l : p->r; } return NULL; } -/* like rb_tree_find, but guarantees that returned node n will have - n->k == k (may not be true above if compare(k,k') == 0 for some k != k') */ -rb_node *rb_tree_find_exact(rb_tree *t, int k) -{ - rb_node *nil = &t->nil; - rb_compare compare = t->compare; - void *c_data = t->c_data; - rb_node *p = t->root; - while (p != nil) { - int comp = compare(k, p->k, c_data); - if (!comp) break; - p = comp <= 0 ? p->l : p->r; - } - if (p == nil) - return NULL; - while (p->l != nil && !compare(k, p->l->k, c_data)) p = p->l; - if (p->l != nil) p = p->l; - do { - if (p->k == k) return p; - p = rb_tree_succ(t, p); - } while (p && compare(p->k, k, c_data) <= 0); - return NULL; -} - /* find greatest point in subtree p that is <= k */ -static rb_node *find_le(rb_node *p, int k, rb_tree *t) +static rb_node *find_le(rb_node *p, rb_key k, rb_tree *t) { - rb_node *nil = &t->nil; rb_compare compare = t->compare; - void *c_data = t->c_data; - while (p != nil) { - if (compare(p->k, k, c_data) <= 0) { /* p->k <= k */ + while (p != NIL) { + if (compare(p->k, k) <= 0) { /* p->k <= k */ rb_node *r = find_le(p->r, k, t); if (r) return r; else return p; @@ -250,19 +206,17 @@ static rb_node *find_le(rb_node *p, int k, rb_tree *t) } /* find greatest point in t <= k */ -rb_node *rb_tree_find_le(rb_tree *t, int k) +rb_node *rb_tree_find_le(rb_tree *t, rb_key k) { return find_le(t->root, k, t); } /* find least point in subtree p that is > k */ -static rb_node *find_gt(rb_node *p, int k, rb_tree *t) +static rb_node *find_gt(rb_node *p, rb_key k, rb_tree *t) { - rb_node *nil = &t->nil; rb_compare compare = t->compare; - void *c_data = t->c_data; - while (p != nil) { - if (compare(p->k, k, c_data) > 0) { /* p->k > k */ + while (p != NIL) { + if (compare(p->k, k) > 0) { /* p->k > k */ rb_node *l = find_gt(p->l, k, t); if (l) return l; else return p; @@ -274,64 +228,60 @@ static rb_node *find_gt(rb_node *p, int k, rb_tree *t) } /* find least point in t > k */ -rb_node *rb_tree_find_gt(rb_tree *t, int k) +rb_node *rb_tree_find_gt(rb_tree *t, rb_key k) { return find_gt(t->root, k, t); } rb_node *rb_tree_min(rb_tree *t) { - rb_node *nil = &t->nil; rb_node *n = t->root; - while (n != nil && n->l != nil) + while (n != NIL && n->l != NIL) n = n->l; - return(n == nil ? NULL : n); + return(n == NIL ? NULL : n); } rb_node *rb_tree_max(rb_tree *t) { - rb_node *nil = &t->nil; rb_node *n = t->root; - while (n != nil && n->r != nil) + while (n != NIL && n->r != NIL) n = n->r; - return(n == nil ? NULL : n); + return(n == NIL ? NULL : n); } -rb_node *rb_tree_succ(rb_tree *t, rb_node *n) +rb_node *rb_tree_succ(rb_node *n) { - rb_node *nil = &t->nil; if (!n) return NULL; - if (n->r == nil) { + if (n->r == NIL) { rb_node *prev; do { prev = n; n = n->p; - } while (prev == n->r && n != nil); - return n == nil ? NULL : n; + } while (prev == n->r && n != NIL); + return n == NIL ? NULL : n; } else { n = n->r; - while (n->l != nil) + while (n->l != NIL) n = n->l; return n; } } -rb_node *rb_tree_pred(rb_tree *t, rb_node *n) +rb_node *rb_tree_pred(rb_node *n) { - rb_node *nil = &t->nil; if (!n) return NULL; - if (n->l == nil) { + if (n->l == NIL) { rb_node *prev; do { prev = n; n = n->p; - } while (prev == n->l && n != nil); - return n == nil ? NULL : n; + } while (prev == n->l && n != NIL); + return n == NIL ? NULL : n; } else { n = n->l; - while (n->r != nil) + while (n->r != NIL) n = n->r; return n; } @@ -339,85 +289,85 @@ rb_node *rb_tree_pred(rb_tree *t, rb_node *n) rb_node *rb_tree_remove(rb_tree *t, rb_node *n) { - rb_node *nil = &t->nil; - rb_node *m; - if (n->l != nil && n->r != nil) { + rb_key k = n->k; + rb_node *m, *mp; + if (n->l != NIL && n->r != NIL) { rb_node *lmax = n->l; - while (lmax->r != nil) lmax = lmax->r; + while (lmax->r != NIL) lmax = lmax->r; n->k = lmax->k; n = lmax; } - m = n->l != nil? n->l : n->r; - if (n->p != nil) { + m = n->l != NIL ? n->l : n->r; + if (n->p != NIL) { if (n->p->r == n) n->p->r = m; else n->p->l = m; } else t->root = m; - m->p = n->p; + mp = n->p; + if (m != NIL) m->p = mp; if (n->c == BLACK) { if (m->c == RED) m->c = BLACK; else { deleteblack: - if (m->p != nil) { - rb_node *s = m == m->p->l ? m->p->r : m->p->l; + if (mp != NIL) { + rb_node *s = m == mp->l ? mp->r : mp->l; if (s->c == RED) { - m->p->c = RED; + mp->c = RED; s->c = BLACK; - if (m == m->p->l) rotate_left(m->p, t); - else rotate_right(m->p, t); - s = m == m->p->l ? m->p->r : m->p->l; + if (m == mp->l) rotate_left(mp, t); + else rotate_right(mp, t); + s = m == mp->l ? mp->r : mp->l; } - if (m->p->c == BLACK && s->c == BLACK + if (mp->c == BLACK && s->c == BLACK && s->l->c == BLACK && s->r->c == BLACK) { - if (s != nil) s->c = RED; - m = m->p; + if (s != NIL) s->c = RED; + m = mp; mp = m->p; goto deleteblack; } - else if (m->p->c == RED && s->c == BLACK && + else if (mp->c == RED && s->c == BLACK && s->l->c == BLACK && s->r->c == BLACK) { - if (s != nil) s->c = RED; - m->p->c = BLACK; + if (s != NIL) s->c = RED; + mp->c = BLACK; } else { - if (m == m->p->l && s->c == BLACK && + if (m == mp->l && s->c == BLACK && s->l->c == RED && s->r->c == BLACK) { s->c = RED; s->l->c = BLACK; rotate_right(s, t); - s = m == m->p->l ? m->p->r : m->p->l; + s = m == mp->l ? mp->r : mp->l; } - else if (m == m->p->r && s->c == BLACK && + else if (m == mp->r && s->c == BLACK && s->r->c == RED && s->l->c == BLACK) { s->c = RED; s->r->c = BLACK; rotate_left(s, t); - s = m == m->p->l ? m->p->r : m->p->l; + s = m == mp->l ? mp->r : mp->l; } - s->c = m->p->c; - m->p->c = BLACK; - if (m == m->p->l) { + s->c = mp->c; + mp->c = BLACK; + if (m == mp->l) { s->r->c = BLACK; - rotate_left(m->p, t); + rotate_left(mp, t); } else { s->l->c = BLACK; - rotate_right(m->p, t); + rotate_right(mp, t); } } } } } + t->N--; + n->k = k; /* n may have changed during remove */ return n; /* the node that was deleted may be different from initial n */ } rb_node *rb_tree_resort(rb_tree *t, rb_node *n) { - int k = n->k; n = rb_tree_remove(t, n); - n->p = n->l = n->r = &t->nil; - n->k = k; /* n may have changed during remove */ insert_node(t, n); return n; } diff --git a/cdirect/redblack.h b/cdirect/redblack.h index fdc507a..8fc9ae4 100644 --- a/cdirect/redblack.h +++ b/cdirect/redblack.h @@ -6,44 +6,38 @@ extern "C" { #endif /* __cplusplus */ +typedef double *rb_key; /* key type ... double* is convenient for us, + but of course this could be cast to anything + desired (although void* would look more generic) */ + typedef enum { RED, BLACK } rb_color; typedef struct rb_node_s { struct rb_node_s *p, *r, *l; /* parent, right, left */ - int k; /* key/data ... for DIRECT, an index into our hyperrect array */ + rb_key k; /* key (and data) */ rb_color c; } rb_node; -typedef int (*rb_compare)(int k1, int k2, void *c_data); +typedef int (*rb_compare)(rb_key k1, rb_key k2); typedef struct { - rb_compare compare; void *c_data; + rb_compare compare; rb_node *root; int N; /* number of nodes */ - - /* in our application, we can optimize memory allocation because - we never delete two nodes in a row (we always add a node after - deleting)... or rather, we never delete but the value of - the key sometimes changes. ... this means we can just - allocate a linear, exponentially growing stack (nodes) of - nodes, and don't have to worry about holes in the stack */ - rb_node *nodes; /* allocated data of nodes, in some order */ - int Nalloc; /* number of allocated nodes */ - rb_node nil; /* explicit node for NULL nodes, for convenience */ } rb_tree; -extern int rb_tree_init(rb_tree *t, rb_compare compare, void *c_data); +extern void rb_tree_init(rb_tree *t, rb_compare compare); extern void rb_tree_destroy(rb_tree *t); -extern int rb_tree_insert(rb_tree *t, int k); +extern void rb_tree_destroy_with_keys(rb_tree *t); +extern rb_node *rb_tree_insert(rb_tree *t, rb_key k); extern int rb_tree_check(rb_tree *t); -extern rb_node *rb_tree_find(rb_tree *t, int k); -extern rb_node *rb_tree_find_exact(rb_tree *t, int k); -extern rb_node *rb_tree_find_le(rb_tree *t, int k); -extern rb_node *rb_tree_find_gt(rb_tree *t, int k); +extern rb_node *rb_tree_find(rb_tree *t, rb_key k); +extern rb_node *rb_tree_find_le(rb_tree *t, rb_key k); +extern rb_node *rb_tree_find_gt(rb_tree *t, rb_key k); extern rb_node *rb_tree_resort(rb_tree *t, rb_node *n); extern rb_node *rb_tree_min(rb_tree *t); extern rb_node *rb_tree_max(rb_tree *t); -extern rb_node *rb_tree_succ(rb_tree *t, rb_node *n); -extern rb_node *rb_tree_pred(rb_tree *t, rb_node *n); +extern rb_node *rb_tree_succ(rb_node *n); +extern rb_node *rb_tree_pred(rb_node *n); /* To change a key, use rb_tree_find+resort. Removing a node currently wastes memory unless you change the allocation scheme diff --git a/cdirect/redblack_test.c b/cdirect/redblack_test.c index 003f90d..fba0f7b 100644 --- a/cdirect/redblack_test.c +++ b/cdirect/redblack_test.c @@ -1,37 +1,39 @@ #include #include +#include #include #include "redblack.h" -static int comp(int k1, int k2, void *dummy) +static int comp(rb_key k1, rb_key k2) { - (void) dummy; - return k1 - k2; + if (*k1 < *k2) return -1; + if (*k1 > *k2) return +1; + return 0; } int main(int argc, char **argv) { int N, M; int *k; + double kd; rb_tree t; rb_node *n; int i, j; - if (argc != 2) { - fprintf(stderr, "Usage: redblack_test Ntest\n"); + if (argc < 2) { + fprintf(stderr, "Usage: redblack_test Ntest [rand seed]\n"); return 1; } N = atoi(argv[1]); k = (int *) malloc(N * sizeof(int)); - if (!rb_tree_init(&t, comp, NULL)) { - fprintf(stderr, "error in rb_tree_init\n"); - return 1; - } + rb_tree_init(&t, comp); - srand((unsigned) time(NULL)); + srand((unsigned) (argc > 2 ? atoi(argv[2]) : time(NULL))); for (i = 0; i < N; ++i) { - if (!rb_tree_insert(&t, k[i] = rand() % N)) { + double *newk = (double *) malloc(sizeof(double)); + *newk = (k[i] = rand() % N); + if (!rb_tree_insert(&t, newk)) { fprintf(stderr, "error in rb_tree_insert\n"); return 1; } @@ -41,20 +43,27 @@ int main(int argc, char **argv) } } - for (i = 0; i < N; ++i) - if (!rb_tree_find(&t, k[i]) || !rb_tree_find_exact(&t, k[i])) { + if (t.N != N) { + fprintf(stderr, "incorrect N (%d) in tree (vs. %d)\n", t.N, N); + return 1; + } + + for (i = 0; i < N; ++i) { + kd = k[i]; + if (!rb_tree_find(&t, &kd)) { fprintf(stderr, "rb_tree_find lost %d!\n", k[i]); return 1; } - + } + n = rb_tree_min(&t); for (i = 0; i < N; ++i) { if (!n) { fprintf(stderr, "not enough successors %d\n!", i); return 1; } - printf("%d: %d\n", i, n->k); - n = rb_tree_succ(&t, n); + printf("%d: %g\n", i, n->k[0]); + n = rb_tree_succ(n); } if (n) { fprintf(stderr, "too many successors!\n"); @@ -67,8 +76,8 @@ int main(int argc, char **argv) fprintf(stderr, "not enough predecessors %d\n!", i); return 1; } - printf("%d: %d\n", i, n->k); - n = rb_tree_pred(&t, n); + printf("%d: %g\n", i, n->k[0]); + n = rb_tree_pred(n); } if (n) { fprintf(stderr, "too many predecessors!\n"); @@ -83,17 +92,19 @@ int main(int argc, char **argv) if (j-- == 0) break; if (i >= N) abort(); - if (!(n = rb_tree_find(&t, k[i])) || !rb_tree_find_exact(&t, k[i])) { + kd = k[i]; + if (!(n = rb_tree_find(&t, &kd))) { fprintf(stderr, "rb_tree_find lost %d!\n", k[i]); return 1; } - n->k = knew; + n->k[0] = knew; if (!rb_tree_resort(&t, n)) { fprintf(stderr, "error in rb_tree_resort\n"); return 1; } if (!rb_tree_check(&t)) { - fprintf(stderr, "rb_tree_check_failed after change!\n"); + fprintf(stderr, "rb_tree_check_failed after change %d!\n", + N - M + 1); return 1; } k[i] = -1 - knew; @@ -104,26 +115,25 @@ int main(int argc, char **argv) return 1; } + for (i = 0; i < N; ++i) + k[i] = -1 - k[i]; /* undo negation above */ + for (i = 0; i < N; ++i) { - k[i] = -1 - k[i]; - /* rescale keys by 100 to add more space between them */ - k[i] *= 100; - t.nodes[i].k *= 100; - } - - for (i = 0; i < N; ++i) { - int k = rand() % (N * 150) - N*25; - rb_node *le = rb_tree_find_le(&t, k); - rb_node *gt = rb_tree_find_gt(&t, k); - rb_node *n = rb_tree_min(&t); - printf("%d <= %d < %d\n", le? le->k:-999999, k, gt? gt->k:999999); - if (n->k > k) { + rb_node *le, *gt; + kd = 0.01 * (rand() % (N * 150) - N*25); + le = rb_tree_find_le(&t, &kd); + gt = rb_tree_find_gt(&t, &kd); + n = rb_tree_min(&t); + double lek = le ? le->k[0] : -HUGE_VAL; + double gtk = gt ? gt->k[0] : +HUGE_VAL; + printf("%g <= %g < %g\n", lek, kd, gtk); + if (n->k[0] > kd) { if (le) { - fprintf(stderr, "found invalid le %d for %d\n", le->k, k); + fprintf(stderr, "found invalid le %g for %g\n", lek, kd); return 1; } if (gt != n) { - fprintf(stderr, "gt is not first node for k=%d\n", k); + fprintf(stderr, "gt is not first node for k=%g\n", kd); return 1; } } @@ -131,14 +141,16 @@ int main(int argc, char **argv) rb_node *succ = n; do { n = succ; - succ = rb_tree_succ(&t, n); - } while (succ && succ->k <= k); + succ = rb_tree_succ(n); + } while (succ && succ->k[0] <= kd); if (n != le) { - fprintf("rb_tree_find_le gave wrong result for k=%d\n", k); + fprintf(stderr, + "rb_tree_find_le gave wrong result for k=%g\n",kd); return 1; } if (succ != gt) { - fprintf("rb_tree_find_gt gave wrong result for k=%d\n", k); + fprintf(stderr, + "rb_tree_find_gt gave wrong result for k=%g\n",kd); return 1; } } @@ -151,11 +163,14 @@ int main(int argc, char **argv) if (j-- == 0) break; if (i >= N) abort(); - if (!(n = rb_tree_find(&t, k[i])) || !rb_tree_find_exact(&t, k[i])) { + kd = k[i]; + if (!(n = rb_tree_find(&t, &kd))) { fprintf(stderr, "rb_tree_find lost %d!\n", k[i]); return 1; } - rb_tree_remove(&t, n); + n = rb_tree_remove(&t, n); + free(n->k); + free(n); if (!rb_tree_check(&t)) { fprintf(stderr, "rb_tree_check_failed after remove!\n"); return 1; @@ -163,6 +178,11 @@ int main(int argc, char **argv) k[i] = -1 - k[i]; } + if (t.N != 0) { + fprintf(stderr, "nonzero N (%d) in tree at end\n", t.N); + return 1; + } + rb_tree_destroy(&t); free(k); return 0; -- 2.30.2