chiark / gitweb /
curveopt: wip
[moebius3.git] / symbolic.py
1
2 from __future__ import print_function
3
4 from sympy.vector.vector import *
5 from sympy import *
6 import itertools
7 from sympy.utilities.lambdify import lambdify, implemented_function
8 import numpy as np
9
10 from moedebug import *
11
12 def dprint(*args):
13   if not dbg_enabled(): return
14   print(*args)
15
16 def dbg(*args):
17   if not dbg_enabled(): return
18   for vn in args:
19     print('\n    ' + vn + '\n')
20     pprint(eval(vn))
21     print('\n          =\n')
22     pprint(cse(eval(vn)))
23
24 N = CoordSysCartesian('N')
25
26 calculated = False
27
28 def vector_symbols(vnames):
29   out = []
30   for vname in vnames.split(' '):
31     v = Vector.zero
32     for cname in 'i j k'.split(' '):
33       v += getattr(N, cname) * symbols(vname + '_' + cname)
34     out.append(v)
35   return out
36
37 A, B, C, D = vector_symbols('A B C D')
38 p = vector_symbols('p')
39
40 E, H = vector_symbols('E H')
41 F0, G0 = vector_symbols('F0 G0')
42 En, Hn = vector_symbols('En Hn')
43
44 EFl, HGl = symbols('EFl HGl')
45
46 def vector_component(v, ix):
47   return v.components[N.base_vectors()[ix]]
48
49 # x array in numerical algorithm has:
50 #    N x 3    coordinates of points 0..N-3
51 #    1        EFl = length parameter |EF| for point 1
52 #    1        HGl = length parameter |HG| for point N-2
53
54 # fixed array in numerical algorithm has:
55 #    4 x 3    E, En, H, Hn
56
57 #def subst_vect():
58
59 def calculate():
60   global calculated
61   if calculated: return
62
63   Q = B + 0.5 * (B - A)
64   R = C + 0.5 * (C - D)
65   QR = R - Q
66   BC = C - B
67   global cost_ABCD
68   cost_ABCD = (QR & QR) / (BC & BC)
69   dbg('cost_ABCD')
70   dprint(A)
71
72   global F, G
73   F = E + En * EFl
74   G = H + Hn * HGl
75
76   # diff_A_i = cost_ABCD.diff(vector_component(A, 0))
77   # global diff_A_i
78   # dbg('diff_A_i')
79
80   calculated = True
81
82 def ourccode(*a, **kw):
83   return ccode(*a, user_functions={'sinc':'sinc'}, **kw)
84
85 def cprintraw(*s):
86   print(*s)
87
88 def cprint(s):
89   for l in s.split('\n'):
90     cprintraw(l, '\\')
91
92 def cse_prep_cprint(v, tmp_prefix):
93   # => v, but also having cprint'd the common subexpression assignments
94   sym_iter = map((lambda i: symbols('%s%d' % (tmp_prefix,i))),
95                  itertools.count())
96   (defs, vs) = cse(v, symbols=sym_iter)
97   for defname, defval in defs:
98     cprint('double '+ourccode(defval, assign_to=defname))
99   return vs[0]
100
101 def cassign(v, assign_to, tmp_prefix):
102   v = cse_prep_cprint(v, tmp_prefix)
103   cprint(ourccode(v, assign_to=assign_to))
104
105 def cassign_vector(v, assign_to, tmp_prefix):
106   ijk = 'i j k'.split(' ')
107   for ii in range(0, len(ijk)):
108     x = v & getattr(N, ijk[ii])
109     cassign(x, '%s[%d]' % (assign_to, ii), '%s_%s' % (tmp_prefix, ijk[ii]))
110
111 def gen_diff(current, smalls):
112   global j
113   if not smalls:
114     j = zeros(len(params),0)
115     for param in params:
116       global d
117       paramv = eval(param)
118       d = diff(current, paramv)
119       dbg('d')
120       j = j.row_join(d)
121     dbg('j')
122     j = cse_prep_cprint(j, 'jtmp')
123     for ix in range(0, j.cols):
124       cprint(ourccode(j.col(ix), 'J_COL'))
125       cprint('J_END_COL(%d)' % ix)
126   else:
127     small = smalls[0]
128     smalls = smalls[1:]
129     cprint('if (!IS_SMALL(' + ourccode(small) + ')) {')
130     gen_diff(current, smalls)
131     cprint('} else { /* %s small */' % small)
132     gen_diff(current.replace(
133       sinc(small),
134       1 - small*small/factorial(3) - small**4/factorial(5),
135       ),
136       smalls
137     )
138     cprint('} /* %s small */' % small)
139
140 def gen_misc():
141   cprintraw('// AUTOGENERATED - DO NOT EDIT\n')
142
143 def gen_point_coords_macro(macro_basename):
144   ijk = 'i j k'.split(' ')
145   for ii in range(0, len(ijk)):
146     cprintraw('#define %s_%s (%s[%d])'
147               % (macro_basename, ijk[ii], macro_basename, ii))
148
149 def gen_point_index_macro(macro_basename, c_array_name, base_index):
150   cprintraw('#define %s (&%s[%s])'
151             % (macro_basename, c_array_name, base_index))
152   gen_point_coords_macro(macro_basename)
153
154 def gen_point_references():
155   abcd = 'A B C D'.split(' ')
156
157   gen_point_index_macro('E',  'INPUT', '3*0')
158   gen_point_index_macro('F0', 'INPUT', '3*1')
159   gen_point_index_macro('G0', 'INPUT', '3*(NP-2)')
160   gen_point_index_macro('H',  'INPUT', '3*(NP-1)')
161   cprintraw(         '#define NINPUT  ( 3*(NP-0) )')
162
163   gen_point_index_macro('En', 'PREP', '3*0')
164   gen_point_index_macro('Hn', 'PREP', '3*1')
165   cprintraw(         '#define NPREP   (3*2)')
166
167   cprintraw('#define NX_DIRECT 3*(NP-4)')
168   cprint('#define POINT(PP) (')
169   cprint(' (PP) == 0    ? E :')
170   cprint(' (PP) == 1    ? F :')
171   cprint(' (PP) == NP-2 ? G :')
172   cprint(' (PP) == NP-1 ? H :')
173   cprint(' &X[3*((PP)-2)]')
174   cprintraw(')')
175
176   cprintraw('#define EFl X[ NX_DIRECT + 0 ]')
177   cprintraw('#define HGl X[ NX_DIRECT + 1 ]')
178   cprintraw('#define NX   ( NX_DIRECT + 2 )')
179
180   for ai in range(0, len(abcd)):
181     cprintraw('#define %s POINT(P%+d)' % (abcd[ai], ai))
182     gen_point_coords_macro(abcd[ai])
183
184   cprintraw('')
185
186 def gen_prepare():
187   cprint('#define PREPARE')
188   cprint('memcpy(X, &INPUT[3*2], sizeof(double) * NX_DIRECT);')
189   for EH,EHs,FG0,FGs in ((E,'E', F0,'F'),
190                          (H,'H', G0,'G')):
191     EFHGv = FG0 - EH
192     EFHGl = EFHGv.magnitude()
193     cassign_vector(EFHGv/EFHGl, EHs+'n', 'tmp_'+EHs)
194     cassign(EFHGl, EHs+FGs+'l', 'tmp_l'+EHs)
195   cprintraw('')
196
197 def gen_calculate_FG():
198   cprintraw('#define DECLARE_F_G double F[3], G[3];')
199   cprint('#define CALCULATE_F_G')
200   cassign_vector(F,'F','tmp_F')
201   cassign_vector(G,'G','tmp_G')
202   cprintraw('')
203
204 def gen_calculate_cost():
205   cprint('#define CALCULATE_COST')
206   cassign(cost_ABCD,'P_cost','tmp_P_cost')
207   cprintraw('')
208
209 def gen_C():
210   gen_misc()
211   gen_point_references()
212   gen_prepare()
213   gen_calculate_FG()
214   gen_calculate_cost()
215
216 def get_python():
217   # https://github.com/sympy/sympy/issues/13642
218   # "lambdify sinc gives wrong answer!"
219   out = q_sqparm
220   sinc_fixed = Function('sinc_fixed')
221   implemented_function(sinc_fixed, lambda x: np.sinc(x/np.pi))
222   out = out.subs(sinc,sinc_fixed)
223   p = list(map(eval,params))
224   return lambdify(p, out)