chiark / gitweb /
curveopt: show it in a better way
[moebius3.git] / symbolic.py
1
2 from __future__ import print_function
3
4 from sympy.vector.vector import *
5 from sympy import *
6 import itertools
7 from sympy.utilities.lambdify import lambdify, implemented_function
8 import numpy as np
9
10 from moedebug import *
11
12 # When cse() is called for something containing BaseVector, it
13 # produces infinite recursion.
14 #def cse(x, *a, **kw): return ((), (x,))
15
16 def dprint(*args):
17   if not dbg_enabled(): return
18   print(*args)
19
20 def dbg(*args):
21   if not dbg_enabled(): return
22   for vn in args:
23     print('\n    ' + vn + '\n')
24     pprint(eval(vn))
25     print('\n          =\n')
26     pprint(cse(eval(vn)))
27
28 def sqnorm(v): return v & v
29
30 N = CoordSysCartesian('N')
31
32 calculated = False
33
34 def vector_symbols(vnames):
35   out = []
36   for vname in vnames.split(' '):
37     v = Vector.zero
38     for cname in 'i j k'.split(' '):
39       v += getattr(N, cname) * symbols(vname + '_' + cname)
40     out.append(v)
41   return out
42
43 A, B, C, D = vector_symbols('A B C D')
44 p = vector_symbols('p')
45
46 E, H = vector_symbols('E H')
47 F0, G0 = vector_symbols('F0 G0')
48 En, Hn = vector_symbols('En Hn')
49
50 EFlq, HGlq = symbols('EFlq HGlq')
51
52 def vector_component(v, ix):
53   return v.components[N.base_vectors()[ix]]
54
55 # x array in numerical algorithm has:
56 #    N x 3    coordinates of points 0..N-3
57 #    1        EFlq = sqrt of length parameter |EF| for point 1
58 #    1        HGlq = sqrt of length parameter |HG| for point N-2
59
60 # fixed array in numerical algorithm has:
61 #    4 x 3    E, En, H, Hn
62
63 #def subst_vect():
64
65 iterations = []
66
67 class SomeIteration():
68   def __init__(ar, names, size, expr):
69     ar.names_string = names
70     ar.names = names.split(' ')
71     ar.name = ar.names[0]
72     ar.size = size
73     ar.expr = expr
74     if dbg_enabled():
75       print('\n   ' + ar.name + '\n')
76       print(expr)
77     iterations.append(ar)
78
79   def gen_calculate_cost(ar):
80     ar._gen_array()
81     cprint('for (P=0; P<(%s); P++) {' % ar.size)
82     ar._cassign()
83     cprint('}')
84
85 class ScalarArray(SomeIteration):
86   def _gen_array(ar):
87     cprint('double A_%s[%s];' % (ar.name, ar.size))
88   def gen_references(ar):
89     for ai in range(0, len(ar.names)):
90       ar._gen_reference(ai, ar.names[ai])
91   def _gen_reference(ar, ai, an):
92     cprintraw('#define %s A_%s[P%+d]' % (an, ar.name, ai))
93   def _cassign(ar):
94     cassign(ar.expr, ar.name, 'tmp_'+ar.name)
95   def s(ar):
96     return symbols(ar.names_string)
97
98 class CoordArray(ScalarArray):
99   def _gen_array(ar):
100     cprint('double A_%s[%s][3];' % (ar.name, ar.size))
101   def _gen_reference(ar, ai, an):
102     ScalarArray._gen_reference(ar, ai, an)
103     gen_point_coords_macro(an)
104   def _cassign(ar):
105     cassign_vector(ar.expr, ar.name, 'tmp_'+ar.name)
106   def s(ar):
107     return vector_symbols(ar.names_string)
108
109 class CostComponent(SomeIteration):
110   def __init__(cc, size, expr):
111     cc.size = size
112     cc.expr = expr
113     iterations.append(cc)
114   def gen_references(cc): pass
115   def _gen_array(cc): pass
116   def _cassign(cc):
117     cassign(cc.expr, 'P_cost', 'tmp_cost')
118     cprint('cost += P_cost;')
119
120 def calculate():
121   global calculated
122   if calculated: return
123
124   # ---------- actual cost computation formulae ----------
125
126   global F, G
127   F = E + En * pow(EFlq, 2)
128   G = H + Hn * pow(HGlq, 2)
129
130   global a,b, al,bl, au,bu
131   a,  b  = CoordArray ('a_ b_',   'NP-1', B-A           ).s() # [mm]
132   al, bl = ScalarArray('al bl',   'NP-1', a.magnitude() ).s() # [mm]
133   au, bu = CoordArray ('au bu',   'NP-1', a / al        ).s() # [1]
134
135   tan_theta = (au ^ bu) / (au & bu)     # [1]     bending
136
137   global mu, nu
138   mu, nu = CoordArray ('mu nu', 'NP-2', tan_theta ).s() # [1]
139
140   CostComponent('NP-3', sqnorm(mu - nu)) # [1]
141
142   dl2 = pow(al - bl, 2) # [mm^2]
143   CostComponent('NP-2', dl2 / (al*bl)) # [1]
144
145   # ---------- end of cost computation formulae ----------
146
147   calculated = True
148
149 def ourccode(*a, **kw):
150   return ccode(*a, user_functions={'sinc':'sinc'}, **kw)
151
152 def cprintraw(*s):
153   print(*s)
154
155 def cprint(s):
156   for l in s.split('\n'):
157     cprintraw(l, '\\')
158
159 def cse_prep_cprint(v, tmp_prefix):
160   # => v, but also having cprint'd the common subexpression assignments
161   sym_iter = map((lambda i: symbols('%s%d' % (tmp_prefix,i))),
162                  itertools.count())
163   (defs, vs) = cse(v, symbols=sym_iter)
164   for defname, defval in defs:
165     cprint('double '+ourccode(defval, assign_to=defname))
166   return vs[0]
167
168 def cassign(v, assign_to, tmp_prefix):
169   v = cse_prep_cprint(v, tmp_prefix)
170   cprint(ourccode(v, assign_to=assign_to))
171
172 def cassign_vector(v, assign_to, tmp_prefix):
173   ijk = 'i j k'.split(' ')
174   for ii in range(0, len(ijk)):
175     x = v & getattr(N, ijk[ii])
176     cassign(x, '%s[%d]' % (assign_to, ii), '%s_%s' % (tmp_prefix, ijk[ii]))
177
178 def gen_diff(current, smalls):
179   global j
180   if not smalls:
181     j = zeros(len(params),0)
182     for param in params:
183       global d
184       paramv = eval(param)
185       d = diff(current, paramv)
186       dbg('d')
187       j = j.row_join(d)
188     dbg('j')
189     j = cse_prep_cprint(j, 'jtmp')
190     for ix in range(0, j.cols):
191       cprint(ourccode(j.col(ix), 'J_COL'))
192       cprint('J_END_COL(%d)' % ix)
193   else:
194     small = smalls[0]
195     smalls = smalls[1:]
196     cprint('if (!IS_SMALL(' + ourccode(small) + ')) {')
197     gen_diff(current, smalls)
198     cprint('} else { /* %s small */' % small)
199     gen_diff(current.replace(
200       sinc(small),
201       1 - small*small/factorial(3) - small**4/factorial(5),
202       ),
203       smalls
204     )
205     cprint('} /* %s small */' % small)
206
207 def gen_misc():
208   cprintraw('// AUTOGENERATED - DO NOT EDIT\n')
209
210 def gen_point_coords_macro(macro_basename):
211   ijk = 'i j k'.split(' ')
212   for ii in range(0, len(ijk)):
213     cprintraw('#define %s_%s (%s[%d])'
214               % (macro_basename, ijk[ii], macro_basename, ii))
215
216 def gen_point_index_macro(macro_basename, c_array_name, base_index):
217   cprintraw('#define %s (&%s[%s])'
218             % (macro_basename, c_array_name, base_index))
219   gen_point_coords_macro(macro_basename)
220
221 def gen_point_references():
222   abcd = 'A B C D'.split(' ')
223
224   gen_point_index_macro('E',  'INPUT', '3*0')
225   gen_point_index_macro('F0', 'INPUT', '3*1')
226   gen_point_index_macro('G0', 'INPUT', '3*(NP-2)')
227   gen_point_index_macro('H',  'INPUT', '3*(NP-1)')
228   cprintraw(         '#define NINPUT  ( 3*(NP-0) )')
229
230   gen_point_index_macro('En', 'PREP', '3*0')
231   gen_point_index_macro('Hn', 'PREP', '3*1')
232   cprintraw(         '#define NPREP   (3*2)')
233
234   cprintraw('#define NX_DIRECT 3*(NP-4)')
235   cprint('#define POINT(PP) (')
236   cprint(' (PP) == 0    ? E :')
237   cprint(' (PP) == 1    ? F :')
238   cprint(' (PP) == NP-2 ? G :')
239   cprint(' (PP) == NP-1 ? H :')
240   cprint(' &X[3*((PP)-2)]')
241   cprintraw(')')
242
243   cprintraw('#define EFlq X[ NX_DIRECT + 0 ]')
244   cprintraw('#define HGlq X[ NX_DIRECT + 1 ]')
245   cprintraw('#define NX    ( NX_DIRECT + 2 )')
246
247   for ai in range(0, len(abcd)):
248     cprintraw('#define %s POINT(P%+d)' % (abcd[ai], ai))
249     gen_point_coords_macro(abcd[ai])
250
251   for si in iterations:
252     si.gen_references()
253
254   cprintraw('')
255
256 def gen_prepare():
257   cprint('#define PREPARE')
258   cprint('memcpy(X, &INPUT[3*2], sizeof(double) * NX_DIRECT);')
259   for EH,EHs,FG0,FGs in ((E,'E', F0,'F'),
260                          (H,'H', G0,'G')):
261     EFHGv = FG0 - EH
262     EFHGl = EFHGv.magnitude()
263     EFHGlq = sqrt(EFHGl)
264     cassign_vector(EFHGv/EFHGl, EHs+'n', 'tmp_'+EHs)
265     cassign(EFHGlq, EHs+FGs+'lq', 'tmp_l'+EHs)
266   cprintraw('')
267
268 def gen_calculate_FG():
269   cprintraw('#define DECLARE_F_G double F[3], G[3];')
270   cprint('#define CALCULATE_F_G')
271   cassign_vector(F,'F','tmp_F')
272   cassign_vector(G,'G','tmp_G')
273   cprintraw('')
274
275 def gen_calculate_cost():
276   cprint('#define CALCULATE_COST')
277   cprint('double cost=0, P_cost;')
278   for si in iterations:
279     si.gen_calculate_cost()
280   cprintraw('')
281
282 def gen_C():
283   gen_misc()
284   gen_point_references()
285   gen_prepare()
286   gen_calculate_FG()
287   gen_calculate_cost()
288
289 def get_python():
290   # https://github.com/sympy/sympy/issues/13642
291   # "lambdify sinc gives wrong answer!"
292   out = q_sqparm
293   sinc_fixed = Function('sinc_fixed')
294   implemented_function(sinc_fixed, lambda x: np.sinc(x/np.pi))
295   out = out.subs(sinc,sinc_fixed)
296   p = list(map(eval,params))
297   return lambdify(p, out)