chiark / gitweb /
curveopt: symbolic: wip
authorIan Jackson <ijackson@chiark.greenend.org.uk>
Sat, 7 Apr 2018 19:37:48 +0000 (20:37 +0100)
committerIan Jackson <ijackson@chiark.greenend.org.uk>
Sat, 7 Apr 2018 19:37:48 +0000 (20:37 +0100)
Signed-off-by: Ian Jackson <ijackson@chiark.greenend.org.uk>
findcurve.c
symbolic.py

index eec17d4da6ac715b5606819310c2ae220cb31534..65c5672eb7753c5a117d741c2723fda43c2f7a5b 100644 (file)
@@ -10,9 +10,8 @@
 
 #include "symbolic.c"
 
-#define X(i) x[i]
 #define J_END_COL(i) \
-  for (j=0; j<N; j++) gsl_matrix_set(J,i,j,J_COL[j]);
+  for (j=0; j<PN; j++) gsl_matrix_set(J,i,j,J_COL[j]);
 
 static inline _Bool IS_SMALL(double v) {
   return v < 1E6;
@@ -22,43 +21,33 @@ static inline double sinc(double x) {
   return gsl_sf_sinc(x / M_PI);
 }
 
-static double target[N];
+static double *INPUT; /* dyanmic array, on main's stack */
+static double PREP[NPREP];
+
+static void prepare(double X[] /* startpoint */) {
+  /* fills in PREP and startpoint */
+  PREPARE;
+}
 
 static double cb_Efunc(void *xp) {
   const double *X = xp;
+  double F[3], G[3];
+
+  CALCULATE_F_G;
 
   double e = 0;
-  for (P=0; P<N-3; P++) {
-    double d;
-    // A is point #p, B #p+1, C #p+2, etc.
-    if (P == 0) {
-    } else if (P==N-2) {
-    } else {
-      E_CALCULATE_MID;
-
-  double F[N], e;
-  int i;
-  X_EXTRACT;
-  F_POPULATE;
-  for (i=0, e=0; i<N; i++) {
-    double d = F[i] - target[i];
-    e += d*d;
+  for (P=0; P<NP-3; P++) {
+    double P_cost;
+    CALCULATE_COST;
+    e += P_cost;
   }
-  //printf("\n cb_Efunc %p %10.7f [", xp, e);
-  //for (i=0; i<N; i++) printf(" %10.7f,", x[i]);
-  //printf("]\n");
   return e;
 }
 
-static int cb_fdf(const gsl_vector *x, void *params,
-                 gsl_vector *f, gsl_matrix *J) {
-  
-}
-
 static void cb_step(const gsl_rng *rng, void *xp, double step_size) {
   double *x = xp;
   int i;
-  double step[N];
+  double step[NX];
   gsl_ran_dir_nd(rng, N,step);
   for (i=0; i<N; i++)
     x[i] += step_size * step[i];
@@ -106,20 +95,21 @@ static double scan1double(void) {
   return v;
 }
 
-int main(void) {
+int main(int argc, const char *const argv) {
   double epsilon;
   int i;
 
-  double startpoint[N];
+  NP = atoi(argv[1]);
 
   gsl_rng *rng = gsl_rng_alloc(gsl_rng_ranlxd2);
 
+  double input[NINPUT]; INPUT = input;
+  double startpoint[NX];
+
   for (;;) {
-    /* 2N+1 doubles: target, initial guess, epsilon for residual */
-    for (i=0; i<N; i++)
-      target[i] = scan1double();
-    for (i=0; i<N; i++)
-      startpoint[i] = scan1double();
+    /* NINPUT + 1 doubles: startpoint, epsilon for residual */
+    for (i=0; i<NINPUT; i++)
+      INPUT[i] = scan1double();
     epsilon = scan1double();
 
     gsl_rng_set(rng,0);
@@ -134,6 +124,8 @@ int main(void) {
       .step_size = 0.05,
     };
 
+    prepare(startpoint);
+
     gsl_siman_solve(rng,
                    startpoint,
                    cb_Efunc, cb_step, cb_metric,
index 06ef2c6937f485603183389036b2ae73a0450301..af5af8f1cbff32f3bd972f9f49976eab59761a16 100644 (file)
@@ -38,6 +38,7 @@ A, B, C, D = vector_symbols('A B C D')
 p = vector_symbols('p')
 
 E, H = vector_symbols('E H')
+F0, G0 = vector_symbols('F0 G0')
 En, Hn = vector_symbols('En Hn')
 
 EFl, GHl = symbols('EFl GHl')
@@ -68,6 +69,7 @@ def calculate():
   dbg('cost_ABCD')
   dprint(A)
 
+  global F, G
   F = E + En * EFl
   G = H + Hn * GHl
 
@@ -100,6 +102,12 @@ def cassign(v, assign_to, tmp_prefix):
   v = cse_prep_cprint(v, tmp_prefix)
   cprint(ourccode(v, assign_to=assign_to))
 
+def cassign_vector(v, assign_to, tmp_prefix):
+  ijk = 'i j k'.split(' ')
+  for ii in range(0, len(ijk)):
+    x = v & getattr(N, ijk[ii])
+    cassign(x, '%s[%d]' % (assign_to, ii), '%s_%s' % (tmp_prefix, ijk[ii]))
+
 def gen_diff(current, smalls):
   global j
   if not smalls:
@@ -145,46 +153,59 @@ def gen_point_index_macro(macro_basename, c_array_name, base_index):
 
 def gen_point_references():
   abcd = 'A B C D'.split(' ')
-  eh = 'E En H Hn'.split(' ')
-  for ehi in range(0, len(eh)):
-    gen_point_index_macro(eh[ehi], 'FIXED', ehi * 3)
 
+  gen_point_index_macro('E',  'INPUT', '3*0')
+  gen_point_index_macro('F0', 'INPUT', '3*1')
+  gen_point_index_macro('G0', 'INPUT', '3*(N-2)')
+  gen_point_index_macro('H',  'INPUT', '3*(N-1)')
+  cprintraw(    '#define NINPUT       ( 3*(N-0) )')
+
+  gen_point_index_macro('En', 'PREP', '3*0')
+  gen_point_index_macro('Hn', 'PREP', '3*1')
+  cprintraw(         '#define PREP_N (3*2)')
+
+  cprintraw('#define X_N_DIRECT 3*(N-4)')
   cprint('#define POINT(PP) (')
   cprint(' (PP) == 0   ? E :')
   cprint(' (PP) == 1   ? F :')
   cprint(' (PP) == N-2 ? G :')
   cprint(' (PP) == N-1 ? H :')
-  cprint(' &X[((PP)-2)*3]')
+  cprint(' &X[3*((PP)-2)]')
   cprintraw(')')
 
+  cprintraw('#define EFl X[ X_N_DIRECT + 0 ]')
+  cprintraw('#define GHl X[ X_N_DIRECT + 1 ]')
+  cprintraw('#define X_N  ( X_N_DIRECT + 2 )')
+
   for ai in range(0, len(abcd)):
     cprintraw('#define %s POINT(P%+d)' % (abcd[ai], ai))
     gen_point_coords_macro(abcd[ai])
 
-def gen_calculate_FG():
-  cprint('#define CALCULAGE_F_G')
-  cassign(E,'E','tmp_E')
-  cassign(E,'F','tmp_F')
+  cprintraw('')
 
-def gen_calculate_cost():
-  cprint('#define CALCULATE_COST')
-  cassign(cost_ABCD,'P_cost','tmp_P_cost')
+def gen_prepare():
+  cprint('#define PREPARE')
+  cprint('memcpy(X, &INPUT[3*2], sizeof(double) * X_N_DIRECT)')
+  cassign_vector((F0 - E).normalize(), 'En', 'tmp_En')
+  cassign_vector((G0 - H).normalize(), 'Hn', 'tmp_Hn')
   cprintraw('')
 
-def gen_f_populate():
-  cprint('#define F_POPULATE')
-  cassign(cost_ABCD,'F','ftmp')
+def gen_calculate_FG():
+  cprint('#define CALCULATE_F_G')
+  cassign_vector(F,'F','tmp_F')
+  cassign_vector(G,'G','tmp_G')
   cprintraw('')
 
-def gen_j_populate():
-  cprint('#define J_POPULATE')
-  gen_diff(result_dirnscaled, (sh*sh*la, th*th*la))
+def gen_calculate_cost():
+  cprint('#define CALCULATE_COST')
+  cassign(cost_ABCD,'P_cost','tmp_P_cost')
   cprintraw('')
 
 def gen_C():
   gen_misc()
   gen_point_references()
-  #gen_calculate_FG()
+  gen_prepare()
+  gen_calculate_FG()
   gen_calculate_cost()
 
 def get_python():