chiark / gitweb /
curveopt: new approach
[moebius3.git] / symbolic.py
1
2 from __future__ import print_function
3
4 from sympy.vector.vector import *
5 from sympy import *
6 import itertools
7 from sympy.utilities.lambdify import lambdify, implemented_function
8 import numpy as np
9
10 from moedebug import *
11
12 # When cse() is called for something containing BaseVector, it
13 # produces infinite recursion.
14 #def cse(x, *a, **kw): return ((), (x,))
15
16 def dprint(*args):
17   if not dbg_enabled(): return
18   print(*args)
19
20 def dbg(*args):
21   if not dbg_enabled(): return
22   for vn in args:
23     print('\n    ' + vn + '\n')
24     pprint(eval(vn))
25     print('\n          =\n')
26     pprint(cse(eval(vn)))
27
28 def sqnorm(v): return v & v
29
30 N = CoordSysCartesian('N')
31
32 calculated = False
33
34 def vector_symbols(vnames):
35   out = []
36   for vname in vnames.split(' '):
37     v = Vector.zero
38     for cname in 'i j k'.split(' '):
39       v += getattr(N, cname) * symbols(vname + '_' + cname)
40     out.append(v)
41   return out
42
43 A, B, C, D = vector_symbols('A B C D')
44 p = vector_symbols('p')
45
46 E, H = vector_symbols('E H')
47 F0, G0 = vector_symbols('F0 G0')
48 En, Hn = vector_symbols('En Hn')
49
50 EFl, HGl = symbols('EFl HGl')
51
52 def vector_component(v, ix):
53   return v.components[N.base_vectors()[ix]]
54
55 # x array in numerical algorithm has:
56 #    N x 3    coordinates of points 0..N-3
57 #    1        EFl = length parameter |EF| for point 1
58 #    1        HGl = length parameter |HG| for point N-2
59
60 # fixed array in numerical algorithm has:
61 #    4 x 3    E, En, H, Hn
62
63 #def subst_vect():
64
65 iterations = []
66
67 class SomeIteration():
68   def __init__(ar, names, size, expr):
69     ar.names_string = names
70     ar.names = names.split(' ')
71     ar.name = ar.names[0]
72     ar.size = size
73     ar.expr = expr
74     if dbg_enabled():
75       print('\n   ' + ar.name + '\n')
76       print(expr)
77     iterations.append(ar)
78
79   def gen_calculate_cost(ar):
80     ar._gen_array()
81     cprint('for (P=0; P<(%s); P++) {' % ar.size)
82     ar._cassign()
83     cprint('}')
84
85 class ScalarArray(SomeIteration):
86   def _gen_array(ar):
87     cprint('double A_%s[%s];' % (ar.name, ar.size))
88   def gen_references(ar):
89     for ai in range(0, len(ar.names)):
90       ar._gen_reference(ai, ar.names[ai])
91   def _gen_reference(ar, ai, an):
92     cprintraw('#define %s A_%s[P%+d]' % (an, ar.name, ai))
93   def _cassign(ar):
94     cassign(ar.expr, ar.name, 'tmp_'+ar.name)
95   def s(ar):
96     return symbols(ar.names_string)
97
98 class CoordArray(ScalarArray):
99   def _gen_array(ar):
100     cprint('double A_%s[%s][3];' % (ar.name, ar.size))
101   def _gen_reference(ar, ai, an):
102     ScalarArray._gen_reference(ar, ai, an)
103     gen_point_coords_macro(an)
104   def _cassign(ar):
105     cassign_vector(ar.expr, ar.name, 'tmp_'+ar.name)
106   def s(ar):
107     return vector_symbols(ar.names_string)
108
109 class CostComponent(SomeIteration):
110   def __init__(cc, size, expr):
111     cc.size = size
112     cc.expr = expr
113     iterations.append(cc)
114   def gen_references(cc): pass
115   def _gen_array(cc): pass
116   def _cassign(cc):
117     cassign(cc.expr, 'P_cost', 'tmp_cost')
118     cprint('cost += P_cost;')
119
120 def calculate():
121   global calculated
122   if calculated: return
123
124   # ---------- actual cost computation formulae ----------
125
126   global F, G
127   F = E + En * EFl
128   G = H + Hn * HGl
129
130   global a,b, al,bl, au,bu
131   a,  b  = CoordArray ('a_ b_',   'NP-1', B-A           ).s() # [mm]
132   al, bl = ScalarArray('al bl',   'NP-1', a.magnitude() ).s() # [mm]
133   au, bu = CoordArray ('au bu',   'NP-1', a / al        ).s() # [1]
134
135   tan_theta = (au ^ bu) / (au & bu)     # [1]     bending
136   curvature = tan_theta / sqrt(al * bl) # [1/mm]  bending per unit length
137
138   global mu, nu
139   mu, nu = CoordArray ('mu nu', 'NP-2', curvature ).s() # [1/mm]
140
141   CostComponent('NP-3', sqnorm(mu - nu)) # [1/mm^2]
142
143   d_density = 1/al - 1/bl # [1/mm]
144   CostComponent('NP-2', pow(d_density, 2)) # [1/mm^2]
145
146   # ---------- end of cost computation formulae ----------
147
148   calculated = True
149
150 def ourccode(*a, **kw):
151   return ccode(*a, user_functions={'sinc':'sinc'}, **kw)
152
153 def cprintraw(*s):
154   print(*s)
155
156 def cprint(s):
157   for l in s.split('\n'):
158     cprintraw(l, '\\')
159
160 def cse_prep_cprint(v, tmp_prefix):
161   # => v, but also having cprint'd the common subexpression assignments
162   sym_iter = map((lambda i: symbols('%s%d' % (tmp_prefix,i))),
163                  itertools.count())
164   (defs, vs) = cse(v, symbols=sym_iter)
165   for defname, defval in defs:
166     cprint('double '+ourccode(defval, assign_to=defname))
167   return vs[0]
168
169 def cassign(v, assign_to, tmp_prefix):
170   v = cse_prep_cprint(v, tmp_prefix)
171   cprint(ourccode(v, assign_to=assign_to))
172
173 def cassign_vector(v, assign_to, tmp_prefix):
174   ijk = 'i j k'.split(' ')
175   for ii in range(0, len(ijk)):
176     x = v & getattr(N, ijk[ii])
177     cassign(x, '%s[%d]' % (assign_to, ii), '%s_%s' % (tmp_prefix, ijk[ii]))
178
179 def gen_diff(current, smalls):
180   global j
181   if not smalls:
182     j = zeros(len(params),0)
183     for param in params:
184       global d
185       paramv = eval(param)
186       d = diff(current, paramv)
187       dbg('d')
188       j = j.row_join(d)
189     dbg('j')
190     j = cse_prep_cprint(j, 'jtmp')
191     for ix in range(0, j.cols):
192       cprint(ourccode(j.col(ix), 'J_COL'))
193       cprint('J_END_COL(%d)' % ix)
194   else:
195     small = smalls[0]
196     smalls = smalls[1:]
197     cprint('if (!IS_SMALL(' + ourccode(small) + ')) {')
198     gen_diff(current, smalls)
199     cprint('} else { /* %s small */' % small)
200     gen_diff(current.replace(
201       sinc(small),
202       1 - small*small/factorial(3) - small**4/factorial(5),
203       ),
204       smalls
205     )
206     cprint('} /* %s small */' % small)
207
208 def gen_misc():
209   cprintraw('// AUTOGENERATED - DO NOT EDIT\n')
210
211 def gen_point_coords_macro(macro_basename):
212   ijk = 'i j k'.split(' ')
213   for ii in range(0, len(ijk)):
214     cprintraw('#define %s_%s (%s[%d])'
215               % (macro_basename, ijk[ii], macro_basename, ii))
216
217 def gen_point_index_macro(macro_basename, c_array_name, base_index):
218   cprintraw('#define %s (&%s[%s])'
219             % (macro_basename, c_array_name, base_index))
220   gen_point_coords_macro(macro_basename)
221
222 def gen_point_references():
223   abcd = 'A B C D'.split(' ')
224
225   gen_point_index_macro('E',  'INPUT', '3*0')
226   gen_point_index_macro('F0', 'INPUT', '3*1')
227   gen_point_index_macro('G0', 'INPUT', '3*(NP-2)')
228   gen_point_index_macro('H',  'INPUT', '3*(NP-1)')
229   cprintraw(         '#define NINPUT  ( 3*(NP-0) )')
230
231   gen_point_index_macro('En', 'PREP', '3*0')
232   gen_point_index_macro('Hn', 'PREP', '3*1')
233   cprintraw(         '#define NPREP   (3*2)')
234
235   cprintraw('#define NX_DIRECT 3*(NP-4)')
236   cprint('#define POINT(PP) (')
237   cprint(' (PP) == 0    ? E :')
238   cprint(' (PP) == 1    ? F :')
239   cprint(' (PP) == NP-2 ? G :')
240   cprint(' (PP) == NP-1 ? H :')
241   cprint(' &X[3*((PP)-2)]')
242   cprintraw(')')
243
244   cprintraw('#define EFl X[ NX_DIRECT + 0 ]')
245   cprintraw('#define HGl X[ NX_DIRECT + 1 ]')
246   cprintraw('#define NX   ( NX_DIRECT + 2 )')
247
248   for ai in range(0, len(abcd)):
249     cprintraw('#define %s POINT(P%+d)' % (abcd[ai], ai))
250     gen_point_coords_macro(abcd[ai])
251
252   for si in iterations:
253     si.gen_references()
254
255   cprintraw('')
256
257 def gen_prepare():
258   cprint('#define PREPARE')
259   cprint('memcpy(X, &INPUT[3*2], sizeof(double) * NX_DIRECT);')
260   for EH,EHs,FG0,FGs in ((E,'E', F0,'F'),
261                          (H,'H', G0,'G')):
262     EFHGv = FG0 - EH
263     EFHGl = EFHGv.magnitude()
264     cassign_vector(EFHGv/EFHGl, EHs+'n', 'tmp_'+EHs)
265     cassign(EFHGl, EHs+FGs+'l', 'tmp_l'+EHs)
266   cprintraw('')
267
268 def gen_calculate_FG():
269   cprintraw('#define DECLARE_F_G double F[3], G[3];')
270   cprint('#define CALCULATE_F_G')
271   cassign_vector(F,'F','tmp_F')
272   cassign_vector(G,'G','tmp_G')
273   cprintraw('')
274
275 def gen_calculate_cost():
276   cprint('#define CALCULATE_COST')
277   cprint('double cost=0, P_cost;')
278   for si in iterations:
279     si.gen_calculate_cost()
280   cprintraw('')
281
282 def gen_C():
283   gen_misc()
284   gen_point_references()
285   gen_prepare()
286   gen_calculate_FG()
287   gen_calculate_cost()
288
289 def get_python():
290   # https://github.com/sympy/sympy/issues/13642
291   # "lambdify sinc gives wrong answer!"
292   out = q_sqparm
293   sinc_fixed = Function('sinc_fixed')
294   implemented_function(sinc_fixed, lambda x: np.sinc(x/np.pi))
295   out = out.subs(sinc,sinc_fixed)
296   p = list(map(eval,params))
297   return lambdify(p, out)